diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index fdcf315cad..701fa1890d 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -624,8 +624,19 @@ def generate( ) if model_kwargs["reduce_recompile"]: assert generation_config.bucket_size + # Below condition checked explicitly since llama supports bucket_internal even without reuse_cache if generation_config.bucket_internal: assert generation_config.bucket_size >= 0, "please set bucket_size to use bucket_internal" + if generation_config.reuse_cache: + assert self.config.model_type in [ + "llama", + "mistral", + "falcon", + ], "reuse_cache only supported by llama, mistral and falcon at the moment" + if not generation_config.bucket_internal: + assert ( + generation_config.bucket_size <= 0 + ), "please set bucket_internal along with reuse_cache and bucket_size" if generation_config.static_shapes: # Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs @@ -1794,6 +1805,9 @@ def sample( assert "position_ids" not in model_kwargs, "Untested path" sample_first = True + model_kwargs["pad_done"] = False + model_kwargs["lazy_mode"] = lazy_mode + # auto-regressive generation while True: if lazy_mode: @@ -1817,7 +1831,6 @@ def sample( ) # prepare model inputs - model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -1922,6 +1935,21 @@ def sample( if this_peer_finished and not synced_gpus: break + if not model_kwargs.get("pad_done", False) and not model_kwargs.get("reuse_cache", False) \ + and bucket_internal: + # Pad the returned pask key values tensors from prefill phase forward run to maximum length + # before starting the decode phase. + self._pad_past_key_values(model_kwargs) + model_kwargs["pad_done"] = True + + if model_kwargs.get("use_hpu_graphs", False) and model_kwargs.get("limit_hpu_graphs", False) \ + and not model_kwargs.get("reuse_cache", False) and bucket_internal: + # Clear HPU graphs input tensors of the decode phase after the full generation while loop + print("CLEAR HPU GRAPH INPUTS OF DECODE PHASE") + self.clear_inputs() + # Delete past key value tensors + self._remove_past_key_values(model_kwargs) + hb_profer.stop() if streamer is not None: streamer.end()