diff --git a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 7b366adacc..9f451256c9 100644 --- a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -294,6 +294,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, + cache_idx: Optional[int] = None, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], @@ -334,18 +335,37 @@ def forward( key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - if layer_past is not None: + _, q_len, _ = hidden_states.size() + bucket_internal_decode_stage = cache_idx is not None and q_len == 1 + + if not bucket_internal_decode_stage: + if layer_past is not None: + past_key, past_value = layer_past.split((self.head_dim, self.head_dim), dim=-1) + if token_idx is not None: + # Using out of place version of index_add_() to ensure the intermediate tensors are not lost when HPU graphs are enabled. + key = past_key.index_add(1, token_idx - 1, key - torch.index_select(past_key, 1, token_idx - 1)) + value = past_value.index_add( + 1, token_idx - 1, value - torch.index_select(past_value, 1, token_idx - 1) + ) + else: + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = torch.cat((key, value), dim=-1) if use_cache else None + else: + assert token_idx is not None, "Invalid parameters: token_idx is None at decode stage with bucket_internal" + assert ( + layer_past is not None + ), "Invalid parameters: layer_past is None at decode stage with bucket_internal" + past_key, past_value = layer_past.split((self.head_dim, self.head_dim), dim=-1) - if token_idx is not None: - # Using out of place version of index_add_() to ensure the intermediate tensors are not lost when HPU graphs are enabled. - key = past_key.index_add(1, token_idx - 1, key - torch.index_select(past_key, 1, token_idx - 1)) - value = past_value.index_add( - 1, token_idx - 1, value - torch.index_select(past_value, 1, token_idx - 1) - ) - else: - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - present = torch.cat((key, value), dim=-1) if use_cache else None + key = past_key.index_copy_(1, token_idx - 1, key) + value = past_value.index_copy_(1, token_idx - 1, value) + present = layer_past + + if bucket_internal_decode_stage: + key = key[:, :cache_idx, :] + value = value[:, :cache_idx, :] + attention_mask = attention_mask[:, :, :, :cache_idx] if not output_attentions and head_mask is None and use_flash_attention: # Difference with the original implementation: there is no need to transpose the key here, @@ -367,6 +387,11 @@ def forward( attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) + if bucket_internal_decode_stage: + # Return only past key value shapes and not the tensors during decode phase (q len is 1) + # to avoid making past key values as persistent output tensors of HPU graphs. + present = present.shape + outputs = (attn_output, present) if output_attentions: if self.multi_query: @@ -392,6 +417,7 @@ def gaudi_gpt_bigcode_block_forward( flash_attention_recompute: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, + cache_idx: Optional[int] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ @@ -413,6 +439,7 @@ def gaudi_gpt_bigcode_block_forward( flash_attention_recompute=flash_attention_recompute, flash_attention_fast_softmax=flash_attention_fast_softmax, flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -475,6 +502,7 @@ def gaudi_gpt_bigcode_model_forward( flash_attention_recompute: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, + cache_idx: Optional[int] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: """ Copied from GPTBigCodeModel.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -638,6 +666,7 @@ def gaudi_gpt_bigcode_model_forward( flash_attention_recompute=flash_attention_recompute, flash_attention_fast_softmax=flash_attention_fast_softmax, flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, ) hidden_states = outputs[0] @@ -750,6 +779,7 @@ def prepare_inputs_for_generation( "flash_attention_recompute": kwargs.get("flash_attention_recompute", False), "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax", False), "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask", False), + "cache_idx": kwargs.get("cache_idx", None), } ) return model_inputs @@ -775,6 +805,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, + cache_idx: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -803,6 +834,7 @@ def forward( flash_attention_recompute=flash_attention_recompute, flash_attention_fast_softmax=flash_attention_fast_softmax, flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, ) hidden_states = transformer_outputs[0] diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index ed1a094e47..7e9d5ebcd5 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -27,13 +27,7 @@ ("EleutherAI/gpt-neox-20b", 1, False, 50.67672679310354, False), ("meta-llama/Llama-2-7b-hf", 1, True, 141.25776956002076, True), ("tiiuae/falcon-40b", 1, True, 25.202450111088346, False), - ( - "bigcode/starcoder", - 256, - True, - 6846.575763562658, - False, - ), # TODO: Enable check_output after model bigcode/starcoder is fixed + ("bigcode/starcoder", 256, True, 6846.575763562658, True), ("Salesforce/codegen2-1B", 1, False, 446.4029486883532, False), ("mosaicml/mpt-30b", 1, False, 36.06464336116623, False), ("mistralai/Mistral-7B-v0.1", 1, True, 130.2172236767782, True), @@ -41,7 +35,7 @@ ("microsoft/phi-2", 1, False, 224.72307766211117, False), ("meta-llama/Meta-Llama-3-8B", 1, True, 129, False), ("meta-llama/Llama-2-7b-hf", 512, True, 12808, False), - ("meta-llama/Llama-2-7b-hf", 512, False, 8711, False), # in some cases like TGI, reuse_cache isnt used + ("meta-llama/Llama-2-7b-hf", 512, False, 8711, False), # in some cases like TGI, reuse_cache isn't used ("stabilityai/stablelm-2-12b", 1, False, 74.8904496532218, False), ("codellama/CodeLlama-34b-hf", 1, True, 32.644, False), ("bigcode/starcoder2-3b", 1, False, 261.07213776344133, True),