-
Notifications
You must be signed in to change notification settings - Fork 31.9k
fix BLOOM ONNX config #19573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix BLOOM ONNX config #19573
Conversation
- `value` params have `seq_len` as their 2nd axe as opposed to other models which have it as 3rd
|
The documentation is not available anymore as the PR was closed or merged. |
ydshieh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but I leave a few questions for learning purpose. Thank you for making bloom even more blooming, @NouamaneTazi !
| inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} | ||
| inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} | ||
| if inverted_values_shape: | ||
| inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 1: "past_sequence + sequence"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lewtun Do you know how this string past_sequence + sequence is used? Is there some syntax we need to follow to make ONNX things work? For example, what happens if I write past_sequence + present_sequence? Or even 0: "batch_dim"?
I have ONNX knowledge only of measure 0 😢 Is there any some guide to understand this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, from
https://github.com/huggingface/transformers/blob/c8315e3bb0d612993983f401acf44b447e9280d3/src/transformers/onnx/config.py#507
(or just what @NouamaneTazi did in this PR), it looks like we need to specify dynamic axes not only for the model inputs , but also required for all intermediate layers. Is this correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ydshieh these strings refer to the names one assigns to input/output nodes in the ONNX graph. It is just for readability, so changing the string has no effect on the API of transformers.onnx
See e.g. the PyTorch docs for more details: https://pytorch.org/docs/stable/onnx.html#example-alexnet-from-pytorch-to-onnx
Regarding your question about intermediate layers, these names are inferred from the ONNX ops generated in the export, so we never have to define them manually. For example, here's the names associated with BERT (directly from the Hub 🔥 ): https://netron.app/?url=https://huggingface.co/cmarkea/distilcamembert-base-ner/blob/main/model.onnx
lewtun
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for fixing this bug @NouamaneTazi 💪 !
I've left a few nits, but otherwise this LGTM
| inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} | ||
| inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} | ||
| if inverted_values_shape: | ||
| inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 1: "past_sequence + sequence"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ydshieh these strings refer to the names one assigns to input/output nodes in the ONNX graph. It is just for readability, so changing the string has no effect on the API of transformers.onnx
See e.g. the PyTorch docs for more details: https://pytorch.org/docs/stable/onnx.html#example-alexnet-from-pytorch-to-onnx
Regarding your question about intermediate layers, these names are inferred from the ONNX ops generated in the export, so we never have to define them manually. For example, here's the names associated with BERT (directly from the Hub 🔥 ): https://netron.app/?url=https://huggingface.co/cmarkea/distilcamembert-base-ner/blob/main/model.onnx
Co-authored-by: lewtun <[email protected]>
Co-authored-by: lewtun <[email protected]>
lewtun
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating @NouamaneTazi !
This PR LGTM, so gently pinging @sgugger for final approval :)
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
Fixes dynamic axes for BloomOnnxConfig. After this PR #18344, if use_past is used
{0: 'batch', 1: 'past_sequence + sequence'}{0: 'batch', 2: 'past_sequence + sequence'}Should also fix failing tests for BLOOM's ONNX export.
(tested using
RUN_SLOW=1 pytest tests/onnx/test_onnx_v2.py -k "bloom" -s -x)cc @lewtun @ydshieh