From e48398d38776771c197d1bd60f1a973d7763d241 Mon Sep 17 00:00:00 2001 From: Vivek Goel Date: Wed, 7 Feb 2024 21:56:10 +0530 Subject: [PATCH 01/18] Expose Llama Fused OPs control from run_lora_clm.py (#23) * Expose Llama Fused OPs control from run_lora_clm.py * Update as per review comments --- examples/language-modeling/run_lora_clm.py | 10 ++++++++++ .../generation/configuration_utils.py | 1 + .../habana/transformers/generation/utils.py | 1 + .../models/llama/modeling_llama.py | 18 +++++++++++++++--- optimum/habana/transformers/trainer.py | 4 ++++ 5 files changed, 31 insertions(+), 3 deletions(-) diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index b480990752..47da3af150 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -156,6 +156,14 @@ class ModelArguments: ) }, ) + use_fused_rope: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use Habana fused-rope for fine-tuning. The current support is limited to Llama only.", + ) + }, + ) load_meta_device: bool = field( default=False, metadata={ @@ -537,6 +545,8 @@ def main(): if model_args.use_flash_attention: model.generation_config.use_flash_attention = True model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute + if not model_args.use_fused_rope: + model.generation_config.use_fused_rope = False if hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None: tokenizer.pad_token_id = model.generation_config.pad_token_id diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 577b4cbd5a..2f8d924226 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -48,3 +48,4 @@ def __init__(self, **kwargs): self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None) self.use_flash_attention = kwargs.get("use_flash_attention", None) self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None) + self.use_fused_rope = kwargs.get("use_fused_rope", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index bc8fad5118..b803953785 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -699,6 +699,7 @@ def generate( # determine whether flash attention needs to be used model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False + model_kwargs["use_fused_rope"] = False if not generation_config.use_fused_rope else True if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[-1] diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 9222afd793..65978b9b4d 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -199,6 +199,7 @@ def pre_attn_forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -249,7 +250,9 @@ def pre_attn_forward( else: kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_customized_rope( + query_states, key_states, cos, sin, position_ids, use_fused_rope=use_fused_rope + ) if past_key_value is not None or reuse_cache: # reuse k, v, self_attention @@ -414,6 +417,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Copied from LlamaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -437,6 +441,7 @@ def forward( reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + use_fused_rope=use_fused_rope, ) self.self_attn.attention_all_reduce(output_pre_attn) output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual) @@ -465,6 +470,7 @@ def pre_attn( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( @@ -479,6 +485,7 @@ def pre_attn( reuse_cache, use_flash_attention, flash_attention_recompute, + use_fused_rope=use_fused_rope, ) return output_attn, attn_weights, present_key_value @@ -527,6 +534,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -617,6 +625,7 @@ def custom_forward(*inputs): attn_softmax_bf16=attn_softmax_bf16, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + use_fused_rope=use_fused_rope, ) return custom_forward @@ -637,6 +646,7 @@ def custom_forward(*inputs): reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + use_fused_rope=use_fused_rope, ) hidden_states = layer_outputs[0] @@ -703,6 +713,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -725,6 +736,7 @@ def forward( reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + use_fused_rope=use_fused_rope, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -815,8 +827,8 @@ def prepare_inputs_for_generation( return model_inputs -def apply_customized_rope(q, k, cos, sin, position_ids): - if q.device.type == "hpu" and FusedRoPE: +def apply_customized_rope(q, k, cos, sin, position_ids, use_fused_rope=True): + if q.device.type == "hpu" and FusedRoPE and use_fused_rope: # TODO: remove `.clone()` when SynapseAI v1.15 is released return FusedRoPE.apply(q, cos.clone(), sin.clone(), position_ids), FusedRoPE.apply( k, cos.clone(), sin.clone(), position_ids diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index c04836a815..2f217a64b9 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -874,6 +874,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if not self.model.generation_config.use_fused_rope: + inputs["use_fused_rope"] = False # TODO: keep syncs for fast DDP? with self.accelerator.accumulate(model): @@ -1626,6 +1628,8 @@ def evaluation_loop( inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if not self.model.generation_config.use_fused_rope: + inputs["use_fused_rope"] = False # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) From d5291ae2c3796b1879511365c5e5c94409f2298b Mon Sep 17 00:00:00 2001 From: xt574chen <158136116+xt574chen@users.noreply.github.com> Date: Thu, 8 Feb 2024 12:34:05 +0800 Subject: [PATCH 02/18] enable internal kv bucket in llama (#24) * enable internal kv bucket in llama * initialize bucket_internal for CI * make bucket_internal more clear * further perf optim while max length is not multiple of bucket size --- examples/text-generation/README.md | 3 ++ examples/text-generation/run_generation.py | 5 ++ examples/text-generation/utils.py | 1 + .../generation/configuration_utils.py | 1 + .../habana/transformers/generation/utils.py | 50 ++++++++++++++----- .../models/llama/modeling_llama.py | 17 +++++++ 6 files changed, 64 insertions(+), 13 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 5732e684a4..b57bf49045 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -236,6 +236,9 @@ python run_generation.py \ `--bucket_size` option is especially useful when processing an input stream with varying lengths, that is when you have something like `--dataset_name squad --column_name context --max_input_tokens -1`. `--max_input_tokens -1` specifies no truncation of input prompt in the dataset. Another way to simulate dynamic input is to use `--simulate_dyn_prompt`. For example `--simulate_dyn_prompt 25,35,45` will extend or crop the default prompt (or the prompt passed in using `--prompt`) to sizes 25, 35, and 45, and throughput will be measured for these 3 lengths. If `--simulate_dyn_prompt` is used, the min and max input lengths from it are computed to perform warmup as well. One final optimization that can be used in case of dynamic inputs is `--reduce_recompile`. Thus the suggested configuration to simulate dynamicity after warmup is to use all three arguments: `--simulate_dyn_prompt 25 35 45 --reduce_recompile --bucket_size 30` + +While `--bucket_size` works for any model without model file changes, an even more optimized version of bucketing is supported for certain models like Llama. This can be enabled by setting `--bucket_internal` flag (along with `--bucket_size` to specify the bucket size) + ### Running with FP8 Llama2-70b and Llama2-7b in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch. diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 14e9712595..6b0b2e4695 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -186,6 +186,11 @@ def setup_parser(parser): then we use `shape = prompt_length + max_new_tokens`. If a positive number is passed \ we increase the bucket in steps of `bucket_size` instead of allocating to max (`prompt_length + max_new_tokens`).", ) + parser.add_argument( + "--bucket_internal", + action="store_true", + help="Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large.", + ) parser.add_argument( "--dataset_max_samples", default=-1, diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 9b66de8128..4bd8f27bb5 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -329,6 +329,7 @@ def setup_generation_config(args, model, tokenizer): generation_config.use_cache = args.use_kv_cache generation_config.static_shapes = is_optimized generation_config.bucket_size = args.bucket_size if is_optimized else -1 + generation_config.bucket_internal = args.bucket_internal generation_config.do_sample = args.do_sample generation_config.num_beams = args.num_beams generation_config.bad_words_ids = bad_words_ids diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 2f8d924226..57f12810db 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -44,6 +44,7 @@ def __init__(self, **kwargs): self.limit_hpu_graphs = kwargs.get("limit_hpu_graphs", None) self.reuse_cache = kwargs.get("reuse_cache", None) self.bucket_size = kwargs.get("bucket_size", -1) + self.bucket_internal = kwargs.get("bucket_internal", None) self.reduce_recompile = kwargs.get("reduce_recompile", None) self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None) self.use_flash_attention = kwargs.get("use_flash_attention", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index b803953785..4d73193b91 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -589,17 +589,25 @@ def generate( inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id ) - is_greedy_or_beam_and_bucket = generation_config.bucket_size > 0 and ( - self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH - or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH + is_greedy_or_beam_and_bucket = ( + not generation_config.bucket_internal + and generation_config.bucket_size > 0 + and ( + self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH + or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH + ) ) model_kwargs["bucket_size"] = generation_config.bucket_size if generation_config.static_shapes else -1 + model_kwargs["bucket_internal"] = generation_config.bucket_internal model_kwargs["reduce_recompile"] = ( generation_config.reduce_recompile if generation_config.reduce_recompile is not None else False ) if model_kwargs["reduce_recompile"]: assert generation_config.bucket_size - if generation_config.reuse_cache: + if generation_config.bucket_internal: + assert generation_config.bucket_size >= 0, "bucket_internal and bucket_size flags set together" + assert generation_config.reuse_cache, "please set reuse_cache to use bucket_internal" + if generation_config.reuse_cache and not generation_config.bucket_internal: assert generation_config.bucket_size <= 0, "reuse_cache and bucketing flags set together" if generation_config.static_shapes: @@ -714,6 +722,8 @@ def generate( token_idx, generation_config.kv_cache_fp8, ) + model_kwargs["kv_cache_len"] = calculated_max_length + if self.config.model_type in ["llama"]: if self.config.max_position_embeddings < calculated_max_length: unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length) @@ -1370,13 +1380,16 @@ def greedy_search( hb_profer.start() this_peer_finished = False # used by synced_gpus only bucket_size = model_kwargs["bucket_size"] + prev_idx = None # avoiding calculate cache_idx when its value is not changing + bucket_internal = model_kwargs["bucket_internal"] reduce_recompile = model_kwargs["reduce_recompile"] prompt_len = input_ids.shape[-1] - if bucket_size >= 0: - inc = iter(incrementor(bucket_size, prompt_len)) - if bucket_size > 0: - assert "position_ids" not in model_kwargs, "Untested path" + if not bucket_internal: + if bucket_size >= 0: + inc = iter(incrementor(bucket_size, prompt_len)) + if bucket_size > 0: + assert "position_ids" not in model_kwargs, "Untested path" while True: if lazy_mode: @@ -1393,11 +1406,22 @@ def greedy_search( break if bucket_size > 0: - # it will not have been padded if bucket_size > 0 - params = next(inc) - input_ids, model_kwargs = self.update_model_kwargs_for_bucketing( - params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile - ) + if not bucket_internal: + # it will not have been padded if bucket_size > 0 + params = next(inc) + input_ids, model_kwargs = self.update_model_kwargs_for_bucketing( + params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile + ) + else: + # Calculate slice idx for kv cache. Breaking down the kv cache in the attention block helps to reduce computation time. + if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size: + idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor") + if idx != prev_idx: + cache_idx = (idx.item() + 1) * bucket_size + model_kwargs["cache_idx"] = cache_idx + prev_idx = idx + else: + model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"] # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 65978b9b4d..ce55b283be 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -199,6 +199,7 @@ def pre_attn_forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -263,6 +264,12 @@ def pre_attn_forward( key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + if use_cache: if reuse_cache: past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) @@ -417,6 +424,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -441,6 +449,7 @@ def forward( reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) self.self_attn.attention_all_reduce(output_pre_attn) @@ -470,6 +479,7 @@ def pre_attn( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) @@ -485,6 +495,7 @@ def pre_attn( reuse_cache, use_flash_attention, flash_attention_recompute, + cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) return output_attn, attn_weights, present_key_value @@ -534,6 +545,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ @@ -646,6 +658,7 @@ def custom_forward(*inputs): reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) @@ -688,6 +701,7 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + self.kv_cache_len = max_seq_len def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) @@ -713,6 +727,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -736,6 +751,7 @@ def forward( reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) hidden_states = outputs[0] @@ -822,6 +838,7 @@ def prepare_inputs_for_generation( "reuse_cache": reuse_cache, "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "cache_idx": kwargs.get("cache_idx"), } ) return model_inputs From fc91b28d2f2e80b78c240bbe949c66eab26c0563 Mon Sep 17 00:00:00 2001 From: Shakked Weinberger <145463809+shakkedw@users.noreply.github.com> Date: Thu, 8 Feb 2024 09:57:23 +0200 Subject: [PATCH 03/18] [SW-173358] add first token prints (#18) * [SW-173358] add first token prints * [SW-173358] rename x to outputs * [SW-173358] make style --- examples/text-generation/run_generation.py | 9 +++++++-- .../habana/transformers/generation/utils.py | 19 ++++++++++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 6b0b2e4695..d2345c711c 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -294,6 +294,8 @@ def main(): def generate(size=None, reduce_recompile=False): """Generates sequences from the input sentences and returns them.""" + t0 = time.perf_counter() + print(f"Step4+ starting time is {t0*1000}", flush=True) # Tokenization if args.max_input_tokens > 0: input_tokens = tokenizer.batch_encode_plus( @@ -314,7 +316,7 @@ def generate(size=None, reduce_recompile=False): if torch.is_tensor(input_tokens[t]): input_tokens[t] = input_tokens[t].to(args.device) - outputs = model.generate( + output_tokens = model.generate( **input_tokens, generation_config=generation_config, lazy_mode=use_lazy_mode, @@ -322,7 +324,10 @@ def generate(size=None, reduce_recompile=False): profiling_steps=args.profiling_steps, profiling_warmup_steps=args.profiling_warmup_steps, ).cpu() - return tokenizer.batch_decode(outputs, skip_special_tokens=True) + outputs = tokenizer.batch_decode(output_tokens, skip_special_tokens=True) + duration = time.perf_counter() - t0 + print(f"Total E2E time of this iteration is {duration:.3f}s", flush=True) + return outputs from optimum.habana.utils import HabanaProfile diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 4d73193b91..edf9afc4f2 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -17,6 +17,7 @@ import copy import inspect import math +import time import warnings from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -1385,12 +1386,13 @@ def greedy_search( reduce_recompile = model_kwargs["reduce_recompile"] prompt_len = input_ids.shape[-1] + if not bucket_internal: if bucket_size >= 0: inc = iter(incrementor(bucket_size, prompt_len)) if bucket_size > 0: assert "position_ids" not in model_kwargs, "Untested path" - + greedy_first = True while True: if lazy_mode: self.htcore_generation.mark_step() @@ -1512,6 +1514,13 @@ def greedy_search( hb_profer.step() + if greedy_first: + import habana_frameworks.torch.hpu as torch_hpu + + torch_hpu.synchronize() + print(f"First Token time(greedy):{time.perf_counter()*1000}") + greedy_first = False + if this_peer_finished and not synced_gpus: break @@ -1730,6 +1739,7 @@ def sample( hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps) hb_profer.start() this_peer_finished = False # used by synced_gpus only + sample_first = True # auto-regressive generation while True: if lazy_mode: @@ -1830,6 +1840,13 @@ def sample( hb_profer.step() + if sample_first: + import habana_frameworks.torch.hpu as torch_hpu + + torch_hpu.synchronize() + print(f"First Token time(sample):{time.perf_counter()*1000}") + sample_first = False + if this_peer_finished and not synced_gpus: break From bfe362bba72ce447cc1ef2617c316a82febec93b Mon Sep 17 00:00:00 2001 From: Witold Szczurek <152967125+wszczurekhabana@users.noreply.github.com> Date: Thu, 8 Feb 2024 10:02:36 +0100 Subject: [PATCH 04/18] Enable Flash Attention in recompute and causal modes (#21) * Enable Flash Attention in recompute and causal modes * Add flash_attention_causal_mask to generation utils * Propagate Flash Attention causal_mask to finetuning example * Modify README example and provide additional description * Add flash_attention_causal_mask to FT README --- examples/language-modeling/README.md | 3 +- examples/language-modeling/run_lora_clm.py | 10 ++++ examples/text-generation/README.md | 24 +++++++++ examples/text-generation/run_generation.py | 54 +++++++++++++++++++ examples/text-generation/utils.py | 2 + .../generation/configuration_utils.py | 3 ++ .../habana/transformers/generation/utils.py | 1 + .../models/llama/modeling_llama.py | 27 ++++++++-- optimum/habana/transformers/trainer.py | 4 ++ 9 files changed, 123 insertions(+), 5 deletions(-) diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index 909593427d..ac4a74ab69 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -550,7 +550,8 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \ --lora_rank 4 \ --lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \ --validation_split_percentage 4 \ - --use_flash_attention True + --use_flash_attention True \ + --flash_attention_causal_mask True ``` - Multi-card finetuning of Falcon-180B: diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index 47da3af150..ba3244e57f 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -156,6 +156,15 @@ class ModelArguments: ) }, ) + flash_attention_causal_mask: bool = field( + default=False, + metadata={ + "help": ( + "Whether to enable causal mask in Habana flash attention for fine-tuning." + " It is applicable only when use_flash_attention is True.", + ) + }, + ) use_fused_rope: bool = field( default=True, metadata={ @@ -545,6 +554,7 @@ def main(): if model_args.use_flash_attention: model.generation_config.use_flash_attention = True model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute + model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask if not model_args.use_fused_rope: model.generation_config.use_fused_rope = False diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index b57bf49045..332d117e2f 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -296,6 +296,30 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ ``` `--fp8` is required to enable quantization in fp8. +### Using Habana Flash Attention + +Habana Flash Attention addresses large sequence lenghts on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes. + +Below example uses `flash_attention_recompute` mode in order to reduce memory consumption on prompt stage. Additionally since all sequences in a batch are of the same lenght it uses `flash_attention_causal_mask` which will further improve performance by taking advantage of specific lower-diagonal shape of inputs to softmax operation. + +```bash +python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ +--model_name_or_path meta-llama/Llama-2-70b-hf \ +--use_hpu_graphs \ +--use_kv_cache \ +--reuse_cache \ +--trim_logits \ +--attn_softmax_bf16 \ +--max_input_tokens 31744 \ +--max_new_tokens 1024 \ +--batch_size=12 \ +--use_flash_attention \ +--flash_attention_recompute \ +--flash_attention_causal_mask \ +--book_source +``` + +For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa). ## Language Model Evaluation Harness diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index d2345c711c..048ef827dd 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -232,6 +232,21 @@ def setup_parser(parser): action="store_true", help="Whether to enable Habana Flash Attention, provided that the model supports it.", ) + parser.add_argument( + "--flash_attention_recompute", + action="store_true", + help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.", + ) + parser.add_argument( + "--flash_attention_causal_mask", + action="store_true", + help="Whether to enable Habana Flash Attention in causal mode on first token generation.", + ) + parser.add_argument( + "--book_source", + action="store_true", + help="Whether to use project Guttenberg books data as input. Usefull for testing large sequence lenghts.", + ) parser.add_argument( "--torch_compile", action="store_true", @@ -271,6 +286,45 @@ def main(): # Benchmark over the prompts below if args.prompt: input_sentences = args.prompt + elif args.book_source: + + def download_book(book_id): + import os + + import requests + + url = f"https://www.gutenberg.org/cache/epub/{book_id}/pg{book_id}.txt" + response = requests.get(url) + if response.status_code == 200: + pid = os.getpid() + save_path = f"/tmp/{book_id}_{pid}.txt" + with open(save_path, "wb") as file: + file.write(response.content) + print(f"Book downloaded and saved to: {save_path}") + return save_path + else: + print("Failed to download book! Exiting...") + import sys + + sys.exit() + + def assemble_prompt(prompt_size, book_path): + prompt = "" + counter = 0 + book_lines = open(book_path).readlines() + for line in book_lines: + for word in line.split(): + counter += 1 + prompt += word + " " + if counter == prompt_size: + return [prompt] * args.batch_size + + book_ids = [ + 2701, # Moby Dick; Or, The Whale + 1513, # Romeo and Juliet + 1342, # Pride and Prejudice + ] + input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0])) else: input_sentences = [ "DeepSpeed is a machine learning framework", diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 4bd8f27bb5..fc7f042223 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -344,6 +344,8 @@ def setup_generation_config(args, model, tokenizer): assert generation_config.bucket_size > 0 generation_config.kv_cache_fp8 = args.kv_cache_fp8 generation_config.use_flash_attention = args.use_flash_attention + generation_config.flash_attention_recompute = args.flash_attention_recompute + generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask return generation_config diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 57f12810db..2e72342263 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -33,6 +33,8 @@ class GaudiGenerationConfig(GenerationConfig): Whether to use flash attention optimization. flash_attention_recompute (`bool`, *optional*): Whether to enable recompute if use Habana flash attention. + flash_attention_causal_mask (`bool`, *optional*): + Whether to enable causal_mask if use Habana flash attention. """ def __init__(self, **kwargs): @@ -49,4 +51,5 @@ def __init__(self, **kwargs): self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None) self.use_flash_attention = kwargs.get("use_flash_attention", None) self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None) + self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index edf9afc4f2..f1fdf0748c 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -708,6 +708,7 @@ def generate( # determine whether flash attention needs to be used model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False + model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False model_kwargs["use_fused_rope"] = False if not generation_config.use_fused_rope else True if not self.config.is_encoder_decoder: diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index ce55b283be..2622d832bd 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -199,6 +199,7 @@ def pre_attn_forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -211,6 +212,7 @@ def pre_attn_forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ bsz, q_len, _ = hidden_states.size() @@ -289,10 +291,15 @@ def pre_attn_forward( ) else: # first token - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same lenght + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) else: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( @@ -424,6 +431,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -435,6 +443,7 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ residual = hidden_states output_pre_attn, self_attn_weights, present_key_value = self.pre_attn( @@ -449,6 +458,7 @@ def forward( reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) @@ -479,6 +489,7 @@ def pre_attn( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -495,6 +506,7 @@ def pre_attn( reuse_cache, use_flash_attention, flash_attention_recompute, + flash_attention_causal_mask, cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) @@ -545,6 +557,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -556,6 +569,7 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -637,6 +651,7 @@ def custom_forward(*inputs): attn_softmax_bf16=attn_softmax_bf16, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, use_fused_rope=use_fused_rope, ) @@ -658,6 +673,7 @@ def custom_forward(*inputs): reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) @@ -727,6 +743,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: @@ -751,6 +768,7 @@ def forward( reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) @@ -838,6 +856,7 @@ def prepare_inputs_for_generation( "reuse_cache": reuse_cache, "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), "cache_idx": kwargs.get("cache_idx"), } ) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 2f217a64b9..99514c295b 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -874,6 +874,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if self.model.generation_config.flash_attention_causal_mask: + inputs["flash_attention_causal_mask"] = True if not self.model.generation_config.use_fused_rope: inputs["use_fused_rope"] = False @@ -1628,6 +1630,8 @@ def evaluation_loop( inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if self.model.generation_config.flash_attention_causal_mask: + inputs["flash_attention_causal_mask"] = True if not self.model.generation_config.use_fused_rope: inputs["use_fused_rope"] = False From 64013ff100925d09dda19080d8193f791151dfa6 Mon Sep 17 00:00:00 2001 From: Mohit Deopujari Date: Thu, 8 Feb 2024 14:06:53 -0800 Subject: [PATCH 05/18] Fix inference command clip-roberta (#31) --- examples/contrastive-image-text/README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/contrastive-image-text/README.md b/examples/contrastive-image-text/README.md index b526c6091b..cd4aa92295 100644 --- a/examples/contrastive-image-text/README.md +++ b/examples/contrastive-image-text/README.md @@ -250,5 +250,8 @@ python run_clip.py \ --use_lazy_mode \ --use_hpu_graphs_for_inference \ --gaudi_config_name Habana/clip \ - --bf16 + --bf16 \ + --mediapipe_dataloader ``` + +> `--mediapipe_dataloader` only works on Gaudi2. From 64fd45a996ff72dad901e373d8871f98df03840a Mon Sep 17 00:00:00 2001 From: Bhargav Date: Fri, 9 Feb 2024 22:00:09 +0530 Subject: [PATCH 06/18] Changing backend name (#32) --- optimum/habana/accelerate/utils/dataclasses.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/habana/accelerate/utils/dataclasses.py b/optimum/habana/accelerate/utils/dataclasses.py index 07e256372f..c0484e2243 100644 --- a/optimum/habana/accelerate/utils/dataclasses.py +++ b/optimum/habana/accelerate/utils/dataclasses.py @@ -73,7 +73,8 @@ class GaudiDynamoBackend(str, BaseEnum): - **IPEX** -- Uses IPEX for inference on CPU. Inference only. [Read more](https://github.com/intel/intel-extension-for-pytorch). - **TVM** -- Uses Apach TVM for inference optimizations. [Read more](https://tvm.apache.org/) - - **AOT_HPU_TRAINING_BACKEND** -- Uses Habana Gaudi. + - **AOT_HPU_TRAINING_BACKEND** -- Uses Habana Gaudi - depracated - will be removed. + - **HPU_BACKEND** -- Uses Habana Gaudi. """ @@ -92,6 +93,7 @@ class GaudiDynamoBackend(str, BaseEnum): IPEX = "IPEX" TVM = "TVM" AOT_HPU_TRAINING_BACKEND = "AOT_HPU_TRAINING_BACKEND" + HPU_BACKEND = "HPU_BACKEND" @dataclass From 87443e361ec6bfd41bbc0a5e01dfd2404e949c5e Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Mon, 12 Feb 2024 09:17:39 -0800 Subject: [PATCH 07/18] enable falcon-180b inference (#15) * enable loading falcon-180b ckpt in .safetensors format * Address comments borrowing transformer's way of reading ckpt file * address comments --- optimum/habana/checkpoint_utils.py | 23 +++++++++++++++---- .../habana/transformers/generation/utils.py | 3 +++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index e0fc139f5d..60bf71d58a 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -3,6 +3,7 @@ from pathlib import Path import torch +import transformers from huggingface_hub import snapshot_download from transformers.utils import is_offline_mode @@ -53,13 +54,27 @@ def get_repo_root(model_name_or_path, local_rank=-1, token=None): def get_checkpoint_files(model_name_or_path, local_rank, token=None): """ Gets the list of files for the specified model checkpoint. + Copied from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/modeling_utils.py#L414 """ cached_repo_dir = get_repo_root(model_name_or_path, local_rank=local_rank, token=token) - # Extensions: .bin | .pt - # Creates a list of paths from all downloaded files in cache dir - file_list = [str(entry) for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") if entry.is_file()] - return file_list + index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.WEIGHTS_INDEX_NAME) + safe_index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) + + index_present = os.path.isfile(index_file) + safe_index_present = os.path.isfile(safe_index_file) + + if not index_present and not safe_index_present: + filenames = (transformers.modeling_utils.WEIGHTS_INDEX_NAME, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) + raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {cached_repo_dir}.") + + load_index = safe_index_file if safe_index_present else index_file + + with open(load_index, "r", encoding="utf-8") as f: + index = json.load(f) + + file_list = set(index["weight_map"].values()) + return [os.path.join(cached_repo_dir, entry) for entry in file_list] def write_checkpoints_json(model_name_or_path, local_rank, f, token=None): diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index f1fdf0748c..1747868116 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -543,6 +543,9 @@ def generate( generation_config.ignore_eos = kwargs.get("ignore_eos", lazy_mode) generation_config.validate() model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + if self.config.model_type == "falcon" and "token_type_ids" in kwargs.keys(): + for key in ["token_type_ids"]: + model_kwargs.pop(key, None) self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() From 19c5e7e4f1fa4055aa534a0dd42a520bcc9303d9 Mon Sep 17 00:00:00 2001 From: Manoj Kumar Date: Tue, 13 Feb 2024 14:29:09 +0530 Subject: [PATCH 08/18] To fix LLAMA-V2-70B-FT-HF (8x) for eager mode (#35) mark_step() should not be called for eager mode Signed-off-by: Manoj Kumar --- optimum/habana/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 99514c295b..09d9a25ce4 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -1425,7 +1425,7 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training - if self.args.pipelining_fwd_bwd: + if self.args.use_lazy_mode and self.args.pipelining_fwd_bwd: self.htcore.mark_step() self.accelerator.backward(loss) From f4e023905831d3f339f1518492d674756ff7855a Mon Sep 17 00:00:00 2001 From: Taylor Jackle Spriggs <74561858+tjs-intel@users.noreply.github.com> Date: Tue, 13 Feb 2024 13:20:44 -0700 Subject: [PATCH 09/18] Add support for safetensors and sharded checkpoints (#25) Co-authored-by: Sun Choi --- optimum/habana/checkpoint_utils.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 60bf71d58a..1a14c52b2c 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -54,10 +54,26 @@ def get_repo_root(model_name_or_path, local_rank=-1, token=None): def get_checkpoint_files(model_name_or_path, local_rank, token=None): """ Gets the list of files for the specified model checkpoint. - Copied from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/modeling_utils.py#L414 """ cached_repo_dir = get_repo_root(model_name_or_path, local_rank=local_rank, token=token) + # Logic for loading individual weights from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/trainer.py#L2061 + individual_weights = [ + os.path.join(cached_repo_dir, weight_name) + for weight_name in ( + transformers.modeling_utils.SAFE_WEIGHTS_NAME, + transformers.modeling_utils.WEIGHTS_NAME, + ) + ] + checkpoint_files = [] + for weight_file in individual_weights: + if os.path.isfile(weight_file): + checkpoint_files.append(weight_file) + break + if checkpoint_files: + return checkpoint_files + + # Code for loading sharded weights copied from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/modeling_utils.py#L414 index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.WEIGHTS_INDEX_NAME) safe_index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) From 453b14a243aad0e4b62e723ce88e95c0d14f93d9 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Wed, 14 Feb 2024 12:06:05 -0800 Subject: [PATCH 10/18] Update ckpt loading (#38) * enable loading falcon-180b ckpt in .safetensors format * Address comments borrowing transformer's way of reading ckpt file * address comments * Update ckpt loading PR#15 reads a set of ckpt file names from the index json file. When OH downloads files from the hub instead of loading from a cache dir, get_repo_root() skips downloading the index json file. Thus the PR#15 fails to load file names. This PR scans the path and returns a list of names that matches the pattern * import modeling_utils from transformers --- optimum/habana/checkpoint_utils.py | 59 +++++++++++------------------- 1 file changed, 22 insertions(+), 37 deletions(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 1a14c52b2c..8cf5070b34 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -3,8 +3,8 @@ from pathlib import Path import torch -import transformers -from huggingface_hub import snapshot_download +from huggingface_hub import list_repo_files, snapshot_download +from transformers import modeling_utils from transformers.utils import is_offline_mode @@ -22,7 +22,12 @@ def get_repo_root(model_name_or_path, local_rank=-1, token=None): print("Offline mode: forcing local_files_only=True") # Only download PyTorch weights by default - allow_patterns = ["*.bin"] + if any(".bin" in filename for filename in list_repo_files(model_name_or_path, token=token)): + allow_patterns = ["*.bin"] + elif any( + ".safetensors" in filename for filename in list_repo_files(model_name_or_path, token=token) + ): # Some models like Falcon-180b are in only safetensors format + allow_patterns = ["*.safetensors"] # Download only on first process if local_rank in [-1, 0]: @@ -52,45 +57,25 @@ def get_repo_root(model_name_or_path, local_rank=-1, token=None): def get_checkpoint_files(model_name_or_path, local_rank, token=None): - """ - Gets the list of files for the specified model checkpoint. - """ cached_repo_dir = get_repo_root(model_name_or_path, local_rank=local_rank, token=token) - # Logic for loading individual weights from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/trainer.py#L2061 - individual_weights = [ - os.path.join(cached_repo_dir, weight_name) - for weight_name in ( - transformers.modeling_utils.SAFE_WEIGHTS_NAME, - transformers.modeling_utils.WEIGHTS_NAME, - ) - ] - checkpoint_files = [] - for weight_file in individual_weights: - if os.path.isfile(weight_file): - checkpoint_files.append(weight_file) - break - if checkpoint_files: - return checkpoint_files - - # Code for loading sharded weights copied from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/modeling_utils.py#L414 - index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.WEIGHTS_INDEX_NAME) - safe_index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) + # Extensions: .bin | .safetensors | .pt + # Creates a list of paths from all downloaded files in cache dir - index_present = os.path.isfile(index_file) - safe_index_present = os.path.isfile(safe_index_file) - - if not index_present and not safe_index_present: - filenames = (transformers.modeling_utils.WEIGHTS_INDEX_NAME, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) - raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {cached_repo_dir}.") - - load_index = safe_index_file if safe_index_present else index_file + if any(file.suffix == ".bin" for file in Path(cached_repo_dir).rglob("*")): + (name, ext) = os.path.splitext(modeling_utils.WEIGHTS_NAME) + elif any(file.suffix == ".safetensors" for file in Path(cached_repo_dir).rglob("*")): + (name, ext) = os.path.splitext(modeling_utils.SAFE_WEIGHTS_NAME) + else: + (name, ext) = ("*", ".pt") - with open(load_index, "r", encoding="utf-8") as f: - index = json.load(f) + file_list = [ + str(entry) + for entry in Path(cached_repo_dir).rglob("*") + if (entry.is_file() and entry.name.startswith(name) and entry.name.endswith(ext)) + ] - file_list = set(index["weight_map"].values()) - return [os.path.join(cached_repo_dir, entry) for entry in file_list] + return file_list def write_checkpoints_json(model_name_or_path, local_rank, f, token=None): From e2de09bc9115fa14b52b6a21752dfc4a8ba674ef Mon Sep 17 00:00:00 2001 From: Bhargav Date: Thu, 15 Feb 2024 11:31:45 +0530 Subject: [PATCH 11/18] Fix tests (#669) (#41) Co-authored-by: Sayantan Sarkar --- optimum/habana/transformers/generation/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 1747868116..be4ba15915 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1384,13 +1384,13 @@ def greedy_search( hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps) hb_profer.start() this_peer_finished = False # used by synced_gpus only - bucket_size = model_kwargs["bucket_size"] + bucket_size = model_kwargs.get("bucket_size", -1) prev_idx = None # avoiding calculate cache_idx when its value is not changing bucket_internal = model_kwargs["bucket_internal"] - reduce_recompile = model_kwargs["reduce_recompile"] + reduce_recompile = model_kwargs.get("reduce_recompile", False) prompt_len = input_ids.shape[-1] - + if not bucket_internal: if bucket_size >= 0: inc = iter(incrementor(bucket_size, prompt_len)) @@ -2167,8 +2167,8 @@ def expand_if_needed(tensor, new_size, value, dim=-1): hb_profer.start() this_peer_finished = False # used by synced_gpus only - bucket_size = model_kwargs["bucket_size"] - reduce_recompile = model_kwargs["reduce_recompile"] + bucket_size = model_kwargs.get("bucket_size", -1) + reduce_recompile = model_kwargs.get("reduce_recompile", False) prompt_len = input_ids.shape[-1] if bucket_size >= 0: inc = iter(incrementor(bucket_size, prompt_len)) From 72b88d94f9bf61b681ad87da099e8710026c7798 Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Fri, 16 Feb 2024 09:51:59 +0530 Subject: [PATCH 12/18] Further fixes for performance with internal bucketing. (#36) * Further fixes for performance with internal bucketing. Also add clear cache() to save memory. make style changes also added. Signed-off-by: Puneesh Khanna * Calculate kv cache sliding idx for the decode phase only. Signed-off-by: Puneesh Khanna * Add hpu graphs check for clear cache. Signed-off-by: Puneesh Khanna --------- Signed-off-by: Puneesh Khanna --- .../habana/transformers/generation/utils.py | 51 +++++++++++-------- .../models/llama/modeling_llama.py | 3 +- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index be4ba15915..27e9a7e7e0 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -240,6 +240,8 @@ def _update_model_kwargs_for_generation( if token_idx is not None: token_idx.add_(1) + if "token_idx_cpu" in model_kwargs: + model_kwargs["token_idx_cpu"] += 1 return model_kwargs @@ -609,10 +611,12 @@ def generate( if model_kwargs["reduce_recompile"]: assert generation_config.bucket_size if generation_config.bucket_internal: - assert generation_config.bucket_size >= 0, "bucket_internal and bucket_size flags set together" + assert generation_config.bucket_size >= 0, "please set bucket_size to use bucket_internal" assert generation_config.reuse_cache, "please set reuse_cache to use bucket_internal" if generation_config.reuse_cache and not generation_config.bucket_internal: - assert generation_config.bucket_size <= 0, "reuse_cache and bucketing flags set together" + assert ( + generation_config.bucket_size <= 0 + ), "please set bucket_internal along with reuse_cache and bucket_size" if generation_config.static_shapes: # Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs @@ -624,6 +628,7 @@ def generate( # token_idx is the current index in the generation process, it is incremented each time a new token is generated token_idx = inputs_tensor.shape[-1] model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device) + model_kwargs["token_idx_cpu"] = token_idx inputs_tensor = torch.nn.functional.pad( inputs_tensor, (0, generation_config.max_new_tokens), value=generation_config.pad_token_id ) @@ -703,6 +708,7 @@ def generate( model_kwargs["attn_softmax_bf16"] = generation_config.attn_softmax_bf16 # determine whether limit_hpu_graphs needs to be used + model_kwargs["use_hpu_graphs"] = hpu_graphs model_kwargs["limit_hpu_graphs"] = generation_config.limit_hpu_graphs # prepare for allocate kv cache @@ -1384,13 +1390,13 @@ def greedy_search( hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps) hb_profer.start() this_peer_finished = False # used by synced_gpus only + bucket_size = model_kwargs.get("bucket_size", -1) - prev_idx = None # avoiding calculate cache_idx when its value is not changing + prev_idx = -1 # avoiding calculate cache_idx when its value is not changing bucket_internal = model_kwargs["bucket_internal"] reduce_recompile = model_kwargs.get("reduce_recompile", False) prompt_len = input_ids.shape[-1] - if not bucket_internal: if bucket_size >= 0: inc = iter(incrementor(bucket_size, prompt_len)) @@ -1411,23 +1417,12 @@ def greedy_search( if this_peer_finished_flag.item() == 0.0: break - if bucket_size > 0: - if not bucket_internal: - # it will not have been padded if bucket_size > 0 - params = next(inc) - input_ids, model_kwargs = self.update_model_kwargs_for_bucketing( - params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile - ) - else: - # Calculate slice idx for kv cache. Breaking down the kv cache in the attention block helps to reduce computation time. - if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size: - idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor") - if idx != prev_idx: - cache_idx = (idx.item() + 1) * bucket_size - model_kwargs["cache_idx"] = cache_idx - prev_idx = idx - else: - model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"] + if bucket_size > 0 and not bucket_internal: + # it will not have been padded if bucket_size > 0 + params = next(inc) + input_ids, model_kwargs = self.update_model_kwargs_for_bucketing( + params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile + ) # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -1502,6 +1497,18 @@ def greedy_search( model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) + if bucket_size > 0 and bucket_internal: + # Calculate slice idx for kv cache during the decode phase. + # Breaking down the kv cache in the attention block helps to reduce computation time. + if model_kwargs.get("token_idx_cpu") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size: + idx = (model_kwargs.get("token_idx_cpu") - 1) // bucket_size + if prev_idx != idx: + model_kwargs["cache_idx"] = (idx + 1) * bucket_size + prev_idx = idx + if model_kwargs["use_hpu_graphs"]: + self.clear_cache() + else: + model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"] # if eos_token was found in one sentence, set sentence to finished if not ignore_eos and eos_token_id_tensor is not None: @@ -1528,6 +1535,8 @@ def greedy_search( if this_peer_finished and not synced_gpus: break + if model_kwargs["use_hpu_graphs"]: + self.clear_cache() hb_profer.stop() if streamer is not None: streamer.end() diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 2622d832bd..d4de78a788 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -269,7 +269,8 @@ def pre_attn_forward( if cache_idx is not None and q_len == 1: key_states = key_states[:, :, :cache_idx, :] value_states = value_states[:, :, :cache_idx, :] - attention_mask = attention_mask[:, :, :, :cache_idx] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] kv_seq_len = key_states.shape[-2] if use_cache: From 99e564325e21c1bfaf81a0dac5016afee2b0c51f Mon Sep 17 00:00:00 2001 From: Yeonsil Yoon Date: Fri, 16 Feb 2024 09:26:39 -0800 Subject: [PATCH 13/18] Adding a flag whether to save checkpoint or not. (#37) * Adding a flag whether to save checkpoint or not * Add the flag to a model run script --- examples/language-modeling/run_clm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index c50f8e6905..1cdc459268 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -243,6 +243,9 @@ class DataTrainingArguments: keep_linebreaks: bool = field( default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} ) + save_last_ckpt: bool = field( + default=True, metadata={"help": "Whether to save checkpoint at the end of the training."} + ) def __post_init__(self): if self.streaming: @@ -643,7 +646,8 @@ def compute_metrics(eval_preds): elif last_checkpoint is not None: checkpoint = last_checkpoint train_result = trainer.train(resume_from_checkpoint=checkpoint) - trainer.save_model() # Saves the tokenizer too for easy upload + if data_args.save_last_ckpt: + trainer.save_model() # Saves the tokenizer too for easy upload metrics = train_result.metrics From 8e694b96d1567cb8eb5b1a15de8b23b78e491a4d Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Fri, 16 Feb 2024 09:30:13 -0800 Subject: [PATCH 14/18] Update llama-7b command to include eval (#43) --- examples/language-modeling/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index ac4a74ab69..783f14171a 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -370,6 +370,7 @@ python3 run_lora_clm.py \ --max_grad_norm 0.3 \ --logging_steps 1 \ --do_train \ + --do_eval \ --use_habana \ --use_lazy_mode \ --throughput_warmup_steps 3 \ @@ -380,6 +381,7 @@ python3 run_lora_clm.py \ --dataset_concatenation \ --max_seq_length 512 \ --low_cpu_mem_usage True \ + --validation_split_percentage 4 \ --adam_epsilon 1e-08 ``` @@ -436,6 +438,7 @@ python ../gaudi_spawn.py \ --max_grad_norm 0.3 \ --logging_steps 1 \ --do_train \ + --do_eval \ --use_habana \ --use_lazy_mode \ --throughput_warmup_steps 3 \ @@ -447,6 +450,7 @@ python ../gaudi_spawn.py \ --max_seq_length 512 \ --ddp_bucket_cap_mb 50 \ --adam_epsilon 1e-08 \ + --validation_split_percentage 4 \ --low_cpu_mem_usage True ``` From af2c2c2cc685f3d86a3c058265c144d7332bdaec Mon Sep 17 00:00:00 2001 From: Mohit Deopujari Date: Fri, 16 Feb 2024 13:55:07 -0800 Subject: [PATCH 15/18] [BridgeTower] Fix for NoneType in clip mediapipe (#45) * [SW-174850] Fix for Nonetype in image * Using media external reader API * minor fixes * output info fix * Update media reader function name * make style --- .../contrastive-image-text/clip_media_pipe.py | 60 ++++++++----------- 1 file changed, 26 insertions(+), 34 deletions(-) mode change 100644 => 100755 examples/contrastive-image-text/clip_media_pipe.py diff --git a/examples/contrastive-image-text/clip_media_pipe.py b/examples/contrastive-image-text/clip_media_pipe.py old mode 100644 new mode 100755 index 62c2a5651b..574837e38f --- a/examples/contrastive-image-text/clip_media_pipe.py +++ b/examples/contrastive-image-text/clip_media_pipe.py @@ -24,29 +24,37 @@ try: from habana_frameworks.mediapipe import fn - from habana_frameworks.mediapipe.backend.nodes import opnode_tensor_info - from habana_frameworks.mediapipe.backend.operator_specs import schema from habana_frameworks.mediapipe.media_types import dtype, ftype, imgtype, randomCropType, readerOutType from habana_frameworks.mediapipe.mediapipe import MediaPipe - from habana_frameworks.mediapipe.operators.media_nodes import MediaReaderNode from habana_frameworks.mediapipe.operators.reader_nodes.read_image_from_dir import get_max_file + from habana_frameworks.mediapipe.operators.reader_nodes.reader_nodes import ( + media_ext_reader_op_impl, + media_ext_reader_op_tensor_info, + ) from habana_frameworks.torch.hpu import get_device_name except ImportError: pass +read_image_text_from_dataset_params = { + "label_dtype": dtype.UINT64, + "dataset": None, +} -class read_image_text_from_dataset(MediaReaderNode): + +class read_image_text_from_dataset(media_ext_reader_op_impl): """ - Class defining read image/text from directory node. + Class defining read image/text from clip dataset. """ - def __init__(self, name, guid, device, inputs, params, cparams, node_attr): - super().__init__(name, guid, device, inputs, params, cparams, node_attr) + def __init__(self, params): + self.batch_size = 1 + params = params["priv_params"] self.meta_dtype = params["label_dtype"] self.dataset = params["dataset"] self.epoch = 0 - + self.batch_sampler_iter = None + self.iter_loc = 0 self.num_imgs_slice = len(ClipMediaPipe.batch_sampler.sampler) self.num_batches_slice = len(ClipMediaPipe.batch_sampler) @@ -62,13 +70,13 @@ def set_params(self, params): def gen_output_info(self): out_info = [] - o = opnode_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "") + o = media_ext_reader_op_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "") out_info.append(o) - o = opnode_tensor_info( + o = media_ext_reader_op_tensor_info( self.meta_dtype, np.array([self.dataset.text_max_length, self.batch_size], dtype=np.uint32), "" ) out_info.append(o) - o = opnode_tensor_info( + o = media_ext_reader_op_tensor_info( self.meta_dtype, np.array([self.dataset.text_max_length, self.batch_size], dtype=np.uint32), "" ) out_info.append(o) @@ -112,27 +120,6 @@ def __next__(self): return img_list, input_id_list, attention_mask_list -read_image_text_from_dataset_params = { - "label_dtype": dtype.UINT64, - "dataset": None, -} -schema.add_operator( - "ClipDataReader", - None, - 0, - 0, - [], - 3, - read_image_text_from_dataset_params, - None, - read_image_text_from_dataset, - dtype.NDT, -) -op_class = fn.operator_add("ClipDataReader") -op_class.__module__ = fn.__name__ -setattr(fn, "ClipDataReader", op_class) - - class ClipMediaPipe(MediaPipe): """ Class defining clip media pipe: @@ -160,8 +147,13 @@ def __init__(self, dataset=None, sampler=None, batch_size=512, drop_last=False, super(ClipMediaPipe, self).__init__( device=self.device, batch_size=batch_size, prefetch_depth=queue_depth, pipe_name=pipe_name ) - - self.input = fn.ClipDataReader(label_dtype=dtype.UINT32, dataset=self.dataset) + params = read_image_text_from_dataset_params.copy() + params["dataset"] = self.dataset + self.input = fn.MediaExtReaderOp( + impl=read_image_text_from_dataset, + num_outputs=3, + priv_params=params, + ) def_output_image_size = [self.image_size, self.image_size] res_pp_filter = ftype.BICUBIC self.decode = fn.ImageDecoder( From c7745153e0a00450587acd782bc9169ecdc49b34 Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Mon, 19 Feb 2024 06:11:37 +0200 Subject: [PATCH 16/18] Fix graph breaks in torch compile mode Signed-off-by: Sanju C Sudhakaran --- .../habana/transformers/models/llama/modeling_llama.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index d4de78a788..c86ab9a1e1 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -18,12 +18,16 @@ try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE + + has_fused_rope = True except ImportError: print("Not using HPU fused kernel for apply_rotary_pos_emb") FusedRoPE = None try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm + + has_fused_rms_norm = True except ImportError: print("Not using HPU fused kernel for RMSNorm") FusedRMSNorm = None @@ -60,7 +64,7 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states): The only differences are: - override RMSNorm with Habana fused RMSNorm """ - if hidden_states.device.type == "hpu" and FusedRMSNorm: + if hidden_states.device.type == "hpu" and has_fused_rms_norm: # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype if hidden_states.dtype != self.weight.dtype: orig_dtype = hidden_states.dtype @@ -865,7 +869,7 @@ def prepare_inputs_for_generation( def apply_customized_rope(q, k, cos, sin, position_ids, use_fused_rope=True): - if q.device.type == "hpu" and FusedRoPE and use_fused_rope: + if q.device.type == "hpu" and has_fused_rope and use_fused_rope: # TODO: remove `.clone()` when SynapseAI v1.15 is released return FusedRoPE.apply(q, cos.clone(), sin.clone(), position_ids), FusedRoPE.apply( k, cos.clone(), sin.clone(), position_ids From 322e3351431a298dfae39ef92b7064f74514ca3a Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Wed, 21 Feb 2024 14:24:09 +0200 Subject: [PATCH 17/18] Fix graph breaks in torch compile mode Signed-off-by: Sanju C Sudhakaran --- optimum/habana/transformers/models/llama/modeling_llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index c86ab9a1e1..b9b99d0e7f 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -21,8 +21,8 @@ has_fused_rope = True except ImportError: + has_fused_rope = False print("Not using HPU fused kernel for apply_rotary_pos_emb") - FusedRoPE = None try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm @@ -30,7 +30,6 @@ has_fused_rms_norm = True except ImportError: print("Not using HPU fused kernel for RMSNorm") - FusedRMSNorm = None try: from habana_frameworks.torch.hpex.kernels import FusedSDPA From 5d4dafd67e688590c9fb1f310f2fc08b90097d5b Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Wed, 21 Feb 2024 14:27:04 +0200 Subject: [PATCH 18/18] Fix graph breaks in torch compile mode Signed-off-by: Sanju C Sudhakaran --- optimum/habana/transformers/models/llama/modeling_llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index b9b99d0e7f..0d5cae10ab 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -29,6 +29,7 @@ has_fused_rms_norm = True except ImportError: + has_fused_rms_norm = False print("Not using HPU fused kernel for RMSNorm") try: