diff --git a/src/transformers/models/bloom/configuration_bloom.py b/src/transformers/models/bloom/configuration_bloom.py index 4f973d93ae48..b64d9f044a98 100644 --- a/src/transformers/models/bloom/configuration_bloom.py +++ b/src/transformers/models/bloom/configuration_bloom.py @@ -152,7 +152,6 @@ def __init__( class BloomOnnxConfig(OnnxConfigWithPast): - torch_onnx_minimum_version = version.parse("1.12") def __init__( @@ -171,7 +170,8 @@ def __init__( def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) if self.use_past: - self.fill_with_past_key_values_(common_inputs, direction="inputs") + # BLOOM stores values on dynamic axis 2. For more details see: https://github.com/huggingface/transformers/pull/18344 + self.fill_with_past_key_values_(common_inputs, direction="inputs", inverted_values_shape=True) common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} else: common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index dc27c0b6924d..5a1c3e6eede5 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -486,7 +486,9 @@ def generate_dummy_inputs( return common_inputs - def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str): + def fill_with_past_key_values_( + self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str, inverted_values_shape: bool = False + ): """ Fill the input_or_outputs mapping with past_key_values dynamic axes considering. @@ -494,6 +496,8 @@ def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int inputs_or_outputs: The mapping to fill. direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the output mapping, this is important for axes naming. + inverted_values_shape: + If `True`, store values on dynamic axis 1, else on axis 2. """ if direction not in ["inputs", "outputs"]: @@ -502,7 +506,10 @@ def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int name = "past_key_values" if direction == "inputs" else "present" for i in range(self.num_layers): inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} - inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + if inverted_values_shape: + inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 1: "past_sequence + sequence"} + else: + inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} def _flatten_past_key_values_(self, flattened_output, name, idx, t): flattened_output[f"{name}.{idx}.key"] = t[0]