From d6b7323669fc7a3ebbfaabedf26cb7ccb5ccb23e Mon Sep 17 00:00:00 2001 From: Yaser Afshar Date: Fri, 1 Nov 2024 11:02:40 -0700 Subject: [PATCH 1/4] Fix import path for streamers module --- optimum/habana/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 3ca5ae4793..1b09304821 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -70,7 +70,7 @@ if TYPE_CHECKING: from transformers import PreTrainedModel - from transformers.streamers import BaseStreamer + from transformers.generation.streamers import BaseStreamer from transformers.tokenization_utils_base import PreTrainedTokenizerBase from .candidate_generator import GaudiCandidateGenerator From eadc356be54b8a44813038016f38b604b12f797b Mon Sep 17 00:00:00 2001 From: Yaser Afshar Date: Fri, 1 Nov 2024 11:33:50 -0700 Subject: [PATCH 2/4] Fix the _prepare_decoder_attention_mask interface - Fix the type hint, dtype can not be a str - Fix the device hint - Remove the pad token id arg, the decoder_attention_mask is a binary of 0, and 1 --- optimum/habana/transformers/generation/utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 1b09304821..4dbba44208 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -167,12 +167,12 @@ def _prepare_decoder_attention_mask( self, max_steps: int, # current stopping criteria batch_size: int, - pad_token_id: int, - device: str, - dtype: str = bool, + device: Union[str, torch.device], + dtype: torch.dtype = torch.bool, ) -> torch.Tensor: - x = torch.zeros((batch_size, max_steps), device=device, dtype=dtype) - return x.index_fill(1, torch.tensor(0), 1) # First the position with pad_token_id + decoder_attention_mask = torch.zeros((batch_size, max_steps), device=device, dtype=dtype) + index = torch.tensor(0, device=device) + return decoder_attention_mask.index_fill(1, index, 1) # First position with 1 def _prepare_decoder_input_ids_for_generation( self, @@ -1123,7 +1123,6 @@ def generate( model_kwargs["decoder_attention_mask"] = self._prepare_decoder_attention_mask( max_length, inputs_tensor.shape[0], - generation_config.pad_token_id, inputs_tensor.device, ) From d91ea707f0b01be69d9c19cc85c137eebe95b624 Mon Sep 17 00:00:00 2001 From: Yaser Afshar Date: Fri, 1 Nov 2024 12:03:57 -0700 Subject: [PATCH 3/4] Improve the _pad_past_key_values - Added an early return - Extracted is_mqa_model and lazy_mode to avoid repeated dictionary lookups - Used more descriptive variable names and simplified the nested loops for better readability --- .../habana/transformers/generation/utils.py | 48 +++++++++++-------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 4dbba44208..eeffe01173 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -326,29 +326,35 @@ def _expand_dict_for_generation(dict_to_expand): return input_ids, model_kwargs def _pad_past_key_values(self, model_kwargs): + # Early return if no past key values to pad + past_key_values = model_kwargs.get("past_key_values") + if not past_key_values: + return + + # Determine if the model is MQA or not + is_mqa_model = model_kwargs.get("mqa_model", False) + lazy_mode = model_kwargs.get("lazy_mode", False) pad_amount = model_kwargs.get("kv_cache_pad_len", 0) - if model_kwargs["past_key_values"]: - if model_kwargs.get("mqa_model", False): - for i in range(len(model_kwargs["past_key_values"])): # layer - if torch.is_tensor( - model_kwargs["past_key_values"][i] - ): # tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked - model_kwargs["past_key_values"][i] = torch.nn.functional.pad( - model_kwargs["past_key_values"][i], (0, 0, 0, pad_amount) - ) - if model_kwargs.get("lazy_mode", False): + + # For MQA models, past_key_values is a tensor + if is_mqa_model: + for layer in past_key_values: # Iterate over layers + if torch.is_tensor(layer): + # tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked + layer = torch.nn.functional.pad(layer, (0, 0, 0, pad_amount)) + # Mark step if lazy mode is enabled + if lazy_mode: + self.htcore_generation.mark_step() + # For Non-MQA models, the past_key_values is a list of lists (k and v) + else: + for layer in past_key_values: # Iterate over layers + for k_or_v in layer: # Iterate over k and v + if torch.is_tensor(k_or_v): + # tensor(batch_size, n_heads, kv_cache_len, head_dim) + k_or_v = torch.nn.functional.pad(k_or_v, (0, 0, 0, pad_amount)) + # Mark step if lazy mode is enabled + if lazy_mode: self.htcore_generation.mark_step() - else: - for i in range(len(model_kwargs["past_key_values"])): # layer - for j in range(len(model_kwargs["past_key_values"][i])): # k or v - if torch.is_tensor( - model_kwargs["past_key_values"][i][j] - ): # tensor(batch_size, n_heads, kv_cache_len, head_dim) - model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad( - model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount) - ) - if model_kwargs.get("lazy_mode", False): - self.htcore_generation.mark_step() def _remove_past_key_values(self, model_kwargs): if model_kwargs["past_key_values"]: From c19ea362b4c9afa31e4e26ed7b4f3a4c9937d257 Mon Sep 17 00:00:00 2001 From: Yaser Afshar Date: Sun, 10 Nov 2024 06:31:00 -0800 Subject: [PATCH 4/4] Update the list in place --- optimum/habana/transformers/generation/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index f69f938c34..3837e5fb53 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -340,20 +340,20 @@ def _pad_past_key_values(self, model_kwargs): # For MQA models, past_key_values is a tensor if is_mqa_model: - for layer in past_key_values: # Iterate over layers + for i, layer in enumerate(past_key_values): # Iterate over layers if torch.is_tensor(layer): # tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked - layer = torch.nn.functional.pad(layer, (0, 0, 0, pad_amount)) + past_key_values[i] = torch.nn.functional.pad(layer, (0, 0, 0, pad_amount)) # Mark step if lazy mode is enabled if lazy_mode: self.htcore_generation.mark_step() # For Non-MQA models, the past_key_values is a list of lists (k and v) else: - for layer in past_key_values: # Iterate over layers - for k_or_v in layer: # Iterate over k and v + for i, layer in enumerate(past_key_values): # Iterate over layers + for j, k_or_v in enumerate(layer): # Iterate over k and v if torch.is_tensor(k_or_v): # tensor(batch_size, n_heads, kv_cache_len, head_dim) - k_or_v = torch.nn.functional.pad(k_or_v, (0, 0, 0, pad_amount)) + past_key_values[i][j] = torch.nn.functional.pad(k_or_v, (0, 0, 0, pad_amount)) # Mark step if lazy mode is enabled if lazy_mode: self.htcore_generation.mark_step()