diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 5732e684a4..b57bf49045 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -236,6 +236,9 @@ python run_generation.py \ `--bucket_size` option is especially useful when processing an input stream with varying lengths, that is when you have something like `--dataset_name squad --column_name context --max_input_tokens -1`. `--max_input_tokens -1` specifies no truncation of input prompt in the dataset. Another way to simulate dynamic input is to use `--simulate_dyn_prompt`. For example `--simulate_dyn_prompt 25,35,45` will extend or crop the default prompt (or the prompt passed in using `--prompt`) to sizes 25, 35, and 45, and throughput will be measured for these 3 lengths. If `--simulate_dyn_prompt` is used, the min and max input lengths from it are computed to perform warmup as well. One final optimization that can be used in case of dynamic inputs is `--reduce_recompile`. Thus the suggested configuration to simulate dynamicity after warmup is to use all three arguments: `--simulate_dyn_prompt 25 35 45 --reduce_recompile --bucket_size 30` + +While `--bucket_size` works for any model without model file changes, an even more optimized version of bucketing is supported for certain models like Llama. This can be enabled by setting `--bucket_internal` flag (along with `--bucket_size` to specify the bucket size) + ### Running with FP8 Llama2-70b and Llama2-7b in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch. diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 14e9712595..6b0b2e4695 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -186,6 +186,11 @@ def setup_parser(parser): then we use `shape = prompt_length + max_new_tokens`. If a positive number is passed \ we increase the bucket in steps of `bucket_size` instead of allocating to max (`prompt_length + max_new_tokens`).", ) + parser.add_argument( + "--bucket_internal", + action="store_true", + help="Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large.", + ) parser.add_argument( "--dataset_max_samples", default=-1, diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 415837fdee..e8c847c2f7 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -336,6 +336,7 @@ def setup_generation_config(args, model, tokenizer): generation_config.use_cache = args.use_kv_cache generation_config.static_shapes = is_optimized generation_config.bucket_size = args.bucket_size if is_optimized else -1 + generation_config.bucket_internal = args.bucket_internal generation_config.do_sample = args.do_sample generation_config.num_beams = args.num_beams generation_config.bad_words_ids = bad_words_ids diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 577b4cbd5a..de537dbe4f 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -27,6 +27,8 @@ class GaudiGenerationConfig(GenerationConfig): If negative (default=-1) pad to max if `static_shapes` is set. Else start with `shape = bucket_size * ceil(prompt_len/bucket_size)` and then grow space by `bucket_size` when needed. Only active if `static_shapes` is used. Can't be used with `reuse_cache`. + bucket_internal (`bool`, *optional*): + Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large. kv_cache_fp8 (`bool`, *optional*): Store kv-cache in float8 when kv-cache is used use_flash_attention (`bool`, *optional*): @@ -44,6 +46,7 @@ def __init__(self, **kwargs): self.limit_hpu_graphs = kwargs.get("limit_hpu_graphs", None) self.reuse_cache = kwargs.get("reuse_cache", None) self.bucket_size = kwargs.get("bucket_size", -1) + self.bucket_internal = kwargs.get("bucket_internal", None) self.reduce_recompile = kwargs.get("reduce_recompile", None) self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None) self.use_flash_attention = kwargs.get("use_flash_attention", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index eaa1f347ab..daef17f829 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -550,11 +550,16 @@ def generate( inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id ) - is_greedy_or_beam_and_bucket = generation_config.bucket_size > 0 and ( - self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH - or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH + is_greedy_or_beam_and_bucket = ( + not generation_config.bucket_internal + and generation_config.bucket_size > 0 + and ( + self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH + or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH + ) ) model_kwargs["bucket_size"] = generation_config.bucket_size if generation_config.static_shapes else -1 + model_kwargs["bucket_internal"] = generation_config.bucket_internal model_kwargs["reduce_recompile"] = ( generation_config.reduce_recompile if generation_config.reduce_recompile is not None else False ) @@ -562,7 +567,12 @@ def generate( assert generation_config.bucket_size if generation_config.reuse_cache: assert self.config.model_type in ["llama"], "reuse_cache only supported by llama at the moment" - assert generation_config.bucket_size <= 0, "reuse_cache and bucketing flags set together" + if not generation_config.bucket_internal: + assert ( + generation_config.bucket_size <= 0 + ), "please set bucket_internal along with reuse_cache and bucket_size" + else: + assert generation_config.bucket_size >= 0, "please set valid bucket_size to use bucket_internal" if generation_config.static_shapes: # Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs @@ -690,6 +700,8 @@ def generate( token_idx, generation_config.kv_cache_fp8, ) + model_kwargs["kv_cache_len"] = calculated_max_length + if self.config.model_type in ["llama"]: if self.config.max_position_embeddings < calculated_max_length: unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length) @@ -1329,12 +1341,15 @@ def greedy_search( this_peer_finished = False # used by synced_gpus only bucket_size = model_kwargs.get("bucket_size", -1) reduce_recompile = model_kwargs.get("reduce_recompile", False) + prev_idx = None # avoiding calculate cache_idx when its value is not changing + bucket_internal = model_kwargs.get("bucket_internal", None) prompt_len = input_ids.shape[-1] - if bucket_size >= 0: - inc = iter(incrementor(bucket_size, prompt_len)) - if bucket_size > 0: - assert "position_ids" not in model_kwargs, "Untested path" + if not bucket_internal: + if bucket_size >= 0: + inc = iter(incrementor(bucket_size, prompt_len)) + if bucket_size > 0: + assert "position_ids" not in model_kwargs, "Untested path" cur_len = prompt_len token_idx = model_kwargs.get("token_idx", None) if token_idx is not None: @@ -1355,11 +1370,22 @@ def greedy_search( break if bucket_size > 0: - # it will not have been padded if bucket_size > 0 - params = next(inc) - input_ids, model_kwargs = self.update_model_kwargs_for_bucketing( - params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile - ) + if not bucket_internal: + # it will not have been padded if bucket_size > 0 + params = next(inc) + input_ids, model_kwargs = self.update_model_kwargs_for_bucketing( + params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile + ) + else: + # Calculate slice idx for kv cache. Breaking down the kv cache in the attention block helps to reduce computation time. + if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size: + idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor") + if idx != prev_idx: + cache_idx = (idx.item() + 1) * bucket_size + model_kwargs["cache_idx"] = cache_idx + prev_idx = idx + else: + model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"] # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index ee5f152184..12017527ac 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -207,6 +207,7 @@ def pre_attn_forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -281,6 +282,12 @@ def pre_attn_forward( key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + if use_cache: if reuse_cache: past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) @@ -445,6 +452,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -474,6 +482,7 @@ def forward( reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, **kwargs, ) self.self_attn.attention_all_reduce(output_pre_attn) @@ -503,6 +512,7 @@ def pre_attn( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( @@ -517,6 +527,7 @@ def pre_attn( reuse_cache, use_flash_attention, flash_attention_recompute, + cache_idx=cache_idx, ) return output_attn, attn_weights, present_key_value @@ -565,6 +576,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -681,6 +693,7 @@ def forward( reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = layer_outputs[0] @@ -728,6 +741,7 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + self.kv_cache_len = max_seq_len def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) @@ -753,6 +767,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -775,6 +790,7 @@ def forward( reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -886,6 +902,7 @@ def prepare_inputs_for_generation( "reuse_cache": reuse_cache, "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "cache_idx": kwargs.get("cache_idx"), } ) return model_inputs