Skip to content

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jul 14, 2022

What does this PR do?

This PR tries to address a strange behaviour observed when inferring bloom-176 model using DeepSpeed!
My intuitions are:

  • In the previous code we used -10000 for the attention mask filling value whereas we should use fp32.min as it is written in the original cuda kernel of FusedScaledSoftmax. This might lead to inconsistent result between the old version and the new version, but the new version should be considered as the correct one
  • @RezaYazdaniAminabadi discovered that attention scores should not be multiplied by the attention mask after the softmax, which makes sense and could fix the issue

cc @RezaYazdaniAminabadi @stas00 @thomasw21

- remove element wise multiplication after softmax
younesbelkada referenced this pull request Jul 14, 2022
* fix tolerance for a bloom slow test

* enhance alibi padding

- get rid of for loops
- deals better with padded batched input
- avoid useless cpu/gpu communication when creating alibi

Co-authored-by: justheuristic <justheuristic@gmail.com>

* optimize attention mask

* fix scaled softmax limit values

* optimize building alibi tensor

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* fix attention_mask shape when it's None

* minor fixes

- fix docstring + arg names

* remove colons in docstring

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* apply suggestion

* remove unsued arg

* refactor a bit

- use [:, None] for consistency

* refactor attention block

Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>

* quick fixes

* first attempt

* refactor attention block and fix all tests except "test_simple_generation"

- added comments to better explain attention block

* remove debug lines and add TODO comment

* change `torch.bmm` to `torch.baddbmm`
- fixes `test_simple_generation`but breaks `test_batch_generation_padd`

* styling

* all tests are passing now
- use `bmm`
- add explanation for `allow_fp16_reduced_precision_reduction`

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* styling

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* fix support for accelerate

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove attn softmax in fp32

* refactor comments

* refactor a bit

- remove warning message
- remove print on test

* refer to pytorch t5

* change the slow tests

- do the tests in fp32
- remove some comments
- keep large comments

* update expected output for `test_simple_generation`
- we now test using fp32

* make style + change comments a bit

* fix dtype padd test

Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>
Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@younesbelkada younesbelkada changed the title Fix DeepSpeed inference issue Fix BLOOM DeepSpeed inference issue Jul 14, 2022
@younesbelkada
Copy link
Contributor Author

@RezaYazdaniAminabadi did you tried to infer by removing the elementwise multiplication after the softmax as proposed in the PR?
When trying to infer on 8xA100 80GB I did obtained the same generations using the old code vs the new one with batch_size=1

@RezaYazdaniAminabadi
Copy link
Contributor

@RezaYazdaniAminabadi did you tried to infer by removing the elementwise multiplication after the softmax as proposed in the PR? When trying to infer on 8xA100 80GB I did obtained the same generations using the old code vs the new one with batch_size=1
Hi @younesbelkada,

I did try this on 16 A100-40GB previously and it was not giving similar results. I will try with this one and let you know. Anyhow, I think that multiply is not needed since the scores are already masked.
Thanks

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@younesbelkada
Copy link
Contributor Author

Thank you very much @RezaYazdaniAminabadi !!

@younesbelkada
Copy link
Contributor Author

Finally after doing some tests it appears that we need the multiplication with the attention mask because of the following:
in some cases we have an attention mask like the one below

0 0 0 0 0 
0 1 0 0 0
0 1 1 0 0 
0 1 1 1 0 
0 1 1 1 1

After replacing all zeros by torch.finfo(dtype.min) , the softmax will return the following on the first row:
0.2 0.2 0.2 0.2 0.2 because we have the same values on the first row. To avoid using these wrong values on the calculation later I had to multiply the attention scores by the original mask.

cc @NouamaneTazi

@stas00
Copy link
Contributor

stas00 commented Jul 18, 2022

@younesbelkada, ok, so we have the first row of 0.2 0.2 0.2 0.2 0.2 let's follow through to the end - where does that manifest an issue?

Let's perhaps use a small concrete example and use it to document why things are done the way they are - otherwise everybody will keep on questioning why this is done this way.

@thomasw21
Copy link
Contributor

Is this because of padding, we should not care about the padding row, ie when the padding is the query. The wrong values don't matter when they are in the padding no?

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Jul 19, 2022

My guess was this will impact the computation of the context_layer tensor here in the case we have padded inputs as mentioned by @thomasw21
So at the end you are right ! Indeed it impacts the computation of this tensor but I think that it does not matter at all. At the end we get a token-to-token correspondance for the computed hidden states - ie the context layer will have a shape batch_size x seq_len x hidden_dim and the hidden states corresponding to the padding tokens will not impact anyway the prediction of the next token. Do you think that this explanation makes sense?

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Aug 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants