From 4875b1ea117727c371e8d1da0fc124d3c5b4d870 Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Fri, 12 Apr 2024 11:31:54 +0300 Subject: [PATCH 1/2] Sampling search UseKV cache till input seq len for prefill phase --- .../habana/transformers/generation/utils.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index fdcf315cad..1b56a37fc2 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -624,8 +624,19 @@ def generate( ) if model_kwargs["reduce_recompile"]: assert generation_config.bucket_size + # Below condition checked explicitly since llama supports bucket_internal even without reuse_cache if generation_config.bucket_internal: assert generation_config.bucket_size >= 0, "please set bucket_size to use bucket_internal" + if generation_config.reuse_cache: + assert self.config.model_type in [ + "llama", + "mistral", + "falcon", + ], "reuse_cache only supported by llama, mistral and falcon at the moment" + if not generation_config.bucket_internal: + assert ( + generation_config.bucket_size <= 0 + ), "please set bucket_internal along with reuse_cache and bucket_size" if generation_config.static_shapes: # Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs @@ -1794,6 +1805,9 @@ def sample( assert "position_ids" not in model_kwargs, "Untested path" sample_first = True + model_kwargs["pad_done"] = False + model_kwargs["lazy_mode"] = lazy_mode + # auto-regressive generation while True: if lazy_mode: @@ -1922,6 +1936,21 @@ def sample( if this_peer_finished and not synced_gpus: break + if not model_kwargs.get("pad_done", False) and not model_kwargs.get("reuse_cache", False) \ + and bucket_internal: + # Pad the returned pask key values tensors from prefill phase forward run to maximum length + # before starting the decode phase. + self._pad_past_key_values(model_kwargs) + model_kwargs["pad_done"] = True + + if model_kwargs.get("use_hpu_graphs", False) and model_kwargs.get("limit_hpu_graphs", False) \ + and not model_kwargs.get("reuse_cache", False) and bucket_internal: + # Clear HPU graphs input tensors of the decode phase after the full generation while loop + print("CLEAR HPU GRAPH INPUTS OF DECODE PHASE") + self.clear_inputs() + # Delete past key value tensors + self._remove_past_key_values(model_kwargs) + hb_profer.stop() if streamer is not None: streamer.end() From f13eea4d1bb3bf178564f0123cb80206173cdc1c Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Fri, 12 Apr 2024 11:33:59 +0300 Subject: [PATCH 2/2] Remove redundant line --- optimum/habana/transformers/generation/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 1b56a37fc2..701fa1890d 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1831,7 +1831,6 @@ def sample( ) # prepare model inputs - model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs)