Skip to content

add ONNX support for BLOOM#17961

Merged
sgugger merged 7 commits intohuggingface:mainfrom
NouamaneTazi:main
Jul 1, 2022
Merged

add ONNX support for BLOOM#17961
sgugger merged 7 commits intohuggingface:mainfrom
NouamaneTazi:main

Conversation

@NouamaneTazi
Copy link
Member

What does this PR do?

add ONNX support for BLOOM

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@michaelbenayoun

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 30, 2022

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada
Copy link
Contributor

As you told me offline that the slow tests were passing (under torch1.11.0), looks good to me! Thanks for working on that 🔥

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the Bloom config @NouamaneTazi - a very clean PR 💮 !

If the slow tests pass, this PR looks good to me. Let's wait for approval from @LysandreJik or @sgugger before merging this

Edit: I see the CI is now failing for an unrelated issue. I've re-run it, but if it comes up red again, I suggest rebasing on main and pushing again

super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)


# Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

):
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
if not getattr(self._config, "pad_token_id", None):
# TODO: how to do that better?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hehe @michaelbenayoun we should fix this sometime :)

}

PYTORCH_EXPORT_WITH_PAST_MODELS = {
("bloom", "bigscience/bloom-350m"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you checked that the slow tests pass for this checkpoint? You can run:

RUN_SLOW=1 pytest tests/onnx/test_onnx_v2.py -k "bloom"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reminding me. All tests are passing now 🙂

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for handling this!

seq_ids = torch.arange(max_positions, device=input.device)
causal_mask = (
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
(seq_ids[None, None, :] <= seq_ids[None, :, None])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont know if this is relevant but the original implementation outputs a tensor of rank 2, and your change outputs a tensor of rank 3. Should not be a big deal since we do reshape it afterwards but just wanted to point this out.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we keep it like that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed :) Thank you for the notice 🤗

@sgugger
Copy link
Collaborator

sgugger commented Jun 30, 2022

I'm not too sure about the changes in modeling_bloom.py. Looks like not leveraging the bool type and converting to int32 will hurt performance. Wdyt @younesbelkada ?

@michaelbenayoun
Copy link
Member

I think the changes in modeling_bloom.py come from the fact that boolean tensors cannot be added in ONNX (not 100% sure). Two suggestions then:

I think that the first solution is both faster and more aligned with the original implementation.
WDYT?

@younesbelkada
Copy link
Contributor

@sgugger I do not think this will hurt performances in terms of logits since slow tests are passing, but might hurt indeed the inference time performance for large and/or batched sequences.. We need to benchmark that though to be sure

@sgugger
Copy link
Collaborator

sgugger commented Jun 30, 2022

@michaelbenayoun I think option 1 sounds good, yes!

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

seq_ids = torch.arange(max_positions, device=input.device)
causal_mask = (
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
(seq_ids[None, None, :] <= seq_ids[None, :, None])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we keep it like that?

@michaelbenayoun
Copy link
Member

Also make sure all the tests pass before merging.

@NouamaneTazi
Copy link
Member Author

All tests for tests/onnx/test_onnx_v2.py -k "bloom" and tests/models/bloom are passing.
Here are the ones that are skipped (which is fine according to @younesbelkada)

================================================================================= short test summary info =================================================================================
SKIPPED [1] tests/test_modeling_common.py:2006: test is PT+FLAX test
SKIPPED [1] tests/test_modeling_common.py:1934: test is PT+FLAX test
SKIPPED [1] tests/test_modeling_common.py:1758: test is PT+TF test
SKIPPED [1] tests/test_tokenization_common.py:1960: This test is only for slow tokenizers
SKIPPED [1] tests/test_tokenization_common.py:2189: test is PT+TF test
================================================================= 159 passed, 5 skipped, 35 warnings in 449.50s (0:07:29)

@sgugger
Copy link
Collaborator

sgugger commented Jul 1, 2022

There is a difference between a copy in BLOOM and the original in GPT-2 which is why the CI is failing. Make sure to run make fic-copies or remove the Copied from.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

@sgugger sgugger merged commit b68d408 into huggingface:main Jul 1, 2022
viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
* add onnx support for BLOOM

* use TYPE_CHECKING for type annotations

* fix past_shape for bloom (different from gpt2)

* use logical_or instead of `+` for onnx support

* bigger `atol_for_validation` for larger bloom models

* copied -> taken because it's no longer an exact copy

* remove "copied from" comment

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants