diff --git a/src/transformers/models/bloom/configuration_bloom.py b/src/transformers/models/bloom/configuration_bloom.py index 23ecc6d92671..a33a6339b14e 100644 --- a/src/transformers/models/bloom/configuration_bloom.py +++ b/src/transformers/models/bloom/configuration_bloom.py @@ -214,14 +214,19 @@ def generate_dummy_inputs( batch, seqlen = common_inputs["input_ids"].shape # Not using the same length for past_key_values past_key_values_length = seqlen + 2 - past_shape = ( - batch, + head_dim = self._config.hidden_size // self.num_attention_heads + past_key_shape = ( + batch * self.num_attention_heads, + head_dim, past_key_values_length, - self.num_attention_heads, - self._config.hidden_size // self.num_attention_heads, + ) + past_value_shape = ( + batch * self.num_attention_heads, + past_key_values_length, + head_dim, ) ordered_inputs["past_key_values"] = [ - (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + (torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers) ] ordered_inputs["attention_mask"] = common_inputs["attention_mask"] diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index f6757d541b30..a33054a38351 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -61,8 +61,9 @@ def _make_causal_mask( """ batch_size, target_length = input_ids_shape mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device) - mask[:, past_key_values_length:] = True - mask[:, past_key_values_length:].triu_(diagonal=1) + # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround + seq_ids = torch.arange(target_length, device=device) + mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :] if past_key_values_length > 0: mask[:, :past_key_values_length] = False @@ -698,8 +699,7 @@ def forward( past_key_values_length = 0 if past_key_values[0] is not None: past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past += past_key_values_length - + seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: