-
Notifications
You must be signed in to change notification settings - Fork 32.1k
XLA train step fixes #17973
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XLA train step fixes #17973
Changes from all commits
b924b40
0414ced
167fd32
e01286d
3e53793
43ce3f5
4060777
391b050
8035a27
58e3db8
3b9fe74
db79798
9a6b7b5
92a4e79
6021439
a46da25
64c0e77
d34a3b2
32078b2
a19ee4f
f17136c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -195,11 +195,22 @@ def hf_compute_loss(self, labels, logits): | |
| loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( | ||
| from_logits=True, reduction=tf.keras.losses.Reduction.NONE | ||
| ) | ||
| if self.config.tf_legacy_loss: | ||
| # make sure only labels that are not equal to -100 affect the loss | ||
| active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) | ||
| reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) | ||
| labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) | ||
| return loss_fn(labels, reduced_logits) | ||
|
|
||
| # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway | ||
| unmasked_loss = loss_fn(tf.nn.relu(labels), logits) | ||
| # make sure only labels that are not equal to -100 affect the loss | ||
| active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) | ||
| reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) | ||
| labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) | ||
| return loss_fn(labels, reduced_logits) | ||
| loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype) | ||
| # Avoid division by zero later | ||
| loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1)) | ||
| masked_loss = unmasked_loss * loss_mask | ||
| reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator | ||
| return reduced_masked_loss | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this is not equivalent the previous computation.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was checking the docs now, and PT only outputs one number for the batch -- so it makes sense to include the sum here 👍 (we would have to update the TF docstring)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (BTW, the PT output the average instead of the sum, as I see in
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will be impossible to keep the old behaviour with XLA, because the number of 'active' tokens will change in each batch. We could return a scalar number, but in Keras it's nice to return a vector of per-sample losses, because this means the user can use the |
||
|
|
||
|
|
||
| class TFQuestionAnsweringLoss: | ||
|
|
@@ -232,17 +243,34 @@ def hf_compute_loss(self, labels, logits): | |
| loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( | ||
| from_logits=True, reduction=tf.keras.losses.Reduction.NONE | ||
| ) | ||
| # make sure only labels that are not equal to -100 | ||
| # are taken into account as loss | ||
| if tf.math.reduce_any(labels == -1): | ||
| tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.") | ||
| active_loss = tf.reshape(labels, (-1,)) != -1 | ||
| else: | ||
| active_loss = tf.reshape(labels, (-1,)) != -100 | ||
| reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) | ||
| labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) | ||
| if tf.executing_eagerly(): # Data-dependent conditionals are forbidden in XLA | ||
| if tf.math.reduce_any(labels == -1): | ||
| tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.") | ||
|
|
||
| if self.config.tf_legacy_loss: | ||
| # make sure only labels that are not equal to -100 | ||
| # are taken into account as loss | ||
| if tf.math.reduce_any(labels == -1): | ||
| tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.") | ||
| active_loss = tf.reshape(labels, (-1,)) != -1 | ||
| else: | ||
| active_loss = tf.reshape(labels, (-1,)) != -100 | ||
| reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) | ||
| labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) | ||
|
|
||
| return loss_fn(labels, reduced_logits) | ||
|
|
||
| return loss_fn(labels, reduced_logits) | ||
| # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway | ||
| unmasked_loss = loss_fn(tf.nn.relu(labels), logits) | ||
| # make sure only labels that are not equal to -100 or -1 | ||
| # are taken into account as loss | ||
| loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype) | ||
| # Avoid possible division by zero later | ||
| loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1)) | ||
| # Masked positions will have a loss of NaN because -100 and -1 are not valid labels | ||
| masked_loss = unmasked_loss * loss_mask | ||
| reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator | ||
| return reduced_masked_loss | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question above the averaging along dim 1 in |
||
|
|
||
|
|
||
| class TFSequenceClassificationLoss: | ||
|
|
@@ -251,7 +279,7 @@ class TFSequenceClassificationLoss: | |
| """ | ||
|
|
||
| def hf_compute_loss(self, labels, logits): | ||
| if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1: | ||
| if logits.shape.rank == 1 or logits.shape[1] == 1: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this change really necessary?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also don't remember the exact issue, other than it often causes XLA compilation to fail :)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leaving
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't see Regarding
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, sorry, I should explain! I'm not 100% sure of the exact rules it uses, but all I can tell you is that it failed before and it works like this! |
||
| loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE) | ||
| else: | ||
| loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( | ||
|
|
@@ -298,13 +326,25 @@ def hf_compute_loss(self, labels, logits): | |
| loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( | ||
| from_logits=True, reduction=tf.keras.losses.Reduction.NONE | ||
| ) | ||
| if self.config.tf_legacy_loss: | ||
| # make sure only labels that are not equal to -100 | ||
| # are taken into account as loss | ||
| next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) | ||
| next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss) | ||
| next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss) | ||
|
|
||
| return loss_fn(next_sentence_label, next_sentence_reduced_logits) | ||
|
|
||
| # make sure only labels that are not equal to -100 | ||
| # are taken into account as loss | ||
| next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) | ||
| next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss) | ||
| next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss) | ||
|
|
||
| return loss_fn(next_sentence_label, next_sentence_reduced_logits) | ||
| # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway | ||
| unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits) | ||
| ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype) | ||
| # Just zero out samples where label is -100, no reduction | ||
| masked_ns_loss = unmasked_ns_loss * ns_loss_mask | ||
|
|
||
| return masked_ns_loss | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this one be reduced as well?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
|
|
||
| def booleans_processing(config, **kwargs): | ||
|
|
@@ -1327,6 +1367,13 @@ def train_step(self, data): | |
| if not self._using_dummy_loss: | ||
| data = data_adapter.expand_1d(data) | ||
| x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) | ||
| # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify | ||
| # them during input/label pre-processing. This avoids surprising the user by wrecking their data. | ||
| # In addition, modifying mutable Python inputs makes XLA compilation impossible. | ||
| if isinstance(x, dict): | ||
| x = x.copy() | ||
| if isinstance(y, dict): | ||
| y = y.copy() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice~ |
||
|
|
||
| # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments, | ||
| # if those keys are not already present in the input dict | ||
|
|
@@ -1424,6 +1471,13 @@ def test_step(self, data): | |
| if not self._using_dummy_loss: | ||
| data = data_adapter.expand_1d(data) | ||
| x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) | ||
| # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify | ||
| # them during input/label pre-processing. This avoids surprising the user by wrecking their data. | ||
| # In addition, modifying mutable Python inputs makes XLA compilation impossible. | ||
| if isinstance(x, dict): | ||
| x = x.copy() | ||
| if isinstance(y, dict): | ||
| y = y.copy() | ||
|
|
||
| # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments, | ||
| # if those keys are not already present in the input dict | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure we want to set the default to
False? This is breaking no? Also it's a somewhat hard-to-discover silent error in this case no?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's only very slightly breaking - anyone using Keras or a custom model will not notice any change. The existing losses return very strange shapes like vectors of shape
(num_unmasked_tokens,)that vary in length each iteration, with no mapping from there back to the original tokens. I doubt anyone is using them directly, without computingtf.reduce_mean()on them.