Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions docs/source/internal/generation_utils.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,30 @@ generation.
[[autodoc]] TFLogitsProcessorList
- __call__

[[autodoc]] TFLogitsWarper
- __call__

[[autodoc]] TFTemperatureLogitsWarper
- __call__

[[autodoc]] TFTopPLogitsWarper
- __call__

[[autodoc]] TFTopKLogitsWarper
- __call__

[[autodoc]] TFMinLengthLogitsProcessor
- __call__

[[autodoc]] TFNoBadWordsLogitsProcessor
- __call__

[[autodoc]] TFNoRepeatNGramLogitsProcessor
- __call__

[[autodoc]] TFRepetitionPenaltyLogitsProcessor
- __call__

[[autodoc]] FlaxLogitsProcessor
- __call__

Expand Down
8 changes: 8 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1646,10 +1646,14 @@
_import_structure["generation_tf_logits_process"] = [
"TFLogitsProcessor",
"TFLogitsProcessorList",
"TFLogitsWarper",
"TFMinLengthLogitsProcessor",
"TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor",
"TFTemperatureLogitsWarper",
"TFTopKLogitsWarper",
"TFTopPLogitsWarper",
]
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
Expand Down Expand Up @@ -3688,10 +3692,14 @@
from .generation_tf_logits_process import (
TFLogitsProcessor,
TFLogitsProcessorList,
TFLogitsWarper,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
TFTemperatureLogitsWarper,
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from .generation_tf_utils import tf_top_k_top_p_filtering
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/generation_flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int, **

class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
[`FlaxLogitsWarper`] for temperature (exponential scaling output probability distribution).

Args:
temperature (`float`):
Expand All @@ -114,7 +114,7 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) ->

class FlaxTopPLogitsWarper(FlaxLogitsWarper):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
[`FlaxLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.

Args:
top_p (`float`):
Expand Down Expand Up @@ -155,7 +155,7 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) ->

class FlaxTopKLogitsWarper(FlaxLogitsWarper):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
[`FlaxLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.

Args:
top_k (`int`):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/generation_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def generate(
raise NotImplementedError("`Beam sampling is currently not implemented.")

def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None
self, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
) -> FlaxLogitsProcessorList:
"""
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
Expand Down
114 changes: 114 additions & 0 deletions src/transformers/generation_tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
)


class TFLogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""

@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
"""TF method for warping logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)


class TFLogitsProcessorList(list):
"""
This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor.
Expand All @@ -81,6 +92,109 @@ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, **kwargs) -> tf.Tens
return scores


class TFTemperatureLogitsWarper(TFLogitsWarper):
r"""
[`TFLogitsWarper`] for temperature (exponential scaling output probability distribution).

Args:
temperature (`float`):
The value used to module the logits distribution.
"""

def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")

self.temperature = temperature

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
scores = scores / self.temperature
return scores


class TFTopKLogitsWarper(TFLogitsWarper):
r"""
[`TFLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.

Args:
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""

def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")

self.top_k = top_k
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
# Boolean mask containing all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
next_scores = tf.where(indices_to_remove, self.filter_value, scores)
return next_scores


class TFTopPLogitsWarper(TFLogitsWarper):
"""
[`TFLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to <= prob_cut_off.

Args:
top_p (`float`):
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept
for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""

def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")

self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])

mask_scores = tf.fill(scores.shape, self.filter_value)
cumulative_probs = tf.math.cumsum(tf.nn.softmax(topk_scores, axis=-1), axis=-1)
score_mask = cumulative_probs < self.top_p

# Also include the token that is higher than top_p (the first false = shift and insert a True on the left)
score_mask = tf.concat((tf.ones([score_mask.shape[0], 1], dtype=tf.bool), score_mask[:, :-1]), axis=-1)

# Ensure min tokens to keep
score_mask = tf.concat(
(
tf.ones([score_mask.shape[0], self.min_tokens_to_keep], dtype=tf.bool),
score_mask[:, self.min_tokens_to_keep :],
),
axis=-1,
)

# Mask the values that do not fit the criteria
topk_next_scores = tf.where(score_mask, topk_scores, mask_scores)

# Undo the topk sorting: converts the 2D matrix of per-row original indices of shape (batch_size, vocab_size)
# to a 3D tensor of shape (batch_size, vocab_size, 2) containing the original score coordinate, from which we
# can scatter (i.e. `scatter_indices[row, col, :]` is a tensor containing `[row, topk_indices[row, col]]`)
scatter_rows = tf.tile(tf.expand_dims(tf.range(topk_indices.shape[0]), axis=-1), [1, topk_indices.shape[-1]])
scatter_indices = tf.stack((scatter_rows, topk_indices), axis=-1)
next_scores = tf.scatter_nd(scatter_indices, topk_next_scores, shape=topk_next_scores.shape)

return next_scores


class TFMinLengthLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
Expand Down
Loading