diff --git a/examples/contrastive-image-text/README.md b/examples/contrastive-image-text/README.md index b526c6091b..cd4aa92295 100644 --- a/examples/contrastive-image-text/README.md +++ b/examples/contrastive-image-text/README.md @@ -250,5 +250,8 @@ python run_clip.py \ --use_lazy_mode \ --use_hpu_graphs_for_inference \ --gaudi_config_name Habana/clip \ - --bf16 + --bf16 \ + --mediapipe_dataloader ``` + +> `--mediapipe_dataloader` only works on Gaudi2. diff --git a/examples/contrastive-image-text/clip_media_pipe.py b/examples/contrastive-image-text/clip_media_pipe.py old mode 100644 new mode 100755 index 62c2a5651b..574837e38f --- a/examples/contrastive-image-text/clip_media_pipe.py +++ b/examples/contrastive-image-text/clip_media_pipe.py @@ -24,29 +24,37 @@ try: from habana_frameworks.mediapipe import fn - from habana_frameworks.mediapipe.backend.nodes import opnode_tensor_info - from habana_frameworks.mediapipe.backend.operator_specs import schema from habana_frameworks.mediapipe.media_types import dtype, ftype, imgtype, randomCropType, readerOutType from habana_frameworks.mediapipe.mediapipe import MediaPipe - from habana_frameworks.mediapipe.operators.media_nodes import MediaReaderNode from habana_frameworks.mediapipe.operators.reader_nodes.read_image_from_dir import get_max_file + from habana_frameworks.mediapipe.operators.reader_nodes.reader_nodes import ( + media_ext_reader_op_impl, + media_ext_reader_op_tensor_info, + ) from habana_frameworks.torch.hpu import get_device_name except ImportError: pass +read_image_text_from_dataset_params = { + "label_dtype": dtype.UINT64, + "dataset": None, +} -class read_image_text_from_dataset(MediaReaderNode): + +class read_image_text_from_dataset(media_ext_reader_op_impl): """ - Class defining read image/text from directory node. + Class defining read image/text from clip dataset. """ - def __init__(self, name, guid, device, inputs, params, cparams, node_attr): - super().__init__(name, guid, device, inputs, params, cparams, node_attr) + def __init__(self, params): + self.batch_size = 1 + params = params["priv_params"] self.meta_dtype = params["label_dtype"] self.dataset = params["dataset"] self.epoch = 0 - + self.batch_sampler_iter = None + self.iter_loc = 0 self.num_imgs_slice = len(ClipMediaPipe.batch_sampler.sampler) self.num_batches_slice = len(ClipMediaPipe.batch_sampler) @@ -62,13 +70,13 @@ def set_params(self, params): def gen_output_info(self): out_info = [] - o = opnode_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "") + o = media_ext_reader_op_tensor_info(dtype.NDT, np.array([self.batch_size], dtype=np.uint32), "") out_info.append(o) - o = opnode_tensor_info( + o = media_ext_reader_op_tensor_info( self.meta_dtype, np.array([self.dataset.text_max_length, self.batch_size], dtype=np.uint32), "" ) out_info.append(o) - o = opnode_tensor_info( + o = media_ext_reader_op_tensor_info( self.meta_dtype, np.array([self.dataset.text_max_length, self.batch_size], dtype=np.uint32), "" ) out_info.append(o) @@ -112,27 +120,6 @@ def __next__(self): return img_list, input_id_list, attention_mask_list -read_image_text_from_dataset_params = { - "label_dtype": dtype.UINT64, - "dataset": None, -} -schema.add_operator( - "ClipDataReader", - None, - 0, - 0, - [], - 3, - read_image_text_from_dataset_params, - None, - read_image_text_from_dataset, - dtype.NDT, -) -op_class = fn.operator_add("ClipDataReader") -op_class.__module__ = fn.__name__ -setattr(fn, "ClipDataReader", op_class) - - class ClipMediaPipe(MediaPipe): """ Class defining clip media pipe: @@ -160,8 +147,13 @@ def __init__(self, dataset=None, sampler=None, batch_size=512, drop_last=False, super(ClipMediaPipe, self).__init__( device=self.device, batch_size=batch_size, prefetch_depth=queue_depth, pipe_name=pipe_name ) - - self.input = fn.ClipDataReader(label_dtype=dtype.UINT32, dataset=self.dataset) + params = read_image_text_from_dataset_params.copy() + params["dataset"] = self.dataset + self.input = fn.MediaExtReaderOp( + impl=read_image_text_from_dataset, + num_outputs=3, + priv_params=params, + ) def_output_image_size = [self.image_size, self.image_size] res_pp_filter = ftype.BICUBIC self.decode = fn.ImageDecoder( diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index 909593427d..783f14171a 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -370,6 +370,7 @@ python3 run_lora_clm.py \ --max_grad_norm 0.3 \ --logging_steps 1 \ --do_train \ + --do_eval \ --use_habana \ --use_lazy_mode \ --throughput_warmup_steps 3 \ @@ -380,6 +381,7 @@ python3 run_lora_clm.py \ --dataset_concatenation \ --max_seq_length 512 \ --low_cpu_mem_usage True \ + --validation_split_percentage 4 \ --adam_epsilon 1e-08 ``` @@ -436,6 +438,7 @@ python ../gaudi_spawn.py \ --max_grad_norm 0.3 \ --logging_steps 1 \ --do_train \ + --do_eval \ --use_habana \ --use_lazy_mode \ --throughput_warmup_steps 3 \ @@ -447,6 +450,7 @@ python ../gaudi_spawn.py \ --max_seq_length 512 \ --ddp_bucket_cap_mb 50 \ --adam_epsilon 1e-08 \ + --validation_split_percentage 4 \ --low_cpu_mem_usage True ``` @@ -550,7 +554,8 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \ --lora_rank 4 \ --lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \ --validation_split_percentage 4 \ - --use_flash_attention True + --use_flash_attention True \ + --flash_attention_causal_mask True ``` - Multi-card finetuning of Falcon-180B: diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index c50f8e6905..1cdc459268 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -243,6 +243,9 @@ class DataTrainingArguments: keep_linebreaks: bool = field( default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} ) + save_last_ckpt: bool = field( + default=True, metadata={"help": "Whether to save checkpoint at the end of the training."} + ) def __post_init__(self): if self.streaming: @@ -643,7 +646,8 @@ def compute_metrics(eval_preds): elif last_checkpoint is not None: checkpoint = last_checkpoint train_result = trainer.train(resume_from_checkpoint=checkpoint) - trainer.save_model() # Saves the tokenizer too for easy upload + if data_args.save_last_ckpt: + trainer.save_model() # Saves the tokenizer too for easy upload metrics = train_result.metrics diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index b480990752..ba3244e57f 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -156,6 +156,23 @@ class ModelArguments: ) }, ) + flash_attention_causal_mask: bool = field( + default=False, + metadata={ + "help": ( + "Whether to enable causal mask in Habana flash attention for fine-tuning." + " It is applicable only when use_flash_attention is True.", + ) + }, + ) + use_fused_rope: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use Habana fused-rope for fine-tuning. The current support is limited to Llama only.", + ) + }, + ) load_meta_device: bool = field( default=False, metadata={ @@ -537,6 +554,9 @@ def main(): if model_args.use_flash_attention: model.generation_config.use_flash_attention = True model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute + model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask + if not model_args.use_fused_rope: + model.generation_config.use_fused_rope = False if hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None: tokenizer.pad_token_id = model.generation_config.pad_token_id diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 5732e684a4..332d117e2f 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -236,6 +236,9 @@ python run_generation.py \ `--bucket_size` option is especially useful when processing an input stream with varying lengths, that is when you have something like `--dataset_name squad --column_name context --max_input_tokens -1`. `--max_input_tokens -1` specifies no truncation of input prompt in the dataset. Another way to simulate dynamic input is to use `--simulate_dyn_prompt`. For example `--simulate_dyn_prompt 25,35,45` will extend or crop the default prompt (or the prompt passed in using `--prompt`) to sizes 25, 35, and 45, and throughput will be measured for these 3 lengths. If `--simulate_dyn_prompt` is used, the min and max input lengths from it are computed to perform warmup as well. One final optimization that can be used in case of dynamic inputs is `--reduce_recompile`. Thus the suggested configuration to simulate dynamicity after warmup is to use all three arguments: `--simulate_dyn_prompt 25 35 45 --reduce_recompile --bucket_size 30` + +While `--bucket_size` works for any model without model file changes, an even more optimized version of bucketing is supported for certain models like Llama. This can be enabled by setting `--bucket_internal` flag (along with `--bucket_size` to specify the bucket size) + ### Running with FP8 Llama2-70b and Llama2-7b in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch. @@ -293,6 +296,30 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ ``` `--fp8` is required to enable quantization in fp8. +### Using Habana Flash Attention + +Habana Flash Attention addresses large sequence lenghts on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes. + +Below example uses `flash_attention_recompute` mode in order to reduce memory consumption on prompt stage. Additionally since all sequences in a batch are of the same lenght it uses `flash_attention_causal_mask` which will further improve performance by taking advantage of specific lower-diagonal shape of inputs to softmax operation. + +```bash +python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ +--model_name_or_path meta-llama/Llama-2-70b-hf \ +--use_hpu_graphs \ +--use_kv_cache \ +--reuse_cache \ +--trim_logits \ +--attn_softmax_bf16 \ +--max_input_tokens 31744 \ +--max_new_tokens 1024 \ +--batch_size=12 \ +--use_flash_attention \ +--flash_attention_recompute \ +--flash_attention_causal_mask \ +--book_source +``` + +For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa). ## Language Model Evaluation Harness diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 14e9712595..048ef827dd 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -186,6 +186,11 @@ def setup_parser(parser): then we use `shape = prompt_length + max_new_tokens`. If a positive number is passed \ we increase the bucket in steps of `bucket_size` instead of allocating to max (`prompt_length + max_new_tokens`).", ) + parser.add_argument( + "--bucket_internal", + action="store_true", + help="Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large.", + ) parser.add_argument( "--dataset_max_samples", default=-1, @@ -227,6 +232,21 @@ def setup_parser(parser): action="store_true", help="Whether to enable Habana Flash Attention, provided that the model supports it.", ) + parser.add_argument( + "--flash_attention_recompute", + action="store_true", + help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.", + ) + parser.add_argument( + "--flash_attention_causal_mask", + action="store_true", + help="Whether to enable Habana Flash Attention in causal mode on first token generation.", + ) + parser.add_argument( + "--book_source", + action="store_true", + help="Whether to use project Guttenberg books data as input. Usefull for testing large sequence lenghts.", + ) parser.add_argument( "--torch_compile", action="store_true", @@ -266,6 +286,45 @@ def main(): # Benchmark over the prompts below if args.prompt: input_sentences = args.prompt + elif args.book_source: + + def download_book(book_id): + import os + + import requests + + url = f"https://www.gutenberg.org/cache/epub/{book_id}/pg{book_id}.txt" + response = requests.get(url) + if response.status_code == 200: + pid = os.getpid() + save_path = f"/tmp/{book_id}_{pid}.txt" + with open(save_path, "wb") as file: + file.write(response.content) + print(f"Book downloaded and saved to: {save_path}") + return save_path + else: + print("Failed to download book! Exiting...") + import sys + + sys.exit() + + def assemble_prompt(prompt_size, book_path): + prompt = "" + counter = 0 + book_lines = open(book_path).readlines() + for line in book_lines: + for word in line.split(): + counter += 1 + prompt += word + " " + if counter == prompt_size: + return [prompt] * args.batch_size + + book_ids = [ + 2701, # Moby Dick; Or, The Whale + 1513, # Romeo and Juliet + 1342, # Pride and Prejudice + ] + input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0])) else: input_sentences = [ "DeepSpeed is a machine learning framework", @@ -289,6 +348,8 @@ def main(): def generate(size=None, reduce_recompile=False): """Generates sequences from the input sentences and returns them.""" + t0 = time.perf_counter() + print(f"Step4+ starting time is {t0*1000}", flush=True) # Tokenization if args.max_input_tokens > 0: input_tokens = tokenizer.batch_encode_plus( @@ -309,7 +370,7 @@ def generate(size=None, reduce_recompile=False): if torch.is_tensor(input_tokens[t]): input_tokens[t] = input_tokens[t].to(args.device) - outputs = model.generate( + output_tokens = model.generate( **input_tokens, generation_config=generation_config, lazy_mode=use_lazy_mode, @@ -317,7 +378,10 @@ def generate(size=None, reduce_recompile=False): profiling_steps=args.profiling_steps, profiling_warmup_steps=args.profiling_warmup_steps, ).cpu() - return tokenizer.batch_decode(outputs, skip_special_tokens=True) + outputs = tokenizer.batch_decode(output_tokens, skip_special_tokens=True) + duration = time.perf_counter() - t0 + print(f"Total E2E time of this iteration is {duration:.3f}s", flush=True) + return outputs from optimum.habana.utils import HabanaProfile diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 9b66de8128..fc7f042223 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -329,6 +329,7 @@ def setup_generation_config(args, model, tokenizer): generation_config.use_cache = args.use_kv_cache generation_config.static_shapes = is_optimized generation_config.bucket_size = args.bucket_size if is_optimized else -1 + generation_config.bucket_internal = args.bucket_internal generation_config.do_sample = args.do_sample generation_config.num_beams = args.num_beams generation_config.bad_words_ids = bad_words_ids @@ -343,6 +344,8 @@ def setup_generation_config(args, model, tokenizer): assert generation_config.bucket_size > 0 generation_config.kv_cache_fp8 = args.kv_cache_fp8 generation_config.use_flash_attention = args.use_flash_attention + generation_config.flash_attention_recompute = args.flash_attention_recompute + generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask return generation_config diff --git a/optimum/habana/accelerate/utils/dataclasses.py b/optimum/habana/accelerate/utils/dataclasses.py index 07e256372f..c0484e2243 100644 --- a/optimum/habana/accelerate/utils/dataclasses.py +++ b/optimum/habana/accelerate/utils/dataclasses.py @@ -73,7 +73,8 @@ class GaudiDynamoBackend(str, BaseEnum): - **IPEX** -- Uses IPEX for inference on CPU. Inference only. [Read more](https://github.com/intel/intel-extension-for-pytorch). - **TVM** -- Uses Apach TVM for inference optimizations. [Read more](https://tvm.apache.org/) - - **AOT_HPU_TRAINING_BACKEND** -- Uses Habana Gaudi. + - **AOT_HPU_TRAINING_BACKEND** -- Uses Habana Gaudi - depracated - will be removed. + - **HPU_BACKEND** -- Uses Habana Gaudi. """ @@ -92,6 +93,7 @@ class GaudiDynamoBackend(str, BaseEnum): IPEX = "IPEX" TVM = "TVM" AOT_HPU_TRAINING_BACKEND = "AOT_HPU_TRAINING_BACKEND" + HPU_BACKEND = "HPU_BACKEND" @dataclass diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index e0fc139f5d..8cf5070b34 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -3,7 +3,8 @@ from pathlib import Path import torch -from huggingface_hub import snapshot_download +from huggingface_hub import list_repo_files, snapshot_download +from transformers import modeling_utils from transformers.utils import is_offline_mode @@ -21,7 +22,12 @@ def get_repo_root(model_name_or_path, local_rank=-1, token=None): print("Offline mode: forcing local_files_only=True") # Only download PyTorch weights by default - allow_patterns = ["*.bin"] + if any(".bin" in filename for filename in list_repo_files(model_name_or_path, token=token)): + allow_patterns = ["*.bin"] + elif any( + ".safetensors" in filename for filename in list_repo_files(model_name_or_path, token=token) + ): # Some models like Falcon-180b are in only safetensors format + allow_patterns = ["*.safetensors"] # Download only on first process if local_rank in [-1, 0]: @@ -51,14 +57,24 @@ def get_repo_root(model_name_or_path, local_rank=-1, token=None): def get_checkpoint_files(model_name_or_path, local_rank, token=None): - """ - Gets the list of files for the specified model checkpoint. - """ cached_repo_dir = get_repo_root(model_name_or_path, local_rank=local_rank, token=token) - # Extensions: .bin | .pt + # Extensions: .bin | .safetensors | .pt # Creates a list of paths from all downloaded files in cache dir - file_list = [str(entry) for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") if entry.is_file()] + + if any(file.suffix == ".bin" for file in Path(cached_repo_dir).rglob("*")): + (name, ext) = os.path.splitext(modeling_utils.WEIGHTS_NAME) + elif any(file.suffix == ".safetensors" for file in Path(cached_repo_dir).rglob("*")): + (name, ext) = os.path.splitext(modeling_utils.SAFE_WEIGHTS_NAME) + else: + (name, ext) = ("*", ".pt") + + file_list = [ + str(entry) + for entry in Path(cached_repo_dir).rglob("*") + if (entry.is_file() and entry.name.startswith(name) and entry.name.endswith(ext)) + ] + return file_list diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 577b4cbd5a..2e72342263 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -33,6 +33,8 @@ class GaudiGenerationConfig(GenerationConfig): Whether to use flash attention optimization. flash_attention_recompute (`bool`, *optional*): Whether to enable recompute if use Habana flash attention. + flash_attention_causal_mask (`bool`, *optional*): + Whether to enable causal_mask if use Habana flash attention. """ def __init__(self, **kwargs): @@ -44,7 +46,10 @@ def __init__(self, **kwargs): self.limit_hpu_graphs = kwargs.get("limit_hpu_graphs", None) self.reuse_cache = kwargs.get("reuse_cache", None) self.bucket_size = kwargs.get("bucket_size", -1) + self.bucket_internal = kwargs.get("bucket_internal", None) self.reduce_recompile = kwargs.get("reduce_recompile", None) self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None) self.use_flash_attention = kwargs.get("use_flash_attention", None) self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None) + self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None) + self.use_fused_rope = kwargs.get("use_fused_rope", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index bc8fad5118..27e9a7e7e0 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -17,6 +17,7 @@ import copy import inspect import math +import time import warnings from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -239,6 +240,8 @@ def _update_model_kwargs_for_generation( if token_idx is not None: token_idx.add_(1) + if "token_idx_cpu" in model_kwargs: + model_kwargs["token_idx_cpu"] += 1 return model_kwargs @@ -542,6 +545,9 @@ def generate( generation_config.ignore_eos = kwargs.get("ignore_eos", lazy_mode) generation_config.validate() model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + if self.config.model_type == "falcon" and "token_type_ids" in kwargs.keys(): + for key in ["token_type_ids"]: + model_kwargs.pop(key, None) self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() @@ -589,18 +595,28 @@ def generate( inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id ) - is_greedy_or_beam_and_bucket = generation_config.bucket_size > 0 and ( - self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH - or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH + is_greedy_or_beam_and_bucket = ( + not generation_config.bucket_internal + and generation_config.bucket_size > 0 + and ( + self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH + or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH + ) ) model_kwargs["bucket_size"] = generation_config.bucket_size if generation_config.static_shapes else -1 + model_kwargs["bucket_internal"] = generation_config.bucket_internal model_kwargs["reduce_recompile"] = ( generation_config.reduce_recompile if generation_config.reduce_recompile is not None else False ) if model_kwargs["reduce_recompile"]: assert generation_config.bucket_size - if generation_config.reuse_cache: - assert generation_config.bucket_size <= 0, "reuse_cache and bucketing flags set together" + if generation_config.bucket_internal: + assert generation_config.bucket_size >= 0, "please set bucket_size to use bucket_internal" + assert generation_config.reuse_cache, "please set reuse_cache to use bucket_internal" + if generation_config.reuse_cache and not generation_config.bucket_internal: + assert ( + generation_config.bucket_size <= 0 + ), "please set bucket_internal along with reuse_cache and bucket_size" if generation_config.static_shapes: # Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs @@ -612,6 +628,7 @@ def generate( # token_idx is the current index in the generation process, it is incremented each time a new token is generated token_idx = inputs_tensor.shape[-1] model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device) + model_kwargs["token_idx_cpu"] = token_idx inputs_tensor = torch.nn.functional.pad( inputs_tensor, (0, generation_config.max_new_tokens), value=generation_config.pad_token_id ) @@ -691,6 +708,7 @@ def generate( model_kwargs["attn_softmax_bf16"] = generation_config.attn_softmax_bf16 # determine whether limit_hpu_graphs needs to be used + model_kwargs["use_hpu_graphs"] = hpu_graphs model_kwargs["limit_hpu_graphs"] = generation_config.limit_hpu_graphs # prepare for allocate kv cache @@ -699,6 +717,8 @@ def generate( # determine whether flash attention needs to be used model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False + model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False + model_kwargs["use_fused_rope"] = False if not generation_config.use_fused_rope else True if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[-1] @@ -713,6 +733,8 @@ def generate( token_idx, generation_config.kv_cache_fp8, ) + model_kwargs["kv_cache_len"] = calculated_max_length + if self.config.model_type in ["llama"]: if self.config.max_position_embeddings < calculated_max_length: unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length) @@ -1368,15 +1390,19 @@ def greedy_search( hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps) hb_profer.start() this_peer_finished = False # used by synced_gpus only - bucket_size = model_kwargs["bucket_size"] - reduce_recompile = model_kwargs["reduce_recompile"] - prompt_len = input_ids.shape[-1] - if bucket_size >= 0: - inc = iter(incrementor(bucket_size, prompt_len)) - if bucket_size > 0: - assert "position_ids" not in model_kwargs, "Untested path" + bucket_size = model_kwargs.get("bucket_size", -1) + prev_idx = -1 # avoiding calculate cache_idx when its value is not changing + bucket_internal = model_kwargs["bucket_internal"] + reduce_recompile = model_kwargs.get("reduce_recompile", False) + prompt_len = input_ids.shape[-1] + if not bucket_internal: + if bucket_size >= 0: + inc = iter(incrementor(bucket_size, prompt_len)) + if bucket_size > 0: + assert "position_ids" not in model_kwargs, "Untested path" + greedy_first = True while True: if lazy_mode: self.htcore_generation.mark_step() @@ -1391,7 +1417,7 @@ def greedy_search( if this_peer_finished_flag.item() == 0.0: break - if bucket_size > 0: + if bucket_size > 0 and not bucket_internal: # it will not have been padded if bucket_size > 0 params = next(inc) input_ids, model_kwargs = self.update_model_kwargs_for_bucketing( @@ -1471,6 +1497,18 @@ def greedy_search( model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) + if bucket_size > 0 and bucket_internal: + # Calculate slice idx for kv cache during the decode phase. + # Breaking down the kv cache in the attention block helps to reduce computation time. + if model_kwargs.get("token_idx_cpu") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size: + idx = (model_kwargs.get("token_idx_cpu") - 1) // bucket_size + if prev_idx != idx: + model_kwargs["cache_idx"] = (idx + 1) * bucket_size + prev_idx = idx + if model_kwargs["use_hpu_graphs"]: + self.clear_cache() + else: + model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"] # if eos_token was found in one sentence, set sentence to finished if not ignore_eos and eos_token_id_tensor is not None: @@ -1487,9 +1525,18 @@ def greedy_search( hb_profer.step() + if greedy_first: + import habana_frameworks.torch.hpu as torch_hpu + + torch_hpu.synchronize() + print(f"First Token time(greedy):{time.perf_counter()*1000}") + greedy_first = False + if this_peer_finished and not synced_gpus: break + if model_kwargs["use_hpu_graphs"]: + self.clear_cache() hb_profer.stop() if streamer is not None: streamer.end() @@ -1705,6 +1752,7 @@ def sample( hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps) hb_profer.start() this_peer_finished = False # used by synced_gpus only + sample_first = True # auto-regressive generation while True: if lazy_mode: @@ -1805,6 +1853,13 @@ def sample( hb_profer.step() + if sample_first: + import habana_frameworks.torch.hpu as torch_hpu + + torch_hpu.synchronize() + print(f"First Token time(sample):{time.perf_counter()*1000}") + sample_first = False + if this_peer_finished and not synced_gpus: break @@ -2121,8 +2176,8 @@ def expand_if_needed(tensor, new_size, value, dim=-1): hb_profer.start() this_peer_finished = False # used by synced_gpus only - bucket_size = model_kwargs["bucket_size"] - reduce_recompile = model_kwargs["reduce_recompile"] + bucket_size = model_kwargs.get("bucket_size", -1) + reduce_recompile = model_kwargs.get("reduce_recompile", False) prompt_len = input_ids.shape[-1] if bucket_size >= 0: inc = iter(incrementor(bucket_size, prompt_len)) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 9222afd793..0d5cae10ab 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -18,15 +18,19 @@ try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE + + has_fused_rope = True except ImportError: + has_fused_rope = False print("Not using HPU fused kernel for apply_rotary_pos_emb") - FusedRoPE = None try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm + + has_fused_rms_norm = True except ImportError: + has_fused_rms_norm = False print("Not using HPU fused kernel for RMSNorm") - FusedRMSNorm = None try: from habana_frameworks.torch.hpex.kernels import FusedSDPA @@ -60,7 +64,7 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states): The only differences are: - override RMSNorm with Habana fused RMSNorm """ - if hidden_states.device.type == "hpu" and FusedRMSNorm: + if hidden_states.device.type == "hpu" and has_fused_rms_norm: # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype if hidden_states.dtype != self.weight.dtype: orig_dtype = hidden_states.dtype @@ -199,6 +203,9 @@ def pre_attn_forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -209,6 +216,7 @@ def pre_attn_forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ bsz, q_len, _ = hidden_states.size() @@ -249,7 +257,9 @@ def pre_attn_forward( else: kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_customized_rope( + query_states, key_states, cos, sin, position_ids, use_fused_rope=use_fused_rope + ) if past_key_value is not None or reuse_cache: # reuse k, v, self_attention @@ -260,6 +270,13 @@ def pre_attn_forward( key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + if use_cache: if reuse_cache: past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) @@ -279,10 +296,15 @@ def pre_attn_forward( ) else: # first token - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same lenght + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) else: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( @@ -414,6 +436,9 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Copied from LlamaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -423,6 +448,7 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ residual = hidden_states output_pre_attn, self_attn_weights, present_key_value = self.pre_attn( @@ -437,6 +463,9 @@ def forward( reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, + use_fused_rope=use_fused_rope, ) self.self_attn.attention_all_reduce(output_pre_attn) output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual) @@ -465,6 +494,9 @@ def pre_attn( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( @@ -479,6 +511,9 @@ def pre_attn( reuse_cache, use_flash_attention, flash_attention_recompute, + flash_attention_causal_mask, + cache_idx=cache_idx, + use_fused_rope=use_fused_rope, ) return output_attn, attn_weights, present_key_value @@ -527,6 +562,9 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -536,6 +574,7 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -617,6 +656,8 @@ def custom_forward(*inputs): attn_softmax_bf16=attn_softmax_bf16, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + use_fused_rope=use_fused_rope, ) return custom_forward @@ -637,6 +678,9 @@ def custom_forward(*inputs): reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, + use_fused_rope=use_fused_rope, ) hidden_states = layer_outputs[0] @@ -678,6 +722,7 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + self.kv_cache_len = max_seq_len def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) @@ -703,6 +748,9 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + cache_idx: int = None, + use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -725,6 +773,9 @@ def forward( reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + cache_idx=cache_idx, + use_fused_rope=use_fused_rope, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -810,13 +861,15 @@ def prepare_inputs_for_generation( "reuse_cache": reuse_cache, "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), + "cache_idx": kwargs.get("cache_idx"), } ) return model_inputs -def apply_customized_rope(q, k, cos, sin, position_ids): - if q.device.type == "hpu" and FusedRoPE: +def apply_customized_rope(q, k, cos, sin, position_ids, use_fused_rope=True): + if q.device.type == "hpu" and has_fused_rope and use_fused_rope: # TODO: remove `.clone()` when SynapseAI v1.15 is released return FusedRoPE.apply(q, cos.clone(), sin.clone(), position_ids), FusedRoPE.apply( k, cos.clone(), sin.clone(), position_ids diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index c04836a815..09d9a25ce4 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -874,6 +874,10 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if self.model.generation_config.flash_attention_causal_mask: + inputs["flash_attention_causal_mask"] = True + if not self.model.generation_config.use_fused_rope: + inputs["use_fused_rope"] = False # TODO: keep syncs for fast DDP? with self.accelerator.accumulate(model): @@ -1421,7 +1425,7 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training - if self.args.pipelining_fwd_bwd: + if self.args.use_lazy_mode and self.args.pipelining_fwd_bwd: self.htcore.mark_step() self.accelerator.backward(loss) @@ -1626,6 +1630,10 @@ def evaluation_loop( inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if self.model.generation_config.flash_attention_causal_mask: + inputs["flash_attention_causal_mask"] = True + if not self.model.generation_config.use_fused_rope: + inputs["use_fused_rope"] = False # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)