Skip to content

Conversation

@NouamaneTazi
Copy link
Member

@NouamaneTazi NouamaneTazi commented Jul 18, 2022

This PR aims at fixing the following issues:

  • In this line, if we use minimum dtype values in the attention mask to mask some values. After adding a positive value, the masked values would come back to life. This PR proposes to use -inf in the attention mask instead, and only after the addition, we replace the inf values by the respective max/min dtype values
        input_dtype = attention_scores.dtype
        attn_weights = (attention_scores * self.layer_number) + attention_mask # torch.finfo(torch.float16).min + 1 is no longer torch.finfo(torch.float16).min (no longer hidden)
        attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
        attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
  • Use torch.clip instead of torch.max to ensure we avoid both -inf and +inf for softmax
  • [Only relevent if we use torch.finfo(dtype).min in attention mask] In this line, if we use the minimum dtype values, after performing the addition, we get mixed -inf and torch.finfo(dtype).min in the attention mask
      if attention_mask is not None:
           # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
           expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
           combined_attention_mask = (
               expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask # this gives `-inf` when we substract a number from `torch.finfo(dtype).min`
           )

All tests (including slow ones) are passing. ✅
Related to: #17437
Co-authored by: @younesbelkada
cc @ydshieh @stas00

NouamaneTazi and others added 3 commits July 18, 2022 15:31
- avoid having both `-inf` and `dtype.min` in causal mask due to addition
- clip values between dtype max and min to avoid infs (not liked by softmax)

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 18, 2022

The documentation is not available anymore as the PR was closed or merged.

- it's okay to use addition since we're using `-inf` again
@ydshieh
Copy link
Collaborator

ydshieh commented Jul 18, 2022

The two situations you described indeed exist. However, I think there is no real necessity to deal with them.

As long as there is at least one position to attend to, it doesn't matter if we have mixed -inf & torch.finfo(...).min, as well as if we have a positive value added to ``torch.finfo(...).min`. As long as the score(s) for the attended position(s) is/are within reasonable range, their scores will dominate the other unattended scores. (This should hold during the inference of a trained model, otherwise the model is broken.)

And for a sequence without any position to attend, nothing we can't do. If we want to go really rigorous, we should multiply the softmaxed-scores by zeros for the unattended places.

@NouamaneTazi
Copy link
Member Author

NouamaneTazi commented Jul 18, 2022

@ydshieh Are we sure attention_scores can never have very large values ? Because the worst case scenario would be for attention_scores to have the biggest value for a hidden token.
Also by comparing the outputs before and after this PR. It does seem that we get better generations (less repetition). But It needs more testing to be confirmed

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 18, 2022

@NouamaneTazi I don't think there is such guarantee, and what you mentioned is possible. However, it would be great if you can provide some examples for which you find this PR helps to get better results or solve some issues. Thank you!

@thomasw21
Copy link
Contributor

So stupid question: instead of running + operator, can we not run min with an attention mask that's torch.finfo(dtype).max in not masked values and torch.finfo(dtype).min in masked values and be done with it? Or torch.masked_fill(attention_mask, torch.findo(dtype).min)?

# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stupid question but when you sum torch.finfo(dtype).min with torch.finfo(dtype).min, it's not the masked value anymore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.finfo(dtype).min + torch.finfo(dtype).min = -inf

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would get something like this:

>> print(attention_mask)
tensor([[[[     0., -65504., -65504., -65504., -65504., -65504., -65504.],
          [     0.,      0., -65504., -65504., -65504., -65504., -65504.],
          [     0.,      0.,      0., -65504., -65504., -65504., -65504.],
          [     0.,      0.,      0.,      0., -65504., -65504., -65504.],
          [     0.,      0.,      0.,      0.,      0., -65504., -65504.],
          [     0.,      0.,      0.,      0.,      0.,      0., -65504.],
          [     0.,      0.,      0.,      0.,      0.,      0.,      0.]]],


        [[[-65504.,    -inf,    -inf, -65504., -65504., -65504., -65504.],
          [-65504., -65504.,    -inf, -65504., -65504., -65504., -65504.],
          [-65504., -65504., -65504., -65504., -65504., -65504., -65504.],
          [-65504., -65504., -65504.,      0., -65504., -65504., -65504.],
          [-65504., -65504., -65504.,      0.,      0., -65504., -65504.],
          [-65504., -65504., -65504.,      0.,      0.,      0., -65504.],
          [-65504., -65504., -65504.,      0.,      0.,      0.,      0.]]]],
       device='cuda:0', dtype=torch.float16)

@NouamaneTazi
Copy link
Member Author

So stupid question: instead of running + operator, can we not run min with an attention mask that's torch.finfo(dtype).max in not masked values and torch.finfo(dtype).min in masked values and be done with it? Or torch.masked_fill(attention_mask, torch.findo(dtype).min)?

I'm not sure what +operator are you refering to? Is it after the softmax operation? Or when creating the attention mask?

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 19, 2022

I'm not sure what +operator are you refering to? Is it after the softmax operation? Or when creating the attention mask?

I think @thomasw21 is talking about the place where an attn. score (where you say it could be positive) is added by the mask.
Regarding @thomasw21 question, it's also a valid approach (it's like a clamp in different order and reducing some ops). The current approach (simply +) is probably from the first model(s), like BERT/GPT2.

@NouamaneTazi
Copy link
Member Author

Should be fixed in this PR: #18344

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants