-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Bloom Optimize operations #17866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bloom Optimize operations #17866
Conversation
- get rid of for loops - deals better with padded batched input - avoid useless cpu/gpu communication when creating alibi Co-authored-by: justheuristic <justheuristic@gmail.com>
|
The documentation is not available anymore as the PR was closed or merged. |
|
I won't merge this now since I saw that it broke some slow tests, will investigate that! |
|
With the two proposed changes, all tests are now passing @younesbelkada :) |
|
Thanks a lot @NouamaneTazi !! Amazing job 🔥 |
|
Let's merge this together as some improvements to make the inference faster
|
|
Before merging, let's fix the code quality tests... |
ee5a2d5 to
fcfe5b7
Compare
…to pr/younesbelkada/17866
Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments on the style. Let's also wait for @patrickvonplaten comments as he dove deeper in the code than me.
| mask = mask.to(dtype) | ||
|
|
||
| if past_key_values_length > 0: | ||
| mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify here - we put zeros because those tokens should be attended to no? The masked out tokens have corresponding torch.finfo(dtype).min no?
|
all tests are passing now ! Is it ok if we merge this @stas00 (since you are working on DS inference just to check if this PR does not conflict anything with you work) ? |
|
I didn't have a chance to read this PR, but let me at least run a quick test with it. update: it looks fine for the 350b model - I'm waiting for the 176 to download and will test with it as well. if in a rush please go ahead and merge and if anything emerges we can fix it after. |
|
There were lots of changes since you got approval, so please wait for a re-review of @patrickvonplaten and me. |
|
|
||
| # Here is a summary of an ablation study of our observations | ||
| # EXPECTED_OUTPUT = "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am a very good listener. I am a very good person, and I am a very good person. I am a" | ||
| # 350m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why we need to have different cases for torch.bmm or torch.baddm -> let's just use one of them no? Also if the fp16 tests are flaky let's run the test instead in fp32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
totally agree ! Modified the tests to run in fp32, just checking if they effectively pass
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some clean-ups to do in the modeling code and I'm not a fan of making use of allow_fp16_reduced_precision_reduction in the tests -> let's instead run the tests in fp32 maybe?
| @slow | ||
| @require_torch_gpu | ||
| def test_batch_generation_padd(self): | ||
| # With small models the test will fail because of the operator torch.baddm that will give inconsistent results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest these statements are a bit confusing to me. I'd prefer to have stable tests in fp32 instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this comment but kept the one above, I think it is important to keep the comment above to explain our finding if we encounter any similar issue in the future
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
- remove warning message - remove print on test
|
Thanks a lot @sgugger @patrickvonplaten |
- do the tests in fp32 - remove some comments - keep large comments
- we now test using fp32
|
All tests are passing now (tested on A100) 🎉 |
|
Now tests pass on both A100 and Titan RTX 🎉 (because we used |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the iterations!
| batch_size, target_length = input_ids_shape | ||
| mask = torch.full((target_length, target_length), torch.finfo(dtype).min) | ||
| mask_cond = torch.arange(mask.size(-1)) | ||
| intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless I'm wrong that's just a big true vector no?
intermediate_mask = mask_cond[:,None] < (mask_cond + 1)[None, :]
Unless I'm wrong that's a lower triangular mask no?
True False False False
True True False False
True True True False
True True True True
Why not just run torch.triu(torch.full((target_length, target_length), torch.finfo(dtype).min, dtype=dtype), diagonal=1) (you can even run a inplace version of triu_ if your scared of memory footprint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that you can totally right! Thanks for the suggestion!
Let's maybe propose the enhancement in this PR: #18139
| contiguous_split_chunks ([`bool`], *optional*, default=`False`):: | ||
| If True, make each chunk contiguous in memory. | ||
| if past_key_values_length > 0: | ||
| mask = torch.cat([torch.zeros(target_length, past_key_values_length, dtype=dtype), mask], dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's assuming that target_length == 1 no?
| # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim] | ||
| new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim) | ||
| mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) | ||
| alibi = alibi.to(hidden_states.device) # to make the model possible to run under accelerate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should alibi have a default None then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, that's a nit as well ! should not be default to None
| alpha=(1.0 / self.norm_factor), | ||
| ) | ||
| # # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, k_length] -> [batch_size*num_heads, q_length, k_length] | ||
| matmul_result = (1.0 / self.norm_factor) * torch.bmm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was baddbmm dropped? Feels like exactly our usecase no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
baddbmm was initially replaced by bmm after our several experiments to try and solve the differences noticed in using half precisions (fp16 and bf16). Check our findings here.
The second reason is that, we were also investigating the necessity of using alpha and beta coefficients. As they were deemed unnecessary in the FLAX implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edit: I linked the wrong findings. I meant to link the comments for the test test_batch_generation_padd
https://github.com/younesbelkada/transformers/blob/eb86c4369d465d6181a865e43fd198b28c0b7a02/tests/models/bloom/test_modeling_bloom.py#L450
ydshieh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have 2 questions 🙏
| def test_simple_generation(self): | ||
| # This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations | ||
| # do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200 | ||
| # We set allow_fp16_reduced_precision_reduction = True. Please see: https://pytorch.org/docs/stable/notes/cuda.html#reduced-precision-reduction-in-fp16-gemms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am just wondering what this statement for: We set allow_fp16_reduced_precision_reduction = True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we meant to say that we leave allow_fp16_reduced_precision_reduction = True, as set by default. But I agree with you that the sentence could do some rephrasing. Wdyt @younesbelkada?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agree that the sentence is a bit confusing, we should say something like
Therefore we should set allow_fp16_reduced_precision_reduction = True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use should, it looks like there are some reasons, and it would be better to include it.
However, If I understand correctly, what you mean here is more about
As we leave the default value (True) for allow_fp16_reduced_precision_reduction , the tests failed when running in half-precision with smaller models (350m)
(just a description of the situation, and the fix we take is to run in fp32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exact, yes totally agree with your proposal. Let's just describe the situation !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's more of wording thing, I will leave this for you two to decide and open a PR if necessary. My only goal is to avoid confusion for the readers
| # >=760m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS (for use_cache=True and use_cache=False) | ||
| # >=760m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS | ||
| # >=760m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like no mentioning of # >=760m + allow_fp16_reduced_precision_reduction = False + torch.baddm? (But I guess it pass 😄 )
However, this seems suggesting 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm is the only failing case, and we have a comment
suggesting baddbmm is replaced by bmm ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that you are right, @NouamaneTazi I feel that we should use badddm instead no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was replaced by bmm because 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS is the only one that passed the test test_batch_generation_padd
Btw I just edited the comment you referenced. I linked the wrong test before. Here's the correct link: https://github.com/younesbelkada/transformers/blob/eb86c4369d465d6181a865e43fd198b28c0b7a02/tests/models/bloom/test_modeling_bloom.py#L450
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And yes @younesbelkada, we could go back to using baddbmm to keep the same operation used during training as mentioned by thomas
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay got it now sorry! @ydshieh, @NouamaneTazi just updated the comments to detail the reason why we were using bmm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing me to the correct place 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will open a small PR and let you know @NouamaneTazi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created a PR: #18175
Let's move the discussion there! @ydshieh @NouamaneTazi !
* fix tolerance for a bloom slow test * enhance alibi padding - get rid of for loops - deals better with padded batched input - avoid useless cpu/gpu communication when creating alibi Co-authored-by: justheuristic <justheuristic@gmail.com> * optimize attention mask * fix scaled softmax limit values * optimize building alibi tensor Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> * fix attention_mask shape when it's None * minor fixes - fix docstring + arg names * remove colons in docstring * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * apply suggestion * remove unsued arg * refactor a bit - use [:, None] for consistency * refactor attention block Co-authored-by: Nouamane Tazi <nouamane98@gmail.com> * quick fixes * first attempt * refactor attention block and fix all tests except "test_simple_generation" - added comments to better explain attention block * remove debug lines and add TODO comment * change `torch.bmm` to `torch.baddbmm` - fixes `test_simple_generation`but breaks `test_batch_generation_padd` * styling * all tests are passing now - use `bmm` - add explanation for `allow_fp16_reduced_precision_reduction` Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> * styling Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> * fix support for accelerate Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove attn softmax in fp32 * refactor comments * refactor a bit - remove warning message - remove print on test * refer to pytorch t5 * change the slow tests - do the tests in fp32 - remove some comments - keep large comments * update expected output for `test_simple_generation` - we now test using fp32 * make style + change comments a bit * fix dtype padd test Co-authored-by: justheuristic <justheuristic@gmail.com> Co-authored-by: Nouamane Tazi <nouamane98@gmail.com> Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Moved the original PR: #17759 here to check if the tests pass