Skip to content
Merged
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d6b7323
Fix import path for streamers module
yafshar Nov 1, 2024
eadc356
Fix the _prepare_decoder_attention_mask interface
yafshar Nov 1, 2024
d91ea70
Improve the _pad_past_key_values
yafshar Nov 1, 2024
6bf985b
Merge branch 'main' into generation
yafshar Nov 8, 2024
c19ea36
Update the list in place
yafshar Nov 10, 2024
0bb308c
Resolve the merge conflict
yafshar Nov 13, 2024
b2542a9
Merge branch 'main' into generation
yafshar Nov 19, 2024
062eb4f
Merge branch 'main' into generation
yafshar Nov 21, 2024
f7bca23
Merge branch 'main' into generation
yafshar Nov 21, 2024
8008528
Merge branch 'main' into generation
yafshar Nov 25, 2024
2579ff3
Merge branch 'main' into generation
yafshar Nov 25, 2024
79ff3c1
Merge branch 'main' into generation
yafshar Nov 25, 2024
41482e1
Merge branch 'main' into generation
yafshar Nov 26, 2024
3bfd16b
Merge branch 'main' into generation
yafshar Nov 26, 2024
ed4d3d8
Merge branch 'main' into generation
yafshar Nov 26, 2024
0e4fcef
Merge branch 'main' into generation
yafshar Nov 28, 2024
31e9d24
Merge branch 'main' into generation
yafshar Dec 2, 2024
d9e7fd1
Merge branch 'main' into generation
yafshar Dec 2, 2024
e677077
Merge branch 'main' into generation
yafshar Dec 2, 2024
e933149
Merge branch 'main' into generation
yafshar Dec 2, 2024
5a281f5
Merge branch 'main' into generation
yafshar Dec 3, 2024
1c301a0
Merge branch 'main' into generation
yafshar Dec 3, 2024
de0db36
Merge branch 'main' into generation
yafshar Dec 3, 2024
0266d22
Merge branch 'main' into generation
yafshar Dec 3, 2024
95dc6eb
Merge branch 'main' into generation
yafshar Dec 4, 2024
64d9f16
Merge branch 'main' into generation
yafshar Dec 4, 2024
7192cf8
Merge branch 'main' into generation
yafshar Dec 4, 2024
1ebcd3d
Merge branch 'main' into generation
yafshar Dec 5, 2024
9e8cf27
Merge branch 'main' into generation
yafshar Dec 6, 2024
3c37e8b
Merge branch 'main' into generation
yafshar Dec 6, 2024
5f8ac1b
Merge branch 'main' into generation
yafshar Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 34 additions & 30 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

if TYPE_CHECKING:
from transformers import PreTrainedModel
from transformers.streamers import BaseStreamer
from transformers.generation.streamers import BaseStreamer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from .candidate_generator import GaudiCandidateGenerator
Expand Down Expand Up @@ -178,12 +178,12 @@ def _prepare_decoder_attention_mask(
self,
max_steps: int, # current stopping criteria
batch_size: int,
pad_token_id: int,
device: str,
dtype: str = bool,
device: Union[str, torch.device],
dtype: torch.dtype = torch.bool,
) -> torch.Tensor:
x = torch.zeros((batch_size, max_steps), device=device, dtype=dtype)
return x.index_fill(1, torch.tensor(0), 1) # First the position with pad_token_id
decoder_attention_mask = torch.zeros((batch_size, max_steps), device=device, dtype=dtype)
index = torch.tensor(0, device=device)
return decoder_attention_mask.index_fill(1, index, 1) # First position with 1

def _prepare_decoder_input_ids_for_generation(
self,
Expand Down Expand Up @@ -337,32 +337,37 @@ def _expand_dict_for_generation(dict_to_expand):
return input_ids, model_kwargs

def _pad_past_key_values(self, model_kwargs):
# Early return if no past key values to pad
past_key_values = model_kwargs.get("past_key_values")
if not past_key_values:
return

# Determine if the model is MQA or not
is_mqa_model = model_kwargs.get("mqa_model", False)
lazy_mode = model_kwargs.get("lazy_mode", False)
pad_amount = model_kwargs.get("kv_cache_pad_len", 0)
kv_cache_len = model_kwargs.get("kv_cache_len", 0)
if model_kwargs["past_key_values"]:
if model_kwargs.get("mqa_model", False):
for i in range(len(model_kwargs["past_key_values"])): # layer
if (
torch.is_tensor(model_kwargs["past_key_values"][i])
and model_kwargs["past_key_values"][i].shape[-2] == kv_cache_len - pad_amount
): # tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked
model_kwargs["past_key_values"][i] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i], (0, 0, 0, pad_amount)
)
if model_kwargs.get("lazy_mode", False):
kv_cache_len_pad_amount = kv_cache_len - pad_amount

# For MQA models, past_key_values is a tensor
if is_mqa_model:
for i, layer in enumerate(past_key_values): # Iterate over layers
if torch.is_tensor(layer) and layer.shape[-2] == kv_cache_len_pad_amount:
# tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked
past_key_values[i] = torch.nn.functional.pad(layer, (0, 0, 0, pad_amount))
Comment thread
yafshar marked this conversation as resolved.
# Mark step if lazy mode is enabled
if lazy_mode:
self.htcore_generation.mark_step()
# For Non-MQA models, the past_key_values is a list of lists (k and v)
else:
for i, layer in enumerate(past_key_values): # Iterate over layers
for j, k_or_v in enumerate(layer): # Iterate over k and v
if torch.is_tensor(k_or_v) and k_or_v.shape[-2] == kv_cache_len_pad_amount:
# tensor(batch_size, n_heads, kv_cache_len, head_dim)
past_key_values[i][j] = torch.nn.functional.pad(k_or_v, (0, 0, 0, pad_amount))
# Mark step if lazy mode is enabled
Comment thread
yafshar marked this conversation as resolved.
if lazy_mode:
self.htcore_generation.mark_step()
else:
for i in range(len(model_kwargs["past_key_values"])): # layer
for j in range(len(model_kwargs["past_key_values"][i])): # k or v
if (
torch.is_tensor(model_kwargs["past_key_values"][i][j])
and model_kwargs["past_key_values"][i][j].shape[-2] == kv_cache_len - pad_amount
): # tensor(batch_size, n_heads, kv_cache_len, head_dim)
model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount)
)
if model_kwargs.get("lazy_mode", False):
self.htcore_generation.mark_step()

def _remove_past_key_values(self, model_kwargs):
if model_kwargs["past_key_values"]:
Expand Down Expand Up @@ -1164,7 +1169,6 @@ def generate(
model_kwargs["decoder_attention_mask"] = self._prepare_decoder_attention_mask(
max_length,
inputs_tensor.shape[0],
generation_config.pad_token_id,
inputs_tensor.device,
)

Expand Down