Skip to content

Conversation

@younesbelkada
Copy link
Contributor

Moved the original PR: #17759 here to check if the tests pass

younesbelkada and others added 2 commits June 9, 2022 18:04
- get rid of for loops
- deals better with padded batched input
- avoid useless cpu/gpu communication when creating alibi

Co-authored-by: justheuristic <justheuristic@gmail.com>
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 24, 2022

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada younesbelkada marked this pull request as ready for review June 24, 2022 14:35
@younesbelkada younesbelkada requested a review from sgugger June 24, 2022 14:35
@younesbelkada
Copy link
Contributor Author

I won't merge this now since I saw that it broke some slow tests, will investigate that!

@NouamaneTazi
Copy link
Member

With the two proposed changes, all tests are now passing @younesbelkada :)

@younesbelkada
Copy link
Contributor Author

Thanks a lot @NouamaneTazi !! Amazing job 🔥

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Jul 6, 2022

Let's merge this together as some improvements to make the inference faster

  • create attn mask only once
  • broadcast alibi only once instead of each time on the attention layer
  • Remove the contiguous calls and test the model
  • Refactor the reshaping (check how it is done in BLOOM Flax) in the attention layer

@sgugger
Copy link
Collaborator

sgugger commented Jul 6, 2022

Before merging, let's fix the code quality tests...

@younesbelkada younesbelkada force-pushed the bloom-enhance-alibi branch from ee5a2d5 to fcfe5b7 Compare July 6, 2022 14:28
@younesbelkada younesbelkada changed the title Bloom enhance alibi creation + shifting Bloom Optimize operations Jul 6, 2022
@sgugger sgugger requested a review from patrickvonplaten July 7, 2022 12:22
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments on the style. Let's also wait for @patrickvonplaten comments as he dove deeper in the code than me.

mask = mask.to(dtype)

if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify here - we put zeros because those tokens should be attended to no? The masked out tokens have corresponding torch.finfo(dtype).min no?

@younesbelkada
Copy link
Contributor Author

all tests are passing now ! Is it ok if we merge this @stas00 (since you are working on DS inference just to check if this PR does not conflict anything with you work) ?

@stas00
Copy link
Contributor

stas00 commented Jul 11, 2022

I didn't have a chance to read this PR, but let me at least run a quick test with it.

update: it looks fine for the 350b model - I'm waiting for the 176 to download and will test with it as well.

if in a rush please go ahead and merge and if anything emerges we can fix it after.

@sgugger
Copy link
Collaborator

sgugger commented Jul 11, 2022

There were lots of changes since you got approval, so please wait for a re-review of @patrickvonplaten and me.


# Here is a summary of an ablation study of our observations
# EXPECTED_OUTPUT = "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am a very good listener. I am a very good person, and I am a very good person. I am a"
# 350m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we need to have different cases for torch.bmm or torch.baddm -> let's just use one of them no? Also if the fp16 tests are flaky let's run the test instead in fp32

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

totally agree ! Modified the tests to run in fp32, just checking if they effectively pass

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some clean-ups to do in the modeling code and I'm not a fan of making use of allow_fp16_reduced_precision_reduction in the tests -> let's instead run the tests in fp32 maybe?

@slow
@require_torch_gpu
def test_batch_generation_padd(self):
# With small models the test will fail because of the operator torch.baddm that will give inconsistent results
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest these statements are a bit confusing to me. I'd prefer to have stable tests in fp32 instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this comment but kept the one above, I think it is important to keep the comment above to explain our finding if we encounter any similar issue in the future

younesbelkada and others added 6 commits July 11, 2022 18:17
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
- remove warning message
- remove print on test
@younesbelkada
Copy link
Contributor Author

Thanks a lot @sgugger @patrickvonplaten
I think that we will do the tests in fp32 instead, we just need to keep in mind that doing batched generation can be flaky for small models (<=350m) as we have identified it with @NouamaneTazi . We will put a comment on the tests explaining what we have found and I think that we should be good to go!

younesbelkada and others added 3 commits July 11, 2022 18:42
- do the tests in fp32
- remove some comments
- keep large comments
@NouamaneTazi
Copy link
Member

NouamaneTazi commented Jul 11, 2022

All tests are passing now (tested on A100) 🎉

@NouamaneTazi
Copy link
Member

NouamaneTazi commented Jul 11, 2022

Now tests pass on both A100 and Titan RTX 🎉 (because we used fp32)
(Note that the test BloomModelTest::test_batch_generation_padd is still failing on Titan RTX in fp16 whether for this PR or the main branch, because of the issue mentioned above)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the iterations!

@sgugger sgugger merged commit a462fc9 into huggingface:main Jul 11, 2022
batch_size, target_length = input_ids_shape
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
mask_cond = torch.arange(mask.size(-1))
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
Copy link
Contributor

@thomasw21 thomasw21 Jul 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm wrong that's just a big true vector no?

intermediate_mask = mask_cond[:,None] < (mask_cond + 1)[None, :]
Unless I'm wrong that's a lower triangular mask no?

True False False False
True True  False False
True True  True  False
True True  True  True

Why not just run torch.triu(torch.full((target_length, target_length), torch.finfo(dtype).min, dtype=dtype), diagonal=1) (you can even run a inplace version of triu_ if your scared of memory footprint.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you can totally right! Thanks for the suggestion!
Let's maybe propose the enhancement in this PR: #18139

contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
If True, make each chunk contiguous in memory.
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(target_length, past_key_values_length, dtype=dtype), mask], dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's assuming that target_length == 1 no?

# [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
alibi = alibi.to(hidden_states.device) # to make the model possible to run under accelerate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should alibi have a default None then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, that's a nit as well ! should not be default to None

alpha=(1.0 / self.norm_factor),
)
# # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, k_length] -> [batch_size*num_heads, q_length, k_length]
matmul_result = (1.0 / self.norm_factor) * torch.bmm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was baddbmm dropped? Feels like exactly our usecase no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

baddbmm was initially replaced by bmm after our several experiments to try and solve the differences noticed in using half precisions (fp16 and bf16). Check our findings here.
The second reason is that, we were also investigating the necessity of using alpha and beta coefficients. As they were deemed unnecessary in the FLAX implementation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have 2 questions 🙏

def test_simple_generation(self):
# This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations
# do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200
# We set allow_fp16_reduced_precision_reduction = True. Please see: https://pytorch.org/docs/stable/notes/cuda.html#reduced-precision-reduction-in-fp16-gemms
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just wondering what this statement for: We set allow_fp16_reduced_precision_reduction = True.

Copy link
Member

@NouamaneTazi NouamaneTazi Jul 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we meant to say that we leave allow_fp16_reduced_precision_reduction = True, as set by default. But I agree with you that the sentence could do some rephrasing. Wdyt @younesbelkada?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agree that the sentence is a bit confusing, we should say something like
Therefore we should set allow_fp16_reduced_precision_reduction = True

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use should, it looks like there are some reasons, and it would be better to include it.

However, If I understand correctly, what you mean here is more about

As we leave the default value (True) for allow_fp16_reduced_precision_reduction , the tests failed when running in half-precision with smaller models (350m)

(just a description of the situation, and the fix we take is to run in fp32)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exact, yes totally agree with your proposal. Let's just describe the situation !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more of wording thing, I will leave this for you two to decide and open a PR if necessary. My only goal is to avoid confusion for the readers

# >=760m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS (for use_cache=True and use_cache=False)
# >=760m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS
# >=760m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS

Copy link
Collaborator

@ydshieh ydshieh Jul 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like no mentioning of # >=760m + allow_fp16_reduced_precision_reduction = False + torch.baddm? (But I guess it pass 😄 )

However, this seems suggesting 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm is the only failing case, and we have a comment

#17866 (comment)

suggesting baddbmm is replaced by bmm ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you are right, @NouamaneTazi I feel that we should use badddm instead no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was replaced by bmm because 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS is the only one that passed the test test_batch_generation_padd

Btw I just edited the comment you referenced. I linked the wrong test before. Here's the correct link: https://github.com/younesbelkada/transformers/blob/eb86c4369d465d6181a865e43fd198b28c0b7a02/tests/models/bloom/test_modeling_bloom.py#L450

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yes @younesbelkada, we could go back to using baddbmm to keep the same operation used during training as mentioned by thomas

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay got it now sorry! @ydshieh, @NouamaneTazi just updated the comments to detail the reason why we were using bmm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing me to the correct place 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will open a small PR and let you know @NouamaneTazi

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created a PR: #18175
Let's move the discussion there! @ydshieh @NouamaneTazi !

viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
* fix tolerance for a bloom slow test

* enhance alibi padding

- get rid of for loops
- deals better with padded batched input
- avoid useless cpu/gpu communication when creating alibi

Co-authored-by: justheuristic <justheuristic@gmail.com>

* optimize attention mask

* fix scaled softmax limit values

* optimize building alibi tensor

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* fix attention_mask shape when it's None

* minor fixes

- fix docstring + arg names

* remove colons in docstring

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* apply suggestion

* remove unsued arg

* refactor a bit

- use [:, None] for consistency

* refactor attention block

Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>

* quick fixes

* first attempt

* refactor attention block and fix all tests except "test_simple_generation"

- added comments to better explain attention block

* remove debug lines and add TODO comment

* change `torch.bmm` to `torch.baddbmm`
- fixes `test_simple_generation`but breaks `test_batch_generation_padd`

* styling

* all tests are passing now
- use `bmm`
- add explanation for `allow_fp16_reduced_precision_reduction`

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* styling

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* fix support for accelerate

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove attn softmax in fp32

* refactor comments

* refactor a bit

- remove warning message
- remove print on test

* refer to pytorch t5

* change the slow tests

- do the tests in fp32
- remove some comments
- keep large comments

* update expected output for `test_simple_generation`
- we now test using fp32

* make style + change comments a bit

* fix dtype padd test

Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>
Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@patrickvonplaten patrickvonplaten deleted the bloom-enhance-alibi branch August 23, 2022 15:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants