diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 97fa46f473..9e05083caa 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -143,6 +143,11 @@ def setup_env(args): os.environ.setdefault("PT_HPU_LAZY_ACC_PAR_MODE", "0") os.environ.setdefault("PT_HPU_ENABLE_LAZY_COLLECTIVES", "true") + if args.use_hpu_graphs and args.limit_hpu_graphs and not args.reuse_cache and args.bucket_internal: + # Based upon above conditions and below env variable, + # we can call HPU graphs clear_inputs(). + os.environ.setdefault("PT_HPUGRAPH_DISABLE_TENSOR_CACHE", "1") + # Tweak generation so that it runs faster on Gaudi from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index fd50c00962..50e76b7d75 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -264,6 +264,29 @@ def _expand_dict_for_generation(dict_to_expand): return input_ids, model_kwargs + def _pad_past_key_values(self, model_kwargs): + pad_amount = model_kwargs.get("kv_cache_pad_len", 0) + if model_kwargs["past_key_values"]: + for i in range(len(model_kwargs["past_key_values"])): + for j in range(len(model_kwargs["past_key_values"][i])): + if torch.is_tensor(model_kwargs["past_key_values"][i][j]): + model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad( + model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount) + ) + if model_kwargs.get("lazy_mode", False): + self.htcore_generation.mark_step() + + def _remove_past_key_values(self, model_kwargs): + if model_kwargs["past_key_values"]: + for i in range(len(model_kwargs["past_key_values"])): + for j in range(len(model_kwargs["past_key_values"][i])): + if torch.is_tensor(model_kwargs["past_key_values"][i][j]): + t = model_kwargs["past_key_values"][i][j] + del t + model_kwargs["past_key_values"][i][j] = None + del model_kwargs["past_key_values"] + model_kwargs["past_key_values"] = None + def _update_model_kwargs_for_generation( self, outputs: ModelOutput, @@ -278,10 +301,11 @@ def _update_model_kwargs_for_generation( """ # mark to identify starting from second token model_kwargs["first_token"] = False - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) + if not model_kwargs.get("pad_done", False): + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) if getattr(outputs, "state", None) is not None: model_kwargs["state"] = outputs.state @@ -745,6 +769,9 @@ def generate( ) if model_kwargs["reduce_recompile"]: assert generation_config.bucket_size + # Below condition checked explicitly since llama supports bucket_internal even without reuse_cache + if generation_config.bucket_internal: + assert generation_config.bucket_size >= 0, "please set bucket_size to use bucket_internal" if generation_config.reuse_cache: assert self.config.model_type in [ "llama", @@ -878,6 +905,7 @@ def generate( model_kwargs["attn_softmax_bf16"] = generation_config.attn_softmax_bf16 # determine whether limit_hpu_graphs needs to be used + model_kwargs["use_hpu_graphs"] = hpu_graphs model_kwargs["limit_hpu_graphs"] = generation_config.limit_hpu_graphs # prepare for allocate kv cache @@ -899,7 +927,9 @@ def generate( unwrap_deepspeed_model(self).allocate_kv_cache( bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens ) - model_kwargs["kv_cache_len"] = calculated_max_length + if generation_config.use_cache: + model_kwargs["kv_cache_len"] = calculated_max_length + model_kwargs["kv_cache_pad_len"] = generation_config.max_new_tokens if self.config.model_type in ["llama", "falcon", "mistral"]: if self.config.max_position_embeddings < calculated_max_length: @@ -1618,6 +1648,8 @@ def _greedy_search( cur_len = token_idx.item() time_to_first_token_done = False + model_kwargs["pad_done"] = False + model_kwargs["lazy_mode"] = lazy_mode while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): if lazy_mode: self.htcore_generation.mark_step() @@ -1630,7 +1662,6 @@ def _greedy_search( ) # prepare model inputs - model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -1737,6 +1768,15 @@ def _greedy_search( ) this_peer_finished = unfinished_sequences.max() == 0 + if ( + not model_kwargs.get("pad_done", False) + and not model_kwargs.get("reuse_cache", False) + and bucket_internal + ): + # Pad the returned pask key values tensors from prefill phase forward run to maximum length + # before starting the decode phase. + self._pad_past_key_values(model_kwargs) + model_kwargs["pad_done"] = True hb_profer.step() if hb_gen_time is not None: if not time_to_first_token_done: @@ -1746,6 +1786,17 @@ def _greedy_search( torch_hpu.synchronize() hb_gen_time.step() + if ( + model_kwargs.get("use_hpu_graphs", False) + and model_kwargs.get("limit_hpu_graphs", False) + and not model_kwargs.get("reuse_cache", False) + and bucket_internal + ): + # Clear HPU graphs input tensors of the decode phase after the full generation while loop + self.clear_inputs() + # Delete past key value tensors + self._remove_past_key_values(model_kwargs) + hb_profer.stop() if streamer is not None: streamer.end() @@ -2016,6 +2067,8 @@ def _sample( # auto-regressive generation time_to_first_token_done = False + model_kwargs["pad_done"] = False + model_kwargs["lazy_mode"] = lazy_mode while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): if lazy_mode: self.htcore_generation.mark_step() @@ -2028,7 +2081,6 @@ def _sample( ) # prepare model inputs - model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -2130,6 +2182,16 @@ def _sample( ) this_peer_finished = unfinished_sequences.max() == 0 + if ( + not model_kwargs.get("pad_done", False) + and not model_kwargs.get("reuse_cache", False) + and bucket_internal + ): + # Pad the returned pask key values tensors from prefill phase forward run to maximum length + # before starting the decode phase. + self._pad_past_key_values(model_kwargs) + model_kwargs["pad_done"] = True + hb_profer.step() if hb_gen_time is not None: if not time_to_first_token_done: @@ -2139,6 +2201,17 @@ def _sample( torch_hpu.synchronize() hb_gen_time.step() + if ( + model_kwargs.get("use_hpu_graphs", False) + and model_kwargs.get("limit_hpu_graphs", False) + and not model_kwargs.get("reuse_cache", False) + and bucket_internal + ): + # Clear HPU graphs input tensors of the decode phase after the full generation while loop + self.clear_inputs() + # Delete past key value tensors + self._remove_past_key_values(model_kwargs) + hb_profer.stop() if streamer is not None: streamer.end() diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index e19d1cc9a8..2199ee0c04 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -203,6 +203,8 @@ def adapt_transformers_to_gaudi(): GaudiGenerationMixin.update_model_kwargs_for_bucketing ) transformers.generation.GenerationMixin._get_hpu_graphs_kwargs = GaudiGenerationMixin._get_hpu_graphs_kwargs + transformers.generation.GenerationMixin._pad_past_key_values = GaudiGenerationMixin._pad_past_key_values + transformers.generation.GenerationMixin._remove_past_key_values = GaudiGenerationMixin._remove_past_key_values transformers.generation.GenerationMixin._expand_inputs_for_generation = staticmethod( GaudiGenerationMixin._expand_inputs_for_generation ) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index e609fa1b2b..cde727a608 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -393,7 +393,8 @@ def pre_attn_forward( past_value = torch.zeros( key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device ) - past_key_value = (past_key, past_value) + # Return list instead of tuple + past_key_value = [past_key, past_value] if ( token_idx is not None and num_virtual_tokens is not None @@ -479,6 +480,11 @@ def pre_attn_forward( if not output_attentions: attn_weights = None + if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1: + # Return only past key value shapes and not the tensors during decode phase (q len is 1) + # to avoid making past key values as persistent output tensors of HPU graphs. + past_key_value = (past_key_value[0].shape, past_key_value[1].shape) + return attn_output, attn_weights, past_key_value def attention_all_reduce(self, attn_output): @@ -825,6 +831,7 @@ def forward( use_flash_attention, flash_attention_recompute, flash_attention_causal_mask, + None, ) else: layer_outputs = decoder_layer( @@ -1003,6 +1010,7 @@ def prepare_inputs_for_generation( past_length = 0 reuse_cache = kwargs.get("reuse_cache") + bucket_internal = kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) @@ -1041,8 +1049,9 @@ def prepare_inputs_for_generation( and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] - elif reuse_cache and token_idx is not None: - # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + elif (reuse_cache or bucket_internal) and token_idx is not None: + # KV cache is pre allocated with reuse cache or will be padded with bucket internal + # hence for the 1st token we can slice the inputs till token idx for the fwd pass. input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx]