diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index b963270fdd..728da18ea8 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -215,6 +215,8 @@ def _update_model_kwargs_for_generation( if token_idx is not None: token_idx.add_(1) + if "token_idx_cpu" in model_kwargs: + model_kwargs["token_idx_cpu"] += 1 return model_kwargs @@ -576,6 +578,7 @@ def generate( # token_idx is the current index in the generation process, it is incremented each time a new token is generated token_idx = inputs_tensor.shape[-1] model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device) + model_kwargs["token_idx_cpu"] = token_idx if generation_config.max_new_tokens is None: generation_config.max_new_tokens = generation_config.max_length - token_idx inputs_tensor = torch.nn.functional.pad( @@ -670,6 +673,7 @@ def generate( model_kwargs["attn_softmax_bf16"] = generation_config.attn_softmax_bf16 # determine whether limit_hpu_graphs needs to be used + model_kwargs["use_hpu_graphs"] = hpu_graphs model_kwargs["limit_hpu_graphs"] = generation_config.limit_hpu_graphs # prepare for allocate kv cache @@ -1333,8 +1337,9 @@ def greedy_search( hb_profer.start() this_peer_finished = False # used by synced_gpus only bucket_size = model_kwargs.get("bucket_size", -1) + bucket_internal = model_kwargs["bucket_internal"] reduce_recompile = model_kwargs.get("reduce_recompile", False) - prev_idx = None # avoiding calculate cache_idx when its value is not changing + prev_idx = -1 # avoiding calculate cache_idx when its value is not changing bucket_internal = model_kwargs.get("bucket_internal", None) prompt_len = input_ids.shape[-1] @@ -1362,23 +1367,12 @@ def greedy_search( if this_peer_finished_flag.item() == 0.0: break - if bucket_size > 0: - if not bucket_internal: - # it will not have been padded if bucket_size > 0 - params = next(inc) - input_ids, model_kwargs = self.update_model_kwargs_for_bucketing( - params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile - ) - else: - # Calculate slice idx for kv cache. Breaking down the kv cache in the attention block helps to reduce computation time. - if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size: - idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor") - if idx != prev_idx: - cache_idx = (idx.item() + 1) * bucket_size - model_kwargs["cache_idx"] = cache_idx - prev_idx = idx - else: - model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"] + if bucket_size > 0 and not bucket_internal: + # it will not have been padded if bucket_size > 0 + params = next(inc) + input_ids, model_kwargs = self.update_model_kwargs_for_bucketing( + params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile + ) # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -1453,6 +1447,16 @@ def greedy_search( model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) + if bucket_size > 0 and bucket_internal: + # Calculate slice idx for kv cache during the decode phase. + # Breaking down the kv cache in the attention block helps to reduce computation time. + if model_kwargs.get("token_idx_cpu") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size: + idx = (model_kwargs.get("token_idx_cpu") - 1) // bucket_size + if prev_idx != idx: + model_kwargs["cache_idx"] = (idx + 1) * bucket_size + prev_idx = idx + else: + model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"] cur_len = cur_len + 1 # if eos_token was found in one sentence, set sentence to finished diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index f811bc38e5..de0d5aa0f4 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -291,7 +291,8 @@ def pre_attn_forward( if cache_idx is not None and q_len == 1: key_states = key_states[:, :, :cache_idx, :] value_states = value_states[:, :, :cache_idx, :] - attention_mask = attention_mask[:, :, :, :cache_idx] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] kv_seq_len = key_states.shape[-2] if use_cache: