-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Fix BLOOM's softmax for half precisions #18185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- avoid having both `-inf` and `dtype.min` in causal mask due to addition
- clip values between dtype max and min to avoid infs (not liked by softmax) Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
|
The documentation is not available anymore as the PR was closed or merged. |
- it's okay to use addition since we're using `-inf` again
|
The two situations you described indeed exist. However, I think there is no real necessity to deal with them. As long as there is at least one position to attend to, it doesn't matter if we have mixed And for a sequence without any position to attend, nothing we can't do. If we want to go really rigorous, we should multiply the softmaxed-scores by zeros for the unattended places. |
|
@ydshieh Are we sure |
|
@NouamaneTazi I don't think there is such guarantee, and what you mentioned is possible. However, it would be great if you can provide some examples for which you find this PR helps to get better results or solve some issues. Thank you! |
|
So stupid question: instead of running |
| # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | ||
| expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) | ||
| combined_attention_mask = ( | ||
| expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stupid question but when you sum torch.finfo(dtype).min with torch.finfo(dtype).min, it's not the masked value anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.finfo(dtype).min + torch.finfo(dtype).min = -inf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would get something like this:
>> print(attention_mask)
tensor([[[[ 0., -65504., -65504., -65504., -65504., -65504., -65504.],
[ 0., 0., -65504., -65504., -65504., -65504., -65504.],
[ 0., 0., 0., -65504., -65504., -65504., -65504.],
[ 0., 0., 0., 0., -65504., -65504., -65504.],
[ 0., 0., 0., 0., 0., -65504., -65504.],
[ 0., 0., 0., 0., 0., 0., -65504.],
[ 0., 0., 0., 0., 0., 0., 0.]]],
[[[-65504., -inf, -inf, -65504., -65504., -65504., -65504.],
[-65504., -65504., -inf, -65504., -65504., -65504., -65504.],
[-65504., -65504., -65504., -65504., -65504., -65504., -65504.],
[-65504., -65504., -65504., 0., -65504., -65504., -65504.],
[-65504., -65504., -65504., 0., 0., -65504., -65504.],
[-65504., -65504., -65504., 0., 0., 0., -65504.],
[-65504., -65504., -65504., 0., 0., 0., 0.]]]],
device='cuda:0', dtype=torch.float16)
I'm not sure what |
I think @thomasw21 is talking about the place where an attn. score (where you say it could be positive) is added by the mask. |
|
Should be fixed in this PR: #18344 |
This PR aims at fixing the following issues:
-infin the attention mask instead, and only after the addition, we replace the inf values by the respective max/min dtype valuestorch.clipinstead oftorch.maxto ensure we avoid both-infand+inffor softmaxtorch.finfo(dtype).minin attention mask] In this line, if we use the minimum dtype values, after performing the addition, we get mixed-infandtorch.finfo(dtype).minin the attention maskAll tests (including slow ones) are passing. ✅
Related to: #17437
Co-authored by: @younesbelkada
cc @ydshieh @stas00