Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/models/bloom/configuration_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def __init__(


class BloomOnnxConfig(OnnxConfigWithPast):

torch_onnx_minimum_version = version.parse("1.12")

def __init__(
Expand All @@ -171,7 +170,8 @@ def __init__(
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
# BLOOM stores values on dynamic axis 2. For more details see: https://github.com/huggingface/transformers/pull/18344
self.fill_with_past_key_values_(common_inputs, direction="inputs", inverted_values_shape=True)
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
Expand Down
11 changes: 9 additions & 2 deletions src/transformers/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,14 +486,18 @@ def generate_dummy_inputs(

return common_inputs

def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str):
def fill_with_past_key_values_(
self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str, inverted_values_shape: bool = False
):
"""
Fill the input_or_outputs mapping with past_key_values dynamic axes considering.

Args:
inputs_or_outputs: The mapping to fill.
direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the
output mapping, this is important for axes naming.
inverted_values_shape:
If `True`, store values on dynamic axis 1, else on axis 2.

"""
if direction not in ["inputs", "outputs"]:
Expand All @@ -502,7 +506,10 @@ def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int
name = "past_key_values" if direction == "inputs" else "present"
for i in range(self.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
if inverted_values_shape:
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 1: "past_sequence + sequence"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lewtun Do you know how this string past_sequence + sequence is used? Is there some syntax we need to follow to make ONNX things work? For example, what happens if I write past_sequence + present_sequence? Or even 0: "batch_dim"?

I have ONNX knowledge only of measure 0 😢 Is there any some guide to understand this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, from
https://github.com/huggingface/transformers/blob/c8315e3bb0d612993983f401acf44b447e9280d3/src/transformers/onnx/config.py#507
(or just what @NouamaneTazi did in this PR), it looks like we need to specify dynamic axes not only for the model inputs , but also required for all intermediate layers. Is this correct?

Copy link
Member

@lewtun lewtun Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh these strings refer to the names one assigns to input/output nodes in the ONNX graph. It is just for readability, so changing the string has no effect on the API of transformers.onnx

See e.g. the PyTorch docs for more details: https://pytorch.org/docs/stable/onnx.html#example-alexnet-from-pytorch-to-onnx

Regarding your question about intermediate layers, these names are inferred from the ONNX ops generated in the export, so we never have to define them manually. For example, here's the names associated with BERT (directly from the Hub 🔥 ): https://netron.app/?url=https://huggingface.co/cmarkea/distilcamembert-base-ner/blob/main/model.onnx

else:
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}

def _flatten_past_key_values_(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.key"] = t[0]
Expand Down