-
Notifications
You must be signed in to change notification settings - Fork 32k
fix BLOOM ONNX config #19573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix BLOOM ONNX config #19573
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -486,14 +486,18 @@ def generate_dummy_inputs( | |
|
|
||
| return common_inputs | ||
|
|
||
| def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str): | ||
| def fill_with_past_key_values_( | ||
| self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str, inverted_values_shape: bool = False | ||
| ): | ||
| """ | ||
| Fill the input_or_outputs mapping with past_key_values dynamic axes considering. | ||
|
|
||
| Args: | ||
| inputs_or_outputs: The mapping to fill. | ||
| direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the | ||
| output mapping, this is important for axes naming. | ||
| inverted_values_shape: | ||
NouamaneTazi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| If `True`, store values on dynamic axis 1, else on axis 2. | ||
|
|
||
| """ | ||
| if direction not in ["inputs", "outputs"]: | ||
|
|
@@ -502,7 +506,10 @@ def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int | |
| name = "past_key_values" if direction == "inputs" else "present" | ||
| for i in range(self.num_layers): | ||
| inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} | ||
| inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} | ||
| if inverted_values_shape: | ||
| inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 1: "past_sequence + sequence"} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lewtun Do you know how this string I have ONNX knowledge only of measure 0 😢 Is there any some guide to understand this?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, from
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ydshieh these strings refer to the names one assigns to input/output nodes in the ONNX graph. It is just for readability, so changing the string has no effect on the API of See e.g. the PyTorch docs for more details: https://pytorch.org/docs/stable/onnx.html#example-alexnet-from-pytorch-to-onnx Regarding your question about intermediate layers, these names are inferred from the ONNX ops generated in the export, so we never have to define them manually. For example, here's the names associated with BERT (directly from the Hub 🔥 ): https://netron.app/?url=https://huggingface.co/cmarkea/distilcamembert-base-ner/blob/main/model.onnx |
||
| else: | ||
| inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} | ||
|
|
||
| def _flatten_past_key_values_(self, flattened_output, name, idx, t): | ||
| flattened_output[f"{name}.{idx}.key"] = t[0] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.