Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ class PretrainedConfig(PushToHubMixin):

use_bfloat16 (`bool`, *optional*, defaults to `False`):
Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
tf_legacy_loss (`bool`, *optional*, defaults to `False`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure we want to set the default to False? This is breaking no? Also it's a somewhat hard-to-discover silent error in this case no?

Copy link
Member Author

@Rocketknight1 Rocketknight1 Jul 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only very slightly breaking - anyone using Keras or a custom model will not notice any change. The existing losses return very strange shapes like vectors of shape (num_unmasked_tokens,) that vary in length each iteration, with no mapping from there back to the original tokens. I doubt anyone is using them directly, without computing tf.reduce_mean() on them.

Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
v5.
"""
model_type: str = ""
is_composition: bool = False
Expand All @@ -260,6 +264,7 @@ def __init__(self, **kwargs):
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
self.pruned_heads = kwargs.pop("pruned_heads", {})
self.tie_word_embeddings = kwargs.pop(
"tie_word_embeddings", True
Expand Down
92 changes: 73 additions & 19 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,22 @@ def hf_compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
if self.config.tf_legacy_loss:
# make sure only labels that are not equal to -100 affect the loss
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
return loss_fn(labels, reduced_logits)

# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
# make sure only labels that are not equal to -100 affect the loss
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
return loss_fn(labels, reduced_logits)
loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype)
# Avoid division by zero later
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is not equivalent the previous computation.

  • previous: We compute the loss for each actual tokens.

  • this PR: along each batch dimension (i.e. for each sequence), the loss is averaged (over the actual tokens in that sequence).

    • return masked_loss should be fine (as we get 0 for inactivate tokens)
    • (but I don't know why we don't return a scalar loss that is obtained by averaging over all active tokens - this is what is done in GPT2LMHeadModel for example)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was checking the docs now, and PT only outputs one number for the batch -- so it makes sense to include the sum here 👍

(we would have to update the TF docstring)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(BTW, the PT output the average instead of the sum, as I see in GPT2LMHeadModel - CrossEntropyLoss have reduction='mean' by default).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be impossible to keep the old behaviour with XLA, because the number of 'active' tokens will change in each batch.

We could return a scalar number, but in Keras it's nice to return a vector of per-sample losses, because this means the user can use the sample_weight argument to fit() if they want to. I think that's fairly uncommon though, so if we want to stick with a scalar, that's fine!



class TFQuestionAnsweringLoss:
Expand Down Expand Up @@ -232,17 +243,34 @@ def hf_compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
if tf.math.reduce_any(labels == -1):
tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1
else:
active_loss = tf.reshape(labels, (-1,)) != -100
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
if tf.executing_eagerly(): # Data-dependent conditionals are forbidden in XLA
if tf.math.reduce_any(labels == -1):
tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")

if self.config.tf_legacy_loss:
# make sure only labels that are not equal to -100
# are taken into account as loss
if tf.math.reduce_any(labels == -1):
tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1
else:
active_loss = tf.reshape(labels, (-1,)) != -100
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)

return loss_fn(labels, reduced_logits)

return loss_fn(labels, reduced_logits)
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
# make sure only labels that are not equal to -100 or -1
# are taken into account as loss
loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype)
# Avoid possible division by zero later
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
# Masked positions will have a loss of NaN because -100 and -1 are not valid labels
masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question above the averaging along dim 1 in TFCausalLanguageModelingLoss above



class TFSequenceClassificationLoss:
Expand All @@ -251,7 +279,7 @@ class TFSequenceClassificationLoss:
"""

def hf_compute_loss(self, labels, logits):
if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1:
if logits.shape.rank == 1 or logits.shape[1] == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change really necessary?

  • I have experienced a few times shape is not working in graph mode, and when it occurs, shape_list makes those tests pass.

  • I heard @gante mentioned some issues with shape_list + XLA, but I didn't check (even completely forget what was wrong).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't remember the exact issue, other than it often causes XLA compilation to fail :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving tf.shape() here did cause XLA compilation to fail. I could use shape_list, but I think .shape and .shape.rank are fine!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see tf.shape used in the previous version.

Regarding .shape v.s shape_list, I am ok as long as things work. I am just feel confused that most of the time (in other places) I see .shape fails while shape_list works (with graph mode / symbolic tensors etc.).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, I should explain! shape_list uses a combination of .shape and tf.shape() to build the list. I think using tf.shape() here confuses XLA, because it looks like a conditional that depends on the specific data you input, and those are forbidden.

I'm not 100% sure of the exact rules it uses, but all I can tell you is that it failed before and it works like this!

loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
else:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
Expand Down Expand Up @@ -298,13 +326,25 @@ def hf_compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
if self.config.tf_legacy_loss:
# make sure only labels that are not equal to -100
# are taken into account as loss
next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)

return loss_fn(next_sentence_label, next_sentence_reduced_logits)

# make sure only labels that are not equal to -100
# are taken into account as loss
next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)

return loss_fn(next_sentence_label, next_sentence_reduced_logits)
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits)
ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype)
# Just zero out samples where label is -100, no reduction
masked_ns_loss = unmasked_ns_loss * ns_loss_mask

return masked_ns_loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this one be reduced as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ns_loss is calculated per sample rather than per position, so masked_ns_loss is a vector of shape (num_samples,). We could reduce that to a scalar if we want, though!



def booleans_processing(config, **kwargs):
Expand Down Expand Up @@ -1327,6 +1367,13 @@ def train_step(self, data):
if not self._using_dummy_loss:
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
if isinstance(x, dict):
x = x.copy()
if isinstance(y, dict):
y = y.copy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice~


# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict
Expand Down Expand Up @@ -1424,6 +1471,13 @@ def test_step(self, data):
if not self._using_dummy_loss:
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
if isinstance(x, dict):
x = x.copy()
if isinstance(y, dict):
y = y.copy()

# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict
Expand Down
67 changes: 45 additions & 22 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,52 @@ def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
masked_lm_reduced_logits = tf.boolean_mask(
tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
mask=masked_lm_active_loss,
)
masked_lm_labels = tf.boolean_mask(
tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
)
sentence_order_active_loss = tf.not_equal(tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100)
sentence_order_reduced_logits = tf.boolean_mask(
tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss
)
sentence_order_label = tf.boolean_mask(
tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss
)
masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits)
masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0]))
masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
if self.config.tf_legacy_loss:
# make sure only labels that are not equal to -100
# are taken into account as loss
masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
masked_lm_reduced_logits = tf.boolean_mask(
tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
mask=masked_lm_active_loss,
)
masked_lm_labels = tf.boolean_mask(
tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
)
sentence_order_active_loss = tf.not_equal(
tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100
)
sentence_order_reduced_logits = tf.boolean_mask(
tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss
)
sentence_order_label = tf.boolean_mask(
tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss
)
masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits)
masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0]))
masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)

return masked_lm_loss + sentence_order_loss

return masked_lm_loss + sentence_order_loss
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
# make sure only labels that are not equal to -100
# are taken into account for the loss computation
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
# Avoid division by zero later
lm_loss_denominator = tf.math.maximum(tf.cast(1, lm_loss_mask.dtype), tf.reduce_sum(lm_loss_mask, axis=1))
masked_lm_losses = unmasked_lm_losses * lm_loss_mask
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator

sop_logits = tf.reshape(logits[1], (-1, 2))
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_sop_loss = loss_fn(y_true=tf.nn.relu(labels["sentence_order_label"]), y_pred=sop_logits)
sop_loss_mask = tf.cast(labels["sentence_order_label"] != -100, dtype=unmasked_sop_loss.dtype)

# No reduction because this already has shape (num_samples,)
masked_sop_loss = unmasked_sop_loss * sop_loss_mask

return reduced_masked_lm_loss + masked_sop_loss


class TFAlbertEmbeddings(tf.keras.layers.Layer):
Expand Down
14 changes: 9 additions & 5 deletions src/transformers/models/bert/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,22 @@ def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
unmasked_lm_losses = loss_fn(y_true=labels["labels"], y_pred=logits[0])

# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
# make sure only labels that are not equal to -100
# are taken into account for the loss computation
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
lm_loss_denominator = tf.reduce_sum(lm_loss_mask, axis=1)
masked_lm_losses = tf.math.multiply_no_nan(unmasked_lm_losses, lm_loss_mask)
# Avoid potential division by zero later
lm_loss_denominator = tf.math.maximum(tf.cast(1, lm_loss_mask.dtype), tf.reduce_sum(lm_loss_mask, axis=1))
masked_lm_losses = unmasked_lm_losses * lm_loss_mask
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator

unmasked_ns_loss = loss_fn(y_true=labels["next_sentence_label"], y_pred=logits[1])
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels["next_sentence_label"]), y_pred=logits[1])
ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype)
# Just zero out samples where label is -100, no reduction
masked_ns_loss = tf.math.multiply_no_nan(unmasked_ns_loss, ns_loss_mask)
masked_ns_loss = unmasked_ns_loss * ns_loss_mask

return reduced_masked_lm_loss + masked_ns_loss

Expand Down
25 changes: 17 additions & 8 deletions src/transformers/models/led/modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -2505,11 +2505,20 @@ def _reorder_cache(past, beam_idx):
def hf_compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE,
)
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits)
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
if self.config.tf_legacy_loss:
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits)

# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
# make sure only non-padding labels affect the loss
loss_mask = tf.cast(labels != self.config.pad_token_id, dtype=unmasked_loss.dtype)
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
Loading