Skip to content

Commit

Permalink
Add InfoLM (#915)
Browse files Browse the repository at this point in the history
* First full pass working
* Add InfoLM module metric and fix some minor issue
* Add module metric InfoLM
* Apply partially some suggestions from code review
* Fix an intendation in a docstring
* Uncoment test parameters
* Fix a minor bug, specify docstring + fix max_len for the test
* Add some missing part
* Add IDF support
* Add 'backsorting' of scores
* Update reference test results after finding the bug
* Fix functional metric tests and some minor things
* Update class test
* Skip checking whether a metric is scriptable as transformers are usually not scriptable
* Update test results
* Use different link for Fisher-rao distance
* Drop use_cache kwarg
* Use dim_zero_cat instead of torch.cat in class metric
* Set num_threads=1 when dist_sync_on_step=True

Daemonic processes are not allowed to have children. As a consequence,
when num_threads>1 and dist_sync_on_step=True is used, the application
fails. We, therefore, enforce to automatically setting num_threads=1.

* Drop enforcing num_threads=1 and add linkk to repo for generating test results
* Replace _TRANSFORMERS_AUTO_AVAILABLE with _TRANSFORMERS_AUTO_AVAILABLE everywhere
* Fix device placement and Set num_threads=0 as default (SkafteNicki's review)
* Make testing conditional on dependency and connection
* Skip tests if transformers are not available
* Mark tests as skip if there's again any connection issue to HF Hub

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
  • Loading branch information
3 people authored Jul 12, 2022
1 parent c7b7c46 commit 4ebb4a2
Show file tree
Hide file tree
Showing 13 changed files with 1,387 additions and 235 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added


- Added a new NLP metric `InfoLM` ([#915](https://github.com/PyTorchLightning/metrics/pull/915))


- Added `Perplexity` metric ([#922](https://github.com/PyTorchLightning/metrics/pull/922))


Expand Down
6 changes: 6 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,9 @@
.. _MER: https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf
.. _WIL: https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf
.. _WIP: https://infoscience.epfl.ch/record/82766
.. _InfoLM: https://arxiv.org/pdf/2112.01589.pdf
.. _alpha divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf
.. _beta divergence: https://www.sciencedirect.com/science/article/pii/S0047259X08000456
.. _AB divergence: https://pdfs.semanticscholar.org/744b/1166de34cb099100f151f3b1459f141ae25b.pdf
.. _Rényi divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf
.. _Fisher-Rao distance: http://www.scholarpedia.org/article/Fisher-Rao_metric
21 changes: 21 additions & 0 deletions docs/source/text/infolm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: InfoLM
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/summarization.svg
:tags: Text

.. include:: ../links.rst

######
InfoLM
######

Module Interface
________________

.. autoclass:: torchmetrics.text.infolm.InfoLM
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.text.infolm.infolm
5 changes: 3 additions & 2 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@
from torchmetrics.functional.text.wer import word_error_rate
from torchmetrics.functional.text.wil import word_information_lost
from torchmetrics.functional.text.wip import word_information_preserved
from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE

if _TRANSFORMERS_AUTO_AVAILABLE:
if _TRANSFORMERS_AVAILABLE:
from torchmetrics.functional.text.bert import bert_score # noqa: F401
from torchmetrics.functional.text.infolm import infolm # noqa: F401

__all__ = [
"accuracy",
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from torchmetrics.functional.text.wer import word_error_rate # noqa: F401
from torchmetrics.functional.text.wil import word_information_lost # noqa: F401
from torchmetrics.functional.text.wip import word_information_preserved # noqa: F401
from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _TRANSFORMERS_AUTO_AVAILABLE
from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _TRANSFORMERS_AVAILABLE

if _TRANSFORMERS_AUTO_AVAILABLE:
if _TRANSFORMERS_AVAILABLE:
from torchmetrics.functional.text.bert import bert_score # noqa: F401
from torchmetrics.functional.text.infolm import infolm # noqa: F401

if _NLTK_AVAILABLE:
from torchmetrics.functional.text.rouge import rouge_score # noqa: F401
238 changes: 17 additions & 221 deletions src/torchmetrics/functional/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,240 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import math
import urllib
from collections import Counter, defaultdict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from warnings import warn

import torch
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset

from torchmetrics.utilities.imports import _TQDM_AVAILABLE, _TRANSFORMERS_AUTO_AVAILABLE

if _TRANSFORMERS_AUTO_AVAILABLE:
from transformers.models.auto import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader

from torchmetrics.functional.text.helper_embedding_metric import (
TextDataset,
TokenizedDataset,
_check_shape_of_model_output,
_get_progress_bar,
_input_data_collator,
_output_data_collator,
_process_attention_mask_for_special_tokens,
)
from torchmetrics.utilities.imports import _TQDM_AVAILABLE, _TRANSFORMERS_AVAILABLE

if _TRANSFORMERS_AVAILABLE:
from transformers import AutoModel, AutoTokenizer
else:
__doctest_skip__ = ["bert_score"]

if _TQDM_AVAILABLE:
import tqdm


# Default model recommended in the original implementation.
_DEFAULT_MODEL = "roberta-large"


def _preprocess_text(
text: List[str],
tokenizer: Any,
max_length: int = 512,
truncation: bool = True,
sort_according_length: bool = True,
own_tokenizer: bool = False,
) -> Dict[str, Tensor]:
"""Default text pre-processing function using `transformers` `AutoTokenizer` instance.
Args:
text: An iterable of sentences.
tokenizer: Either ``AutoTokenizer`` instance from ``transformers`` package, or a user's own tokenizer.
max_length: A maximum sequence length.
truncation:
An indication of whether tokenized sequences should be padded only to the length of the longest sequence.
sort_according_length:
An indication of whether tokenized sequences should be sorted from shortest to longest. This is appropriate
to do for leveraging dynamic padding during embedding calculation and thereby to hasten inference.
own_tokenizer: An indication of whether a non-default user's own tokenizer is used.
Return:
A dictionary of tokenized sentences including ``input_ids`` and ``attention_mask``.
Raises:
BaseException:
If a tokenization with a user's own tokenizer is not successful.
"""
if not own_tokenizer:
tokenized_data = tokenizer(
text, padding="max_length", max_length=max_length, truncation=truncation, return_tensors="pt"
)
else:
try:
tokenized_data = tokenizer(text, max_length)
except BaseException as ex:
raise BaseException(f"Tokenization was not successful: {ex}")

input_ids, attention_mask = (
_sort_data_according_length(tokenized_data["input_ids"], tokenized_data["attention_mask"])
if sort_according_length
else (tokenized_data["input_ids"], tokenized_data["attention_mask"])
)
return {"input_ids": input_ids, "attention_mask": attention_mask}


def _process_attention_mask_for_special_tokens(attention_mask: Tensor) -> Tensor:
"""Process attention mask to be zero for special [CLS] and [SEP] tokens as they're not included in a
calculation for BERT score.
Args:
attention_mask: An attention mask to be returned, for example, by a ``transformers`` tokenizer.
Return:
A processed attention mask.
"""
# Make attention_mask zero for [CLS] token
attention_mask[:, 0] = 0
# Make attention_mask zero for [SEP] token
sep_token_position = (attention_mask - 0.1).cumsum(-1).argmax(-1)
attention_mask[torch.arange(attention_mask.size(0)).long(), sep_token_position] = 0
return attention_mask


def _sort_data_according_length(input_ids: Tensor, attention_mask: Tensor) -> Tuple[Tensor, Tensor]:
"""Sort tokenized sentence from the shortest to the longest one."""
sorted_indices = attention_mask.sum(1).argsort()
input_ids = input_ids[sorted_indices]
attention_mask = attention_mask[sorted_indices]
return input_ids, attention_mask


def _input_data_collator(
batch: Dict[str, Tensor], device: Optional[Union[str, torch.device]] = None
) -> Dict[str, Tensor]:
"""Helper function that trims model inputs to the longest sequence within the batch and put the input on the
proper device."""
max_len = int(batch["attention_mask"].sum(1).max().item())
input_ids = batch["input_ids"][:, :max_len].to(device)
attention_mask = batch["attention_mask"][:, :max_len].to(device)
batch.update({"input_ids": input_ids, "attention_mask": attention_mask})
return batch


def _output_data_collator(model_output: Tensor, attention_mask: Tensor, target_len: int) -> Tuple[Tensor, Tensor]:
"""Helper function that pads the model output and attention mask to the target length."""
zeros_shape = list(model_output.shape)
zeros_shape[2] = target_len - zeros_shape[2]
model_output = torch.cat(
[model_output, torch.zeros(zeros_shape, dtype=model_output.dtype).to(model_output.device)], dim=2
)
zeros = torch.zeros(zeros_shape[0], zeros_shape[2], dtype=attention_mask.dtype).to(attention_mask.device)
attention_mask = torch.cat([attention_mask, zeros], dim=1)
return model_output, attention_mask


class TextDataset(Dataset):
"""PyTorch dataset class for storing tokenized sentences and other properties used for BERT score
calculation."""

def __init__(
self,
text: List[str],
tokenizer: Any,
max_length: int = 512,
preprocess_text_fn: Callable[[List[str], Any, int], Dict[str, Tensor]] = _preprocess_text,
idf: bool = False,
tokens_idf: Optional[Dict[int, float]] = None,
) -> None:
"""
Args:
text: An iterable of sentences.
tokenizer: ``AutoTokenizer`` instance from ``transformers`` package.
max_length: A maximum sequence length.
preprocess_text_fn: A function used for processing the input sentences.
idf: An indication of whether calculate token inverse document frequencies to weight the model embeddings.
tokens_idf: Inverse document frequencies (these should be calculated on reference sentences).
"""
self.text = preprocess_text_fn(text, tokenizer, max_length)
self.max_length = self.text["input_ids"].shape[1]
self.num_sentences = len(text)
self.idf = idf
self.tokens_idf = {}
if idf:
self.tokens_idf = tokens_idf if tokens_idf is not None else self._get_tokens_idf()

def __getitem__(self, idx: int) -> Dict[str, Tensor]:
input_ids = self.text["input_ids"][idx, :]
attention_mask = self.text["attention_mask"][idx, :]
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
if self.idf:
input_ids_idf = torch.tensor([self.tokens_idf[input_idx] for input_idx in input_ids.tolist()])
inputs_dict["input_ids_idf"] = input_ids_idf
return inputs_dict

def __len__(self) -> int:
return self.num_sentences

def _get_tokens_idf(self) -> Dict[int, float]:
"""Calculate token inverse document frequencies.
Return:
A python dictionary containing inverse document frequencies for token ids.
"""
token_counter: Counter = Counter()
for tokens in map(self._set_of_tokens, self.text["input_ids"]):
token_counter.update(tokens)

tokens_idf: Dict[int, float] = defaultdict(self._get_tokens_idf_default_value)
tokens_idf.update(
{idx: math.log((self.num_sentences + 1) / (occurrence + 1)) for idx, occurrence in token_counter.items()}
)
return tokens_idf

def _get_tokens_idf_default_value(self) -> float:
"""Helper function that ensures ``defaultdict`` to be pickled."""
return math.log((self.num_sentences + 1) / 1)

@staticmethod
def _set_of_tokens(input_ids: Tensor) -> Set:
"""Return set of tokens from the ``input_ids``."""
return set(input_ids.tolist())


class TokenizedDataset(TextDataset):
"""The child class of ``TextDataset`` class used with already tokenized data."""

def __init__(
self,
input_ids: Tensor,
attention_mask: Tensor,
idf: bool = False,
tokens_idf: Optional[Dict[int, float]] = None,
) -> None:
"""
Args:
input_ids: Input ids.
attention_mask: Attention mask.
idf: An indication of whether calculate token inverse document frequencies to weight the model embeddings.
tokens_idf: Inverse document frequencies (these should be calculated on reference sentences).
"""
self.text = dict(zip(["input_ids", "attention_mask"], _sort_data_according_length(input_ids, attention_mask)))
self.text = _input_data_collator(self.text)
self.num_sentences = len(self.text["input_ids"])
self.max_length = self.text["input_ids"].shape[1]
self.idf = idf
self.tokens_idf = {}
if idf:
self.tokens_idf = tokens_idf if tokens_idf is not None else self._get_tokens_idf()


def _get_progress_bar(dataloader: DataLoader, verbose: bool = False) -> Union[DataLoader, "tqdm.auto.tqdm"]:
"""Helper function returning either the dataloader itself when ``verbose = False``, or it wraps the dataloader with
``tqdm.auto.tqdm``, when ``verbose = True`` to display a progress bar during the embeddings calculation."""
return tqdm.auto.tqdm(dataloader) if verbose else dataloader


def _check_shape_of_model_output(output: Tensor, input_ids: Tensor) -> None:
"""Check if the shape of the user's own model output."""
bs, seq_len = input_ids.shape[:2]
invalid_out_shape = len(output.shape) != 3 or output.shape[0] != bs or output.shape[1] != seq_len
if invalid_out_shape:
raise ValueError(
"The model output must be `torch.Tensor` of a shape `[batch_size, seq_len, model_dim]` "
f"i.e. [{bs}, {seq_len}. , `model_dim`], but got {output.shape}."
)


def _get_embeddings_and_idf_scale(
dataloader: DataLoader,
target_len: int,
Expand Down Expand Up @@ -537,7 +333,7 @@ def bert_score(
)

if model is None:
if not _TRANSFORMERS_AUTO_AVAILABLE:
if not _TRANSFORMERS_AVAILABLE:
raise ModuleNotFoundError(
"`bert_score` metric with default models requires `transformers` package be installed."
" Either install with `pip install transformers>=4.0` or `pip install torchmetrics[text]`."
Expand Down
Loading

0 comments on commit 4ebb4a2

Please sign in to comment.