Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
As you told me offline that the slow tests were passing (under torch1.11.0), looks good to me! Thanks for working on that 🔥 |
There was a problem hiding this comment.
Thanks for adding the Bloom config @NouamaneTazi - a very clean PR 💮 !
If the slow tests pass, this PR looks good to me. Let's wait for approval from @LysandreJik or @sgugger before merging this
Edit: I see the CI is now failing for an unrelated issue. I've re-run it, but if it comes up red again, I suggest rebasing on main and pushing again
| super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) | ||
|
|
||
|
|
||
| # Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig |
| ): | ||
| super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) | ||
| if not getattr(self._config, "pad_token_id", None): | ||
| # TODO: how to do that better? |
There was a problem hiding this comment.
Hehe @michaelbenayoun we should fix this sometime :)
| } | ||
|
|
||
| PYTORCH_EXPORT_WITH_PAST_MODELS = { | ||
| ("bloom", "bigscience/bloom-350m"), |
There was a problem hiding this comment.
Have you checked that the slow tests pass for this checkpoint? You can run:
RUN_SLOW=1 pytest tests/onnx/test_onnx_v2.py -k "bloom"
There was a problem hiding this comment.
Thank you for reminding me. All tests are passing now 🙂
michaelbenayoun
left a comment
There was a problem hiding this comment.
Great, thanks for handling this!
| seq_ids = torch.arange(max_positions, device=input.device) | ||
| causal_mask = ( | ||
| torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)) | ||
| (seq_ids[None, None, :] <= seq_ids[None, :, None]) |
There was a problem hiding this comment.
Dont know if this is relevant but the original implementation outputs a tensor of rank 2, and your change outputs a tensor of rank 3. Should not be a big deal since we do reshape it afterwards but just wanted to point this out.
There was a problem hiding this comment.
Do we keep it like that?
There was a problem hiding this comment.
Fixed :) Thank you for the notice 🤗
|
I'm not too sure about the changes in |
|
I think the changes in
I think that the first solution is both faster and more aligned with the original implementation. |
|
@sgugger I do not think this will hurt performances in terms of logits since slow tests are passing, but might hurt indeed the inference time performance for large and/or batched sequences.. We need to benchmark that though to be sure |
|
@michaelbenayoun I think option 1 sounds good, yes! |
| seq_ids = torch.arange(max_positions, device=input.device) | ||
| causal_mask = ( | ||
| torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)) | ||
| (seq_ids[None, None, :] <= seq_ids[None, :, None]) |
There was a problem hiding this comment.
Do we keep it like that?
|
Also make sure all the tests pass before merging. |
|
All tests for |
|
There is a difference between a copy in BLOOM and the original in GPT-2 which is why the CI is failing. Make sure to run |
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* add onnx support for BLOOM * use TYPE_CHECKING for type annotations * fix past_shape for bloom (different from gpt2) * use logical_or instead of `+` for onnx support * bigger `atol_for_validation` for larger bloom models * copied -> taken because it's no longer an exact copy * remove "copied from" comment Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
What does this PR do?
add ONNX support for BLOOM
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@michaelbenayoun