From 52543174ccf02a8fbae5892e57d7d83d85a273a2 Mon Sep 17 00:00:00 2001 From: xt574chen <158136116+xt574chen@users.noreply.github.com> Date: Sat, 2 Mar 2024 17:15:41 +0800 Subject: [PATCH] [feat] extend bucket_internal to SAMPLE generation mode --- .../habana/transformers/generation/utils.py | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 5ac7894e6d..80a1ee4e9d 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -746,7 +746,9 @@ def generate( # if generation_config.bucket_size <= 0, padding is handled by the generating fn (like greedy_search) if generation_config.static_shapes and generation_config.bucket_size > 0: assert ( - generation_mode == GenerationMode.GREEDY_SEARCH or generation_mode == GenerationMode.BEAM_SEARCH + generation_mode == GenerationMode.GREEDY_SEARCH + or generation_mode == GenerationMode.SAMPLE + or generation_mode == GenerationMode.BEAM_SEARCH ), "generation_config.bucket_size > 0 supported only for greedy mode" if streamer is not None and (generation_config.num_beams > 1): @@ -1763,6 +1765,18 @@ def sample( # Update cur_len in case of static shapes cur_len = token_idx.item() + bucket_size = model_kwargs.get("bucket_size", -1) + prev_idx = -1 # avoiding calculate cache_idx when its value is not changing + bucket_internal = model_kwargs.get("bucket_internal", None) + reduce_recompile = model_kwargs.get("reduce_recompile", False) + + prompt_len = input_ids.shape[-1] + if not bucket_internal: + if bucket_size >= 0: + inc = iter(incrementor(bucket_size, prompt_len)) + if bucket_size > 0: + assert "position_ids" not in model_kwargs, "Untested path" + # auto-regressive generation while True: if lazy_mode: @@ -1778,6 +1792,13 @@ def sample( if this_peer_finished_flag.item() == 0.0: break + if bucket_size > 0 and not bucket_internal: + # it will not have been padded if bucket_size > 0 + params = next(inc) + input_ids, model_kwargs = self.update_model_kwargs_for_bucketing( + params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile + ) + # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -1849,6 +1870,17 @@ def sample( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) cur_len = cur_len + 1 + if bucket_size > 0 and bucket_internal: + # Calculate slice idx for kv cache during the decode phase. + # Breaking down the kv cache in the attention block helps to reduce computation time. + if model_kwargs.get("token_idx_cpu") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size: + idx = (model_kwargs.get("token_idx_cpu") - 1) // bucket_size + if prev_idx != idx: + model_kwargs["cache_idx"] = (idx + 1) * bucket_size + prev_idx = idx + else: + model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"] + # if eos_token was found in one sentence, set sentence to finished if not ignore_eos and eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul(