Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def _update_model_kwargs_for_generation(

if token_idx is not None:
token_idx.add_(1)
if "token_idx_cpu" in model_kwargs:
model_kwargs["token_idx_cpu"] += 1

return model_kwargs

Expand Down Expand Up @@ -576,6 +578,7 @@ def generate(
# token_idx is the current index in the generation process, it is incremented each time a new token is generated
token_idx = inputs_tensor.shape[-1]
model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device)
model_kwargs["token_idx_cpu"] = token_idx
if generation_config.max_new_tokens is None:
generation_config.max_new_tokens = generation_config.max_length - token_idx
inputs_tensor = torch.nn.functional.pad(
Expand Down Expand Up @@ -670,6 +673,7 @@ def generate(
model_kwargs["attn_softmax_bf16"] = generation_config.attn_softmax_bf16

# determine whether limit_hpu_graphs needs to be used
model_kwargs["use_hpu_graphs"] = hpu_graphs
model_kwargs["limit_hpu_graphs"] = generation_config.limit_hpu_graphs

# prepare for allocate kv cache
Expand Down Expand Up @@ -1333,8 +1337,9 @@ def greedy_search(
hb_profer.start()
this_peer_finished = False # used by synced_gpus only
bucket_size = model_kwargs.get("bucket_size", -1)
bucket_internal = model_kwargs["bucket_internal"]
reduce_recompile = model_kwargs.get("reduce_recompile", False)
prev_idx = None # avoiding calculate cache_idx when its value is not changing
prev_idx = -1 # avoiding calculate cache_idx when its value is not changing
bucket_internal = model_kwargs.get("bucket_internal", None)

prompt_len = input_ids.shape[-1]
Expand Down Expand Up @@ -1362,23 +1367,12 @@ def greedy_search(
if this_peer_finished_flag.item() == 0.0:
break

if bucket_size > 0:
if not bucket_internal:
# it will not have been padded if bucket_size > 0
params = next(inc)
input_ids, model_kwargs = self.update_model_kwargs_for_bucketing(
params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile
)
else:
# Calculate slice idx for kv cache. Breaking down the kv cache in the attention block helps to reduce computation time.
if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor")
if idx != prev_idx:
cache_idx = (idx.item() + 1) * bucket_size
model_kwargs["cache_idx"] = cache_idx
prev_idx = idx
else:
model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]
if bucket_size > 0 and not bucket_internal:
# it will not have been padded if bucket_size > 0
params = next(inc)
input_ids, model_kwargs = self.update_model_kwargs_for_bucketing(
params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile
)

# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
Expand Down Expand Up @@ -1453,6 +1447,16 @@ def greedy_search(
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if bucket_size > 0 and bucket_internal:
# Calculate slice idx for kv cache during the decode phase.
# Breaking down the kv cache in the attention block helps to reduce computation time.
if model_kwargs.get("token_idx_cpu") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
idx = (model_kwargs.get("token_idx_cpu") - 1) // bucket_size
if prev_idx != idx:
model_kwargs["cache_idx"] = (idx + 1) * bucket_size
prev_idx = idx
else:
model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]
cur_len = cur_len + 1

# if eos_token was found in one sentence, set sentence to finished
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ def pre_attn_forward(
if cache_idx is not None and q_len == 1:
key_states = key_states[:, :, :cache_idx, :]
value_states = value_states[:, :, :cache_idx, :]
attention_mask = attention_mask[:, :, :, :cache_idx]
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, :cache_idx]
kv_seq_len = key_states.shape[-2]

if use_cache:
Expand Down