From 30963f371658c148a8e533a7f0bafe8ebe74cef8 Mon Sep 17 00:00:00 2001 From: Xiaotong Chen Date: Tue, 6 Feb 2024 00:51:13 +0800 Subject: [PATCH 1/4] enable internal kv bucket in llama --- examples/text-generation/README.md | 3 ++ examples/text-generation/run_generation.py | 5 +++ examples/text-generation/utils.py | 1 + .../habana/transformers/generation/utils.py | 45 +++++++++++++------ .../models/llama/modeling_llama.py | 17 +++++++ 5 files changed, 58 insertions(+), 13 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 5732e684a4..b57bf49045 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -236,6 +236,9 @@ python run_generation.py \ `--bucket_size` option is especially useful when processing an input stream with varying lengths, that is when you have something like `--dataset_name squad --column_name context --max_input_tokens -1`. `--max_input_tokens -1` specifies no truncation of input prompt in the dataset. Another way to simulate dynamic input is to use `--simulate_dyn_prompt`. For example `--simulate_dyn_prompt 25,35,45` will extend or crop the default prompt (or the prompt passed in using `--prompt`) to sizes 25, 35, and 45, and throughput will be measured for these 3 lengths. If `--simulate_dyn_prompt` is used, the min and max input lengths from it are computed to perform warmup as well. One final optimization that can be used in case of dynamic inputs is `--reduce_recompile`. Thus the suggested configuration to simulate dynamicity after warmup is to use all three arguments: `--simulate_dyn_prompt 25 35 45 --reduce_recompile --bucket_size 30` + +While `--bucket_size` works for any model without model file changes, an even more optimized version of bucketing is supported for certain models like Llama. This can be enabled by setting `--bucket_internal` flag (along with `--bucket_size` to specify the bucket size) + ### Running with FP8 Llama2-70b and Llama2-7b in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch. diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 14e9712595..743ceca91d 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -186,6 +186,11 @@ def setup_parser(parser): then we use `shape = prompt_length + max_new_tokens`. If a positive number is passed \ we increase the bucket in steps of `bucket_size` instead of allocating to max (`prompt_length + max_new_tokens`).", ) + parser.add_argument( + "--bucket_internal", + action="store_true", + help="Split kv sequence into buckets in decode phase. It is useful for long new tokens.", + ) parser.add_argument( "--dataset_max_samples", default=-1, diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 9b66de8128..4bd8f27bb5 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -329,6 +329,7 @@ def setup_generation_config(args, model, tokenizer): generation_config.use_cache = args.use_kv_cache generation_config.static_shapes = is_optimized generation_config.bucket_size = args.bucket_size if is_optimized else -1 + generation_config.bucket_internal = args.bucket_internal generation_config.do_sample = args.do_sample generation_config.num_beams = args.num_beams generation_config.bad_words_ids = bad_words_ids diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index bc8fad5118..50a6c54123 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -589,17 +589,25 @@ def generate( inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id ) - is_greedy_or_beam_and_bucket = generation_config.bucket_size > 0 and ( - self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH - or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH + is_greedy_or_beam_and_bucket = ( + not generation_config.bucket_internal + and generation_config.bucket_size > 0 + and ( + self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH + or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH + ) ) model_kwargs["bucket_size"] = generation_config.bucket_size if generation_config.static_shapes else -1 + model_kwargs["bucket_internal"] = generation_config.bucket_internal model_kwargs["reduce_recompile"] = ( generation_config.reduce_recompile if generation_config.reduce_recompile is not None else False ) if model_kwargs["reduce_recompile"]: assert generation_config.bucket_size - if generation_config.reuse_cache: + if generation_config.bucket_internal: + assert generation_config.bucket_size >= 0, "bucket_internal and bucket_size flags set together" + assert generation_config.reuse_cache, "please set reuse_cache to use bucket_internal" + if generation_config.reuse_cache and not generation_config.bucket_internal: assert generation_config.bucket_size <= 0, "reuse_cache and bucketing flags set together" if generation_config.static_shapes: @@ -713,6 +721,8 @@ def generate( token_idx, generation_config.kv_cache_fp8, ) + model_kwargs["kv_cache_len"] = calculated_max_length + if self.config.model_type in ["llama"]: if self.config.max_position_embeddings < calculated_max_length: unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length) @@ -1369,13 +1379,15 @@ def greedy_search( hb_profer.start() this_peer_finished = False # used by synced_gpus only bucket_size = model_kwargs["bucket_size"] + bucket_internal = model_kwargs["bucket_internal"] reduce_recompile = model_kwargs["reduce_recompile"] prompt_len = input_ids.shape[-1] - if bucket_size >= 0: - inc = iter(incrementor(bucket_size, prompt_len)) - if bucket_size > 0: - assert "position_ids" not in model_kwargs, "Untested path" + if not bucket_internal: + if bucket_size >= 0: + inc = iter(incrementor(bucket_size, prompt_len)) + if bucket_size > 0: + assert "position_ids" not in model_kwargs, "Untested path" while True: if lazy_mode: @@ -1392,11 +1404,18 @@ def greedy_search( break if bucket_size > 0: - # it will not have been padded if bucket_size > 0 - params = next(inc) - input_ids, model_kwargs = self.update_model_kwargs_for_bucketing( - params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile - ) + if not bucket_internal: + # it will not have been padded if bucket_size > 0 + params = next(inc) + input_ids, model_kwargs = self.update_model_kwargs_for_bucketing( + params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile + ) + else: + # Calculate slice idx for kv cache. Breaking down the kv cache in the attention block helps to reduce computation time. + idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor") + if idx < (model_kwargs["kv_cache_len"] // bucket_size): + cache_idx = (idx.item() + 1) * bucket_size + model_kwargs["cache_idx"] = cache_idx # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 9222afd793..6978b4577b 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -199,6 +199,7 @@ def pre_attn_forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -260,6 +261,12 @@ def pre_attn_forward( key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + if use_cache: if reuse_cache: past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) @@ -414,6 +421,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Copied from LlamaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -437,6 +445,7 @@ def forward( reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) self.self_attn.attention_all_reduce(output_pre_attn) output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual) @@ -465,6 +474,7 @@ def pre_attn( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( @@ -479,6 +489,7 @@ def pre_attn( reuse_cache, use_flash_attention, flash_attention_recompute, + cache_idx=cache_idx, ) return output_attn, attn_weights, present_key_value @@ -527,6 +538,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -637,6 +649,7 @@ def custom_forward(*inputs): reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = layer_outputs[0] @@ -678,6 +691,7 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + self.kv_cache_len = max_seq_len def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) @@ -703,6 +717,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -725,6 +740,7 @@ def forward( reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -810,6 +826,7 @@ def prepare_inputs_for_generation( "reuse_cache": reuse_cache, "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "cache_idx": kwargs.get("cache_idx"), } ) return model_inputs From 5b1c1580c53034590afb205d17fc8ae979e8152b Mon Sep 17 00:00:00 2001 From: Xiaotong Chen Date: Tue, 6 Feb 2024 01:10:01 +0800 Subject: [PATCH 2/4] initialize bucket_internal for CI --- optimum/habana/transformers/generation/configuration_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 577b4cbd5a..7b5ed0f7f6 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -44,6 +44,7 @@ def __init__(self, **kwargs): self.limit_hpu_graphs = kwargs.get("limit_hpu_graphs", None) self.reuse_cache = kwargs.get("reuse_cache", None) self.bucket_size = kwargs.get("bucket_size", -1) + self.bucket_internal = kwargs.get("bucket_internal", None) self.reduce_recompile = kwargs.get("reduce_recompile", None) self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None) self.use_flash_attention = kwargs.get("use_flash_attention", None) From c1fb4f461bf327e5225e44a88117928cae32c0b1 Mon Sep 17 00:00:00 2001 From: Xiaotong Chen Date: Tue, 6 Feb 2024 12:29:01 +0800 Subject: [PATCH 3/4] make bucket_internal more clear --- examples/text-generation/run_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 743ceca91d..6b0b2e4695 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -189,7 +189,7 @@ def setup_parser(parser): parser.add_argument( "--bucket_internal", action="store_true", - help="Split kv sequence into buckets in decode phase. It is useful for long new tokens.", + help="Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large.", ) parser.add_argument( "--dataset_max_samples", From 8b628dfb63f28251efd5bef05992590bfa8745cf Mon Sep 17 00:00:00 2001 From: Xiaotong Chen Date: Thu, 8 Feb 2024 10:33:22 +0800 Subject: [PATCH 4/4] further perf optim while max length is not multiple of bucket size --- optimum/habana/transformers/generation/utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 60b306e5d3..4d73193b91 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1380,6 +1380,7 @@ def greedy_search( hb_profer.start() this_peer_finished = False # used by synced_gpus only bucket_size = model_kwargs["bucket_size"] + prev_idx = None # avoiding calculate cache_idx when its value is not changing bucket_internal = model_kwargs["bucket_internal"] reduce_recompile = model_kwargs["reduce_recompile"] @@ -1413,10 +1414,14 @@ def greedy_search( ) else: # Calculate slice idx for kv cache. Breaking down the kv cache in the attention block helps to reduce computation time. - idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor") - if idx < (model_kwargs["kv_cache_len"] // bucket_size): - cache_idx = (idx.item() + 1) * bucket_size - model_kwargs["cache_idx"] = cache_idx + if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size: + idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor") + if idx != prev_idx: + cache_idx = (idx.item() + 1) * bucket_size + model_kwargs["cache_idx"] = cache_idx + prev_idx = idx + else: + model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"] # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)