From 3c263752dd0cb183506dc4b4a3eaeeb0b752d838 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Wed, 31 Jan 2024 18:59:03 +0000 Subject: [PATCH 01/33] enable loading falcon-180b ckpt in .safetensors format --- optimum/habana/checkpoint_utils.py | 7 +++++-- optimum/habana/transformers/generation/utils.py | 3 +++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index e0fc139f5d..1403981e5c 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -56,9 +56,12 @@ def get_checkpoint_files(model_name_or_path, local_rank, token=None): """ cached_repo_dir = get_repo_root(model_name_or_path, local_rank=local_rank, token=token) - # Extensions: .bin | .pt + # Extensions: .bin | .pt | .safetensors # Creates a list of paths from all downloaded files in cache dir - file_list = [str(entry) for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") if entry.is_file()] + exts = [".bin", ".pt", ".safetensors"] + file_list = [ + str(entry) for entry in Path(cached_repo_dir).rglob("*") if (entry.is_file() and entry.suffix in exts) + ] return file_list diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 7c9b69e212..00e888a64f 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -542,6 +542,9 @@ def generate( generation_config.ignore_eos = kwargs.get("ignore_eos", lazy_mode) generation_config.validate() model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + if self.config.model_type == "falcon" and "token_type_ids" in kwargs.keys(): + for key in ["token_type_ids"]: + model_kwargs.pop(key, None) self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() From 5357858f99046713810aadffbed1fb4555c7050e Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Thu, 1 Feb 2024 23:04:35 +0000 Subject: [PATCH 02/33] Address comments borrowing transformer's way of reading ckpt file --- optimum/habana/checkpoint_utils.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 1403981e5c..216e8c27ed 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -3,6 +3,7 @@ from pathlib import Path import torch +import transformers from huggingface_hub import snapshot_download from transformers.utils import is_offline_mode @@ -53,16 +54,27 @@ def get_repo_root(model_name_or_path, local_rank=-1, token=None): def get_checkpoint_files(model_name_or_path, local_rank, token=None): """ Gets the list of files for the specified model checkpoint. + Copied from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/modeling_utils.py#L414 """ cached_repo_dir = get_repo_root(model_name_or_path, local_rank=local_rank, token=token) - # Extensions: .bin | .pt | .safetensors - # Creates a list of paths from all downloaded files in cache dir - exts = [".bin", ".pt", ".safetensors"] - file_list = [ - str(entry) for entry in Path(cached_repo_dir).rglob("*") if (entry.is_file() and entry.suffix in exts) - ] - return file_list + index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.WEIGHTS_INDEX_NAME) + safe_index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) + + index_present = os.path.isfile(index_file) + safe_index_present = os.path.isfile(safe_index_file) + + if not index_present and not safe_index_present: + filenames = (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) + raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.") + + load_index = safe_index_file if safe_index_present else index_file + + with open(load_index, "r", encoding="utf-8") as f: + index = json.load(f) + + file_list = set(index["weight_map"].values()) + return [os.path.join(cached_repo_dir, entry) for entry in file_list] def write_checkpoints_json(model_name_or_path, local_rank, f, token=None): From 2c0799c62d2242ed33eeacdf0ba1fc1d2484e6e1 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Mon, 5 Feb 2024 23:41:38 +0000 Subject: [PATCH 03/33] address comments --- optimum/habana/checkpoint_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 216e8c27ed..60bf71d58a 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -65,8 +65,8 @@ def get_checkpoint_files(model_name_or_path, local_rank, token=None): safe_index_present = os.path.isfile(safe_index_file) if not index_present and not safe_index_present: - filenames = (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) - raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.") + filenames = (transformers.modeling_utils.WEIGHTS_INDEX_NAME, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) + raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {cached_repo_dir}.") load_index = safe_index_file if safe_index_present else index_file From fbf1bd2a4f1923338420395c9f8acc2b4117baae Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Thu, 8 Feb 2024 18:29:01 +0000 Subject: [PATCH 04/33] reformatted --- optimum/habana/checkpoint_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 60bf71d58a..18f5cbeffe 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -65,7 +65,10 @@ def get_checkpoint_files(model_name_or_path, local_rank, token=None): safe_index_present = os.path.isfile(safe_index_file) if not index_present and not safe_index_present: - filenames = (transformers.modeling_utils.WEIGHTS_INDEX_NAME, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) + filenames = ( + transformers.modeling_utils.WEIGHTS_INDEX_NAME, + transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME, + ) raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {cached_repo_dir}.") load_index = safe_index_file if safe_index_present else index_file @@ -95,7 +98,9 @@ def model_on_meta(config): def get_optimized_model_name(config): - from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES + from optimum.habana.transformers.generation import ( + MODELS_OPTIMIZED_WITH_STATIC_SHAPES, + ) for model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: if model_type == config.model_type: From e5e92340eb965cb8da4219baa878d12da47dc009 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Wed, 31 Jan 2024 18:59:03 +0000 Subject: [PATCH 05/33] enable loading falcon-180b ckpt in .safetensors format --- optimum/habana/checkpoint_utils.py | 7 +++++-- optimum/habana/transformers/generation/utils.py | 3 +++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 0fdd1c6566..dee27a2dbf 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -56,9 +56,12 @@ def get_checkpoint_files(model_name_or_path, local_rank, token=None): """ cached_repo_dir = get_repo_root(model_name_or_path, local_rank=local_rank, token=token) - # Extensions: .bin | .pt + # Extensions: .bin | .pt | .safetensors # Creates a list of paths from all downloaded files in cache dir - file_list = [str(entry) for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") if entry.is_file()] + exts = [".bin", ".pt", ".safetensors"] + file_list = [ + str(entry) for entry in Path(cached_repo_dir).rglob("*") if (entry.is_file() and entry.suffix in exts) + ] return file_list diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 42da95552e..005c56bdcd 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -500,6 +500,9 @@ def generate( generation_config.ignore_eos = kwargs.get("ignore_eos", lazy_mode) generation_config.validate() model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + if self.config.model_type == "falcon" and "token_type_ids" in kwargs.keys(): + for key in ["token_type_ids"]: + model_kwargs.pop(key, None) self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() From 190e29cb69821a256fe955b54cb93b3aeabe79b3 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Thu, 1 Feb 2024 23:04:35 +0000 Subject: [PATCH 06/33] Address comments borrowing transformer's way of reading ckpt file --- optimum/habana/checkpoint_utils.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index dee27a2dbf..1992e03b5a 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -3,6 +3,7 @@ from pathlib import Path import torch +import transformers from huggingface_hub import snapshot_download from transformers.utils import is_offline_mode @@ -53,16 +54,27 @@ def get_repo_root(model_name_or_path, local_rank=-1, token=None): def get_checkpoint_files(model_name_or_path, local_rank, token=None): """ Gets the list of files for the specified model checkpoint. + Copied from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/modeling_utils.py#L414 """ cached_repo_dir = get_repo_root(model_name_or_path, local_rank=local_rank, token=token) - # Extensions: .bin | .pt | .safetensors - # Creates a list of paths from all downloaded files in cache dir - exts = [".bin", ".pt", ".safetensors"] - file_list = [ - str(entry) for entry in Path(cached_repo_dir).rglob("*") if (entry.is_file() and entry.suffix in exts) - ] - return file_list + index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.WEIGHTS_INDEX_NAME) + safe_index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) + + index_present = os.path.isfile(index_file) + safe_index_present = os.path.isfile(safe_index_file) + + if not index_present and not safe_index_present: + filenames = (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) + raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.") + + load_index = safe_index_file if safe_index_present else index_file + + with open(load_index, "r", encoding="utf-8") as f: + index = json.load(f) + + file_list = set(index["weight_map"].values()) + return [os.path.join(cached_repo_dir, entry) for entry in file_list] def write_checkpoints_json(model_name_or_path, local_rank, f, token=None): From 34fdc00395e09025cadf6050f5c2666fb22df62a Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Mon, 5 Feb 2024 23:41:38 +0000 Subject: [PATCH 07/33] address comments --- optimum/habana/checkpoint_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 1992e03b5a..10951cc189 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -65,8 +65,8 @@ def get_checkpoint_files(model_name_or_path, local_rank, token=None): safe_index_present = os.path.isfile(safe_index_file) if not index_present and not safe_index_present: - filenames = (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) - raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.") + filenames = (transformers.modeling_utils.WEIGHTS_INDEX_NAME, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) + raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {cached_repo_dir}.") load_index = safe_index_file if safe_index_present else index_file From f5c3029eac5cbbc19707227183ea2746401adff2 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Wed, 14 Feb 2024 01:01:02 +0000 Subject: [PATCH 08/33] Update ckpt loading PR#15 reads a set of ckpt file names from the index json file. When OH downloads files from the hub instead of loading from a cache dir, get_repo_root() skips downloading the index json file. Thus the PR#15 fails to load file names. This PR scans the path and returns a list of names that matches the pattern --- optimum/habana/checkpoint_utils.py | 44 ++++++++++++++---------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 09efa51e9a..9b85621d1f 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -4,7 +4,7 @@ import torch import transformers -from huggingface_hub import snapshot_download +from huggingface_hub import list_repo_files, snapshot_download from transformers.utils import is_offline_mode @@ -22,7 +22,12 @@ def get_repo_root(model_name_or_path, local_rank=-1, token=None): print("Offline mode: forcing local_files_only=True") # Only download PyTorch weights by default - allow_patterns = ["*.bin"] + if any(".bin" in filename for filename in list_repo_files(model_name_or_path, token=token)): + allow_patterns = ["*.bin"] + elif any( + ".safetensors" in filename for filename in list_repo_files(model_name_or_path, token=token) + ): # Some models like Falcon-180b are in only safetensors format + allow_patterns = ["*.safetensors"] # Download only on first process if local_rank in [-1, 0]: @@ -52,32 +57,25 @@ def get_repo_root(model_name_or_path, local_rank=-1, token=None): def get_checkpoint_files(model_name_or_path, local_rank, token=None): - """ - Gets the list of files for the specified model checkpoint. - Copied from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/modeling_utils.py#L414 - """ cached_repo_dir = get_repo_root(model_name_or_path, local_rank=local_rank, token=token) - index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.WEIGHTS_INDEX_NAME) - safe_index_file = os.path.join(cached_repo_dir, transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME) - - index_present = os.path.isfile(index_file) - safe_index_present = os.path.isfile(safe_index_file) + # Extensions: .bin | .safetensors | .pt + # Creates a list of paths from all downloaded files in cache dir - if not index_present and not safe_index_present: - filenames = ( - transformers.modeling_utils.WEIGHTS_INDEX_NAME, - transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME, - ) - raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {cached_repo_dir}.") - - load_index = safe_index_file if safe_index_present else index_file + if any(file.suffix == ".bin" for file in Path(cached_repo_dir).rglob("*")): + (name, ext) = os.path.splitext(transformers.modeling_utils.WEIGHTS_NAME) + elif any(file.suffix == ".safetensors" for file in Path(cached_repo_dir).rglob("*")): + (name, ext) = os.path.splitext(transformers.modeling_utils.SAFE_WEIGHTS_NAME) + else: + (name, ext) = ("*", ".pt") - with open(load_index, "r", encoding="utf-8") as f: - index = json.load(f) + file_list = [ + str(entry) + for entry in Path(cached_repo_dir).rglob("*") + if (entry.is_file() and entry.name.startswith(name) and entry.name.endswith(ext)) + ] - file_list = set(index["weight_map"].values()) - return [os.path.join(cached_repo_dir, entry) for entry in file_list] + return file_list def write_checkpoints_json(model_name_or_path, local_rank, f, token=None): From 0064711091d09af51edd2264152e36bd8c52b200 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Wed, 14 Feb 2024 17:57:24 +0000 Subject: [PATCH 09/33] import modeling_utils from transformers --- optimum/habana/checkpoint_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/habana/checkpoint_utils.py b/optimum/habana/checkpoint_utils.py index 9b85621d1f..6b1469cf64 100644 --- a/optimum/habana/checkpoint_utils.py +++ b/optimum/habana/checkpoint_utils.py @@ -3,8 +3,8 @@ from pathlib import Path import torch -import transformers from huggingface_hub import list_repo_files, snapshot_download +from transformers import modeling_utils from transformers.utils import is_offline_mode @@ -63,9 +63,9 @@ def get_checkpoint_files(model_name_or_path, local_rank, token=None): # Creates a list of paths from all downloaded files in cache dir if any(file.suffix == ".bin" for file in Path(cached_repo_dir).rglob("*")): - (name, ext) = os.path.splitext(transformers.modeling_utils.WEIGHTS_NAME) + (name, ext) = os.path.splitext(modeling_utils.WEIGHTS_NAME) elif any(file.suffix == ".safetensors" for file in Path(cached_repo_dir).rglob("*")): - (name, ext) = os.path.splitext(transformers.modeling_utils.SAFE_WEIGHTS_NAME) + (name, ext) = os.path.splitext(modeling_utils.SAFE_WEIGHTS_NAME) else: (name, ext) = ("*", ".pt") From d90a8e9df1c075565aefde50c76b29a79239a537 Mon Sep 17 00:00:00 2001 From: Local Lab User Date: Wed, 6 Mar 2024 20:35:31 +0200 Subject: [PATCH 10/33] enable Falcon FP8 inference --- .../maxabs_measure_falcon.json | 10 + examples/text-generation/utils.py | 2 +- .../habana/transformers/generation/utils.py | 2 +- optimum/habana/transformers/modeling_utils.py | 11 +- .../habana/transformers/models/__init__.py | 5 +- .../transformers/models/falcon/__init__.py | 5 +- .../models/falcon/modeling_falcon.py | 941 ++++++++++++++---- 7 files changed, 762 insertions(+), 214 deletions(-) create mode 100644 examples/text-generation/quantization_config/maxabs_measure_falcon.json diff --git a/examples/text-generation/quantization_config/maxabs_measure_falcon.json b/examples/text-generation/quantization_config/maxabs_measure_falcon.json new file mode 100644 index 0000000000..32e9e2209e --- /dev/null +++ b/examples/text-generation/quantization_config/maxabs_measure_falcon.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "MEASURE", + "observer": "maxabs", + "whitelist": {"types": [], "names": []}, + "blacklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx", + "measure_exclude": "NONE" +} diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index e8c847c2f7..d4d9dab871 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -237,7 +237,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): model = deepspeed.init_inference(model, **ds_inference_kwargs) model = model.module - if model.config.model_type == "llama": + if model.config.model_type == "llama" or "falcon": patch_scoped_linear_all_reduce(model) if args.quant_config: diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index ed83b65e3d..f9162cf0ed 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -736,7 +736,7 @@ def generate( ) model_kwargs["kv_cache_len"] = calculated_max_length - if self.config.model_type in ["llama"]: + if self.config.model_type in ["llama", "falcon"]: if self.config.max_position_embeddings < calculated_max_length: unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 384f033cd6..813d85dfca 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -26,7 +26,10 @@ GaudiBloomMLP, GaudiCodeGenAttention, GaudiCodeGenForCausalLM, + GaudiFalconAttention, + GaudiFalconDecoderLayer, GaudiFalconForCausalLM, + GaudiFalconMLP, GaudiFalconModel, GaudiGPT2Attention, GaudiGPT2LMHeadModel, @@ -80,10 +83,7 @@ gaudi_conv1d_forward, gaudi_esm_for_protein_folding_forward, gaudi_esmfolding_trunk_forward, - gaudi_falcon_attention_forward, gaudi_falcon_attention_split_heads, - gaudi_falcon_decoder_layer_forward, - gaudi_generate_speech, gaudi_get_extended_attention_mask, gaudi_gpt2_block_forward, gaudi_gpt2_forward, @@ -287,10 +287,11 @@ def adapt_transformers_to_gaudi(): transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward # Optimization for falcon generation on Gaudi + transformers.models.falcon.modeling_falcon.FalconAttention = GaudiFalconAttention transformers.models.falcon.modeling_falcon.FalconForCausalLM = GaudiFalconForCausalLM + transformers.models.falcon.modeling_falcon.FalconMLP = GaudiFalconMLP transformers.models.falcon.modeling_falcon.FalconModel = GaudiFalconModel - transformers.models.falcon.modeling_falcon.FalconDecoderLayer.forward = gaudi_falcon_decoder_layer_forward - transformers.models.falcon.modeling_falcon.FalconAttention.forward = gaudi_falcon_attention_forward + transformers.models.falcon.modeling_falcon.FalconDecoderLayer = GaudiFalconDecoderLayer transformers.models.falcon.modeling_falcon.FalconAttention._split_heads = gaudi_falcon_attention_split_heads # Optimization for t5 on Gaudi diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index a369627d2f..a43f9f7b5a 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -43,11 +43,12 @@ gaudi_rot_vec_mul, ) from .falcon import ( + GaudiFalconAttention, + GaudiFalconDecoderLayer, GaudiFalconForCausalLM, + GaudiFalconMLP, GaudiFalconModel, - gaudi_falcon_attention_forward, gaudi_falcon_attention_split_heads, - gaudi_falcon_decoder_layer_forward, ) from .gpt2 import GaudiGPT2Attention, GaudiGPT2LMHeadModel, gaudi_gpt2_block_forward, gaudi_gpt2_forward from .gpt_bigcode import ( diff --git a/optimum/habana/transformers/models/falcon/__init__.py b/optimum/habana/transformers/models/falcon/__init__.py index 44ac5451f6..00c73ad110 100644 --- a/optimum/habana/transformers/models/falcon/__init__.py +++ b/optimum/habana/transformers/models/falcon/__init__.py @@ -1,7 +1,8 @@ from .modeling_falcon import ( + GaudiFalconAttention, + GaudiFalconDecoderLayer, GaudiFalconForCausalLM, + GaudiFalconMLP, GaudiFalconModel, - gaudi_falcon_attention_forward, gaudi_falcon_attention_split_heads, - gaudi_falcon_decoder_layer_forward, ) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 98e3555e95..f91b1a704c 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -2,6 +2,7 @@ import math import warnings from typing import Optional, Tuple, Union +import os import torch @@ -27,6 +28,7 @@ import habana_frameworks.torch.core as htcore +from torch import nn from torch.nn import CrossEntropyLoss from torch.nn import functional as F from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa @@ -34,13 +36,19 @@ BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, ) +from transformers.models.falcon.configuration_falcon import FalconConfig from transformers.models.falcon.modeling_falcon import ( + FalconAttention, + FalconDecoderLayer, FalconForCausalLM, + FalconMLP, + FalconLinear, FalconModel, + FalconRotaryEmbedding, apply_rotary_pos_emb, build_alibi_tensor, - dropout_add, ) +from ..modeling_all_models import ScopedLinearAllReduce from transformers.utils import logging from ...modeling_attn_mask_utils import ( @@ -52,6 +60,25 @@ logger = logging.get_logger(__name__) +def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + """ + Dropout add function + + Args: + x (`torch.tensor`, *required*): + input tensor + residual (`torch.tensor`, *required*): + residual tensor + prob (`float`, *required*): + dropout probability + training (`bool`, *required*): + training mode + """ + out = F.dropout(x, p=prob, training=training) + out.add_(residual) + return out + + def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: # TODO: remove `.clone()` when it is fixed in SynapseAI @@ -111,257 +138,721 @@ def gaudi_falcon_attention_split_heads( return query, key, value -def gaudi_falcon_attention_forward( - self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -): - """ - Copied from FalconAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py - The only differences are: - - add new args token_idx and position_ids - - replace F.scaled_dot_product_attention with Habana torch's version - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) +class Softmax(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim=None, invAttnHead=None): + return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead) + + +class Matmul(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + return torch.matmul(*args, **kwargs) + + +# ScaledDotProductAttention is based on torch.nn.functional.scaled_dot_product_attention +class ScaledDotProductAttention(nn.Module): + def __init__(self, config: FalconConfig): + super().__init__() + self.head_dim = config.hidden_size // config.num_attention_heads + self.bmm1 = Matmul() + self.bmm2 = Matmul() + self.softmax = Softmax() + + def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(self.head_dim) + invAttnHead = torch.tensor(scale_factor, dtype=torch.float32).to("hpu") + + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + + attn_weight = self.bmm1(query, key.transpose(-2, -1)) + + attn_weight += attn_mask + attn_weight = self.softmax(attn_weight, dim=-1, invAttnHead=invAttnHead) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return self.bmm2(attn_weight, value) + + +def update(prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + cur = cur.to(dtype=prev.dtype) + + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + + if cur.shape[-2] > 1 and cur.shape[-2] <= prev.shape[-2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + prev_cast = prev.to(orig_cur.dtype) + return prev_cast + else: + return torch.cat((prev, cur), dim=dim) - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, query_length, _, _ = query_layer.shape +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 - query_layer = query_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) - value_layer = value_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + + def update(self, prev, cur, dim, idx, inp_seq_len): + return update(prev, cur, dim, idx, inp_seq_len) - kv_seq_len = key_layer.shape[-2] - if layer_past is not None: - if token_idx is not None: - # When token_idx is used, - # past_kv_length = 0 - # static seq len = (input token len + max output token len) - kv_seq_len = layer_past[0].shape[-2] + +class GaudiFalconAttention(FalconAttention): + def __init__(self, config: FalconConfig): + super().__init__(config) + + if config.new_decoder_architecture: + qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim + elif config.multi_query: + qkv_out_dim = self.hidden_size + 2 * self.head_dim else: - kv_seq_len += layer_past[0].shape[-2] - if alibi is None: - cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) - query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids) - - if layer_past is not None: - past_key, past_value = layer_past - if token_idx is not None: - past_key.index_copy_(-2, token_idx - 1, key_layer) - past_value.index_copy_(-2, token_idx - 1, value_layer) - key_layer = past_key - value_layer = past_value + qkv_out_dim = 3 * self.hidden_size + + if os.getenv("QUANT_CONFIG", ""): + self.sdpa = ScaledDotProductAttention(config) + + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 + self.max_position_embeddings = config.max_position_embeddings + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + if self.config.new_decoder_architecture: + cache_shape = (batch_size, self.num_heads, max_seq_len, self.head_dim) else: - # concatenate along seq_length dimension: - # - key: [batch_size, self.num_heads, kv_length, head_dim] - # - value: [batch_size, self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=-2) - value_layer = torch.cat((past_value, value_layer), dim=-2) - - kv_length = key_layer.shape[-2] - if use_cache: - present = (key_layer, value_layer) - else: - present = None + cache_shape = (batch_size, 1, max_seq_len, self.head_dim) + device = self.query_key_value.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) + + def update_sincos_cache(self, seq_len): + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + self.rotary_emb._set_cos_sin_cache( + seq_len, self.query_key_value.weight.device, self.query_key_value.weight.dtype + ) - if alibi is None: - if output_attentions: - attention_scores = query_layer @ key_layer.transpose(-1, -2) - attention_scores /= math.sqrt(self.head_dim) + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: int = None, + **kwargs, + ): + """ + Copied from FalconAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py + The only differences are: + - add new args token_idx and position_ids + - replace F.scaled_dot_product_attention with Habana torch's version + - add new args reuse_cache + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) - # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). - attn_output = attention_scores @ value_layer + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + + kv_seq_len = key_layer.shape[-2] + if layer_past is not None: + if token_idx is not None: + if reuse_cache: + kv_seq_len = layer_past[0][-2] # layer_past conveys only shapes without kv tensors + else: + kv_seq_len = layer_past[0].shape[-2] + else: + kv_length += layer_past[0].shape[-2] + if alibi is None: + cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) + query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids) + + if layer_past is not None or reuse_cache: + if reuse_cache: + key_layer = self.k_cache(key_layer, -2, token_idx) + value_layer = self.v_cache(value_layer, -2, token_idx) + else: + key_layer = update( + layer_past[0], key_layer, -2, token_idx, self.inp_seq_len + ) # k_layer bs*1, q_len, head_dim + value_layer = update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) + + if cache_idx is not None and query_length == 1: + key_layer = key_layer[:, :, :cache_idx, :] + value_layer = value_layer[:, :, :cache_idx, :] + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_layer.shape[-2] + + kv_length = key_layer.shape[-2] + if use_cache: + if reuse_cache: + present = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + present = (key_layer, value_layer) else: - if FusedSDPA: - with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): - attn_output = FusedSDPA.apply( + present = None + + if alibi is None: + if output_attentions: + attention_scores = query_layer @ key_layer.transpose(-1, -2) + attention_scores /= math.sqrt(self.head_dim) + + attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) + attn_output = attention_scores @ value_layer + else: + if FusedSDPA: + if os.getenv("QUANT_CONFIG", ""): + attn_output = self.sdpa( + query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False + ) + else: + with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + attn_output = FusedSDPA.apply( + query_layer, + key_layer, + value_layer, + attention_mask, + 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + self.is_causal and attention_mask is None and query_length > 1, + ) + else: + # Workaround util scaled_dot_product_attention support broadcast. + if self.training is True and query_layer.shape != key_layer.shape: + key_layer = torch.broadcast_to(key_layer, query_layer.shape) + value_layer = torch.broadcast_to(value_layer, query_layer.shape) + attn_output = F.scaled_dot_product_attention( query_layer, key_layer, value_layer, attention_mask, 0.0, # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - self.is_causal and attention_mask is None and query_length > 1, + is_causal=self.is_causal and attention_mask is None and query_length > 1, ) + # Performance improvement for HPU + if self.training is True and htcore: + htcore.mark_step() + attention_scores = None + + attn_output = attn_output.view(batch_size, -1, query_length, self.head_dim) + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, query_length, -1) + + attn_output = self.dense(attn_output) + + if output_attentions: + return attn_output, present, attention_scores else: - # Workaround util scaled_dot_product_attention support broadcast. - if self.training is True and query_layer.shape != key_layer.shape: - key_layer = torch.broadcast_to(key_layer, query_layer.shape) - value_layer = torch.broadcast_to(value_layer, query_layer.shape) - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - # Performance improvement for HPU - if self.training is True and htcore: - htcore.mark_step() - attention_scores = None + return attn_output, present + + else: + if self._use_sdpa and not output_attentions and head_mask is None: + if FusedSDPA: + with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + attn_output = FusedSDPA.apply( + query_layer, + key_layer, + value_layer, + attention_mask, + self.attention_dropout.p if self.training else 0.0, + self.is_causal and attention_mask is None and query_length > 1, + ) + else: + attn_output = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.attention_dropout.p if self.training else 0.0, + is_causal=self.is_causal and attention_mask is None and query_length > 1, + ) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - attn_output = attn_output.view(batch_size, -1, query_length, self.head_dim) - attn_output = attn_output.permute(0, 2, 1, 3) - attn_output = attn_output.reshape(batch_size, query_length, -1) + attn_output = self.dense(attn_output) + else: + matmul_result = query_layer @ key_layer.transpose(-1, -2) - attn_output = self.dense(attn_output) + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) - if output_attentions: - return attn_output, present, attention_scores + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16 or input_dtype == torch.bfloat16: + attention_scores = attention_scores.to(torch.float32) + + attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) + attention_logits *= self.inv_norm_factor + attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size, num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) + + # change view [batch_size, q_length, num_heads * head_dim] + attn_output = self._merge_heads(attn_output) + + attn_output = self.dense(attn_output) + + if output_attentions: + return attn_output, present, attention_probs + else: + return attn_output, present + + def pre_attn_forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: int = None, + **kwargs, + ): + """ + Copied from FalconAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py + The only differences are: + - add new args token_idx and position_ids + - replace F.scaled_dot_product_attention with Habana torch's version + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + + kv_seq_len = key_layer.shape[-2] + if layer_past is not None: + if token_idx is not None: + if reuse_cache: + kv_seq_len = layer_past[0][-2] + else: + kv_seq_len = layer_past[0].shape[-2] + else: + kv_seq_len += layer_past[0].shape[-2] + + if alibi is None: + cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) + query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids) + + if layer_past is not None or reuse_cache: + if reuse_cache: + key_layer = self.k_cache(key_layer, -2, token_idx) + value_layer = self.v_cache(value_layer, -2, token_idx) + else: + key_layer = update( + layer_past[0], key_layer, -2, token_idx, self.inp_seq_len + ) # k_layer bs*1, q_len, head_dim + value_layer = update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) + + if cache_idx is not None and query_length == 1: + key_layer = key_layer[:, :, :cache_idx, :] + value_layer = value_layer[:, :, :cache_idx, :] + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_layer.shape[-2] + + kv_length = key_layer.shape[-2] + if use_cache: + if reuse_cache: + present = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + present = (key_layer, value_layer) else: - return attn_output, present + present = None - else: - if self._use_sdpa and not output_attentions and head_mask is None: - if FusedSDPA: - with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): - attn_output = FusedSDPA.apply( + if alibi is None: + if output_attentions: + attention_scores = query_layer @ key_layer.transpose(-1, -2) + attention_scores /= math.sqrt(self.head_dim) + + attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) + # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). + attn_output = attention_scores @ value_layer + else: + if FusedSDPA: + if os.getenv("QUANT_CONFIG", ""): + attn_output = self.sdpa( + query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False + ) + + else: + with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + attn_output = FusedSDPA.apply( + query_layer, + key_layer, + value_layer, + attention_mask, + 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + self.is_causal and attention_mask is None and query_length > 1, + ) + else: + # Workaround util scaled_dot_product_attention support broadcast. + if self.training is True and query_layer.shape != key_layer.shape: + key_layer = torch.broadcast_to(key_layer, query_layer.shape) + value_layer = torch.broadcast_to(value_layer, query_layer.shape) + attn_output = F.scaled_dot_product_attention( query_layer, key_layer, value_layer, attention_mask, - self.attention_dropout.p if self.training else 0.0, - self.is_causal and attention_mask is None and query_length > 1, + 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal=self.is_causal and attention_mask is None and query_length > 1, ) - else: - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=attention_mask, - dropout_p=self.attention_dropout.p if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + # Performance improvement for HPU + if self.training is True and htcore: + htcore.mark_step() + attention_scores = None + + attn_output = attn_output.view(batch_size, -1, query_length, self.head_dim) + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, query_length, -1) attn_output = self.dense(attn_output) + + if output_attentions: + return attn_output, present, attention_scores + else: + return attn_output, present + else: - matmul_result = query_layer @ key_layer.transpose(-1, -2) + if self._use_sdpa and not output_attentions and head_mask is None: + if FusedSDPA: + with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + attn_output = FusedSDPA.apply( + query_layer, + key_layer, + value_layer, + attention_mask, + self.attention_dropout.p if self.training else 0.0, + self.is_causal and attention_mask is None and query_length > 1, + ) + else: + attn_output = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.attention_dropout.p if self.training else 0.0, + is_causal=self.is_causal and attention_mask is None and query_length > 1, + ) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) + attn_output = self.dense(attn_output) + else: + matmul_result = query_layer @ key_layer.transpose(-1, -2) - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attention_scores.dtype - # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` - if input_dtype == torch.float16 or input_dtype == torch.bfloat16: - attention_scores = attention_scores.to(torch.float32) + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) - attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) - attention_logits *= self.inv_norm_factor - attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) - # [batch_size, num_heads, q_length, kv_length] - attention_probs = self.attention_dropout(attention_probs) + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16 or input_dtype == torch.bfloat16: + attention_scores = attention_scores.to(torch.float32) - if head_mask is not None: - attention_probs = attention_probs * head_mask + attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) + attention_logits *= self.inv_norm_factor + attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) - # change view [batch_size, num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) + if head_mask is not None: + attention_probs = attention_probs * head_mask - # matmul: [batch_size * num_heads, q_length, head_dim] - attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) + # change view [batch_size, num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) - # change view [batch_size, q_length, num_heads * head_dim] - attn_output = self._merge_heads(attn_output) + # matmul: [batch_size * num_heads, q_length, head_dim] + attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) - attn_output = self.dense(attn_output) + # change view [batch_size, q_length, num_heads * head_dim] + attn_output = self._merge_heads(attn_output) - if output_attentions: - return attn_output, present, attention_probs - else: - return attn_output, present - - -def gaudi_falcon_decoder_layer_forward( - self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -): - """ - Copied from FalconDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py - The only differences are: - - add new args token_idx and position_ids - - add token_idx and position_ids into attention inputs - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) + attn_output = self.dense(attn_output) - residual = hidden_states + if output_attentions: + return attn_output, present, attention_probs + else: + return attn_output, present - if self.config.new_decoder_architecture: - attention_layernorm_out = self.ln_attn(hidden_states) - mlp_layernorm_out = self.ln_mlp(hidden_states) - else: - attention_layernorm_out = self.input_layernorm(hidden_states) - - # Self attention. - attn_outputs = self.self_attention( - attention_layernorm_out, - layer_past=layer_past, - attention_mask=attention_mask, - position_ids=position_ids, - alibi=alibi, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - token_idx=token_idx, + def attention_all_reduce(self, attn_output): + if hasattr(self.dense, "all_reduce"): + self.dense.all_reduce(attn_output) + + def post_attn_forward(self, attn_output): + if hasattr(self.dense, "all_reduce"): + self.dense.post_all_reduce(attn_output) + return attn_output + + +class GaudiFalconMLP(FalconMLP): + def pre_mlp_forward(self, x): + x = self.act(self.dense_h_to_4h(x)) + x = self.dense_4h_to_h(x) + return x + + def mlp_all_reduce(self, x): + if hasattr(self.dense_4h_to_h, "all_reduce"): + self.dense_4h_to_h.all_reduce(x) + + def post_mlp_forward(self, x): + if hasattr(self.dense_4h_to_h, "all_reduce"): + self.dense_4h_to_h.post_all_reduce(x) + return x + + +class GaudiFalconDecoderLayer(FalconDecoderLayer): + def __init__(self, config: FalconConfig): + super().__init__(config) + self.self_attention = GaudiFalconAttention(config) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attention.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def update_sincos_cache(self, seq_len): + self.self_attention.update_sincos_cache(seq_len) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: int = None, **kwargs, - ) + ): + """ + Copied from FalconDecoderLayer: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py + The only differences are: + - add new args token_idx and position_ids + - add token_idx and position_ids into attention inputs + - add new args reuse_cache + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + if not self.config.new_decoder_architecture: + residual = hidden_states + + attention_layernorm_out = self.input_layernorm(hidden_states) + + # Self attention. + attn_outputs = self.self_attention( + attention_layernorm_out, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + reuse_cache=reuse_cache, + **kwargs, + ) + + attention_output = attn_outputs[0] - attention_output = attn_outputs[0] + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual = dropout_add( + attention_output, residual, self.config.attention_dropout, training=self.training + ) + mlp_layernorm_out = self.post_attention_layernorm(residual) - if not self.config.new_decoder_architecture: - if self.config.parallel_attn: - mlp_layernorm_out = attention_layernorm_out + outputs = attn_outputs[1:] else: - residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training) - mlp_layernorm_out = self.post_attention_layernorm(residual) + residual = hidden_states + hidden_states, present, attn_scores, mlp_layernorm_out = ( + self.pre_attn( # layernorm+attention before AllReduce + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, + **kwargs, + ) + ) - outputs = attn_outputs[1:] + self.self_attention.attention_all_reduce(hidden_states) + hidden_states = self.self_attention.post_attn_forward( + hidden_states + ) - # MLP. - mlp_output = self.mlp(mlp_layernorm_out) + attention_output = hidden_states - if self.config.new_decoder_architecture or self.config.parallel_attn: - mlp_output += attention_output + outputs = (present, attn_scores) - output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) + # MLP + if not self.config.new_decoder_architecture: + hidden_states = self.mlp(mlp_layernorm_out) + else: + hidden_states = self.mlp.pre_mlp_forward(mlp_layernorm_out) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.mlp.post_mlp_forward(hidden_states) - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] + if self.config.new_decoder_architecture or self.config.parallel_attn: + hidden_states += attention_output - return outputs # hidden_states, present, attentions + output = dropout_add(hidden_states, residual, self.config.hidden_dropout, training=self.training) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + def pre_attn( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: int = None, + ): + if self.config.new_decoder_architecture: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + mlp_layernorm_out = None + + # Self attention. + attn_scores = None + if output_attentions: + attn_outputs, present, attn_scores = self.self_attention.pre_attn_forward( + attention_layernorm_out, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, + ) + else: + attn_outputs, present = self.self_attention.pre_attn_forward( + attention_layernorm_out, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, + ) + + return attn_outputs, present, attn_scores, mlp_layernorm_out class GaudiFalconModel(FalconModel): @@ -375,6 +866,14 @@ class GaudiFalconModel(FalconModel): - use old version of _make_causal_mask to workaround toch.triu that is not supported in Synapse """ + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + for layer in self.h: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def update_sincos_cache(self, seq_len): + for layer in self.h: + layer.update_sincos_cache(seq_len) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -388,6 +887,8 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -425,8 +926,11 @@ def forward( # Compute alibi tensor: check build_alibi_tensor documentation past_key_values_length = 0 - if past_key_values[0] is not None and token_idx is None: - past_key_values_length = past_key_values[0][0].shape[-2] + if past_key_values[0] is not None and token_idx is None: ### non static input + if reuse_cache: + past_key_values_length = past_key_values[0][0][-2] + else: + past_key_values_length = past_key_values[0][0].shape[-2] if self.use_alibi: mask = ( @@ -489,6 +993,7 @@ def forward( attention_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) + else: # 4d mask is passed through the layers attention_mask = _gaudi_prepare_4d_causal_attention_mask( @@ -501,6 +1006,7 @@ def forward( # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + htcore.mark_step() for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -529,6 +1035,8 @@ def forward( output_attentions=output_attentions, alibi=alibi, token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, ) hidden_states = outputs[0] @@ -563,8 +1071,16 @@ class GaudiFalconForCausalLM(FalconForCausalLM): - add token_idx and position_ids into model inputs - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx + - add new args reuse_cache """ + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.transformer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + self.kv_cache_len = max_seq_len + + def update_sincos_cache(self, seq_len): + self.transformer.update_sincos_cache(seq_len) + def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, @@ -574,6 +1090,7 @@ def prepare_inputs_for_generation( token_idx: Optional[torch.Tensor] = None, **kwargs, ) -> dict: + reuse_cache = kwargs.get("reuse_cache") if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) @@ -588,6 +1105,10 @@ def prepare_inputs_for_generation( remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. if ( @@ -612,6 +1133,8 @@ def prepare_inputs_for_generation( "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, + "reuse_cache": reuse_cache, + "cache_idx": kwargs.get("cache_idx"), } def forward( @@ -628,6 +1151,9 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + trim_logits: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -649,9 +1175,18 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, ) hidden_states = transformer_outputs[0] + _, seq_len, _ = hidden_states.shape + if seq_len > 1 and trim_logits and not self.training: + if token_idx is not None: + hidden_states = hidden_states.index_select(1, token_idx - 1) + else: + hidden_states = hidden_states[:, -1:, :] + lm_logits = self.lm_head(hidden_states) loss = None From ffc8f4cd7041c70e9c38bfc8af16f244d9c32dd5 Mon Sep 17 00:00:00 2001 From: Local Lab User Date: Fri, 8 Mar 2024 01:53:15 +0200 Subject: [PATCH 11/33] added example command in readme, code cleanup --- examples/text-generation/README.md | 35 ++- .../maxabs_measure_falcon.json | 10 - .../models/falcon/modeling_falcon.py | 282 ++---------------- 3 files changed, 57 insertions(+), 270 deletions(-) delete mode 100644 examples/text-generation/quantization_config/maxabs_measure_falcon.json diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 83a481970c..ef909ad324 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -241,7 +241,7 @@ While `--bucket_size` works for any model without model file changes, an even mo ### Running with FP8 -Llama2-70b, Llama2-7b and Mixtral-8x7B in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch. +Llama2-70b, Llama2-7b, Mixtral-8x7B, Falcon-7B, Falcon-40B, and Falcon-180B in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch. More information on enabling fp8 in SynapseAI is available here: https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html @@ -321,6 +321,39 @@ QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generati --bf16 \ --fp8 ``` + +Here is an example to measure the tensor quantization statistics on Falcon-180B with 8 cards: +```bash +QUANT_CONFIG=./quantization_config/maxabs_measure_include_outputs.json python ../gaudi_spawn.py \ +--use_deepspeed --world_size 8 run_generation.py \ +--model_name_or_path tiiuae/falcon-180B \ +--use_hpu_graphs \ +--use_kv_cache \ +--limit_hpu_graphs \ +--max_input_tokens 128 \ +--max_new_tokens 128 \ +--batch_size 1 \ +--bf16 \ +--reuse_cache \ +--trim_logits +``` + +Here is an example to quantize the model based on previous measurements for Falcon-180B with 8 cards: +```bash +QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ +--use_deepspeed --world_size 8 run_generation.py \ +--model_name_or_path tiiuae/falcon-180B \ +--use_hpu_graphs \ +--use_kv_cache \ +--limit_hpu_graphs \ +--max_input_tokens 128 \ +--max_new_tokens 2048 \ +--batch_size 110 \ +--bf16 \ +--reuse_cache \ +--trim_logits \ +--fp8 +``` `--fp8` is required to enable quantization in fp8. diff --git a/examples/text-generation/quantization_config/maxabs_measure_falcon.json b/examples/text-generation/quantization_config/maxabs_measure_falcon.json deleted file mode 100644 index 32e9e2209e..0000000000 --- a/examples/text-generation/quantization_config/maxabs_measure_falcon.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "method": "HOOKS", - "mode": "MEASURE", - "observer": "maxabs", - "whitelist": {"types": [], "names": []}, - "blacklist": {"types": [], "names": []}, - "dump_stats_path": "./hqt_output/measure", - "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx", - "measure_exclude": "NONE" -} diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index f91b1a704c..a3809a7f8f 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -1,8 +1,8 @@ import contextlib import math +import os import warnings from typing import Optional, Tuple, Union -import os import torch @@ -42,13 +42,10 @@ FalconDecoderLayer, FalconForCausalLM, FalconMLP, - FalconLinear, FalconModel, - FalconRotaryEmbedding, apply_rotary_pos_emb, build_alibi_tensor, ) -from ..modeling_all_models import ScopedLinearAllReduce from transformers.utils import logging from ...modeling_attn_mask_utils import ( @@ -62,17 +59,8 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: """ - Dropout add function - - Args: - x (`torch.tensor`, *required*): - input tensor - residual (`torch.tensor`, *required*): - residual tensor - prob (`float`, *required*): - dropout probability - training (`bool`, *required*): - training mode + Copied from transformers.models.falcon.modeling_falcon/dropout_add + https://github.com/huggingface/transformers/blob/b338a6c3b8eda29610d4d472cad8cd87cbfdaaed/src/transformers/models/falcon/modeling_falcon.py#L248 """ out = F.dropout(x, p=prob, training=training) out.add_(residual) @@ -170,6 +158,7 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa if is_causal: assert attn_mask is None + attn_bias = torch.zeros(L, S, dtype=query.dtype) temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) @@ -239,13 +228,6 @@ class GaudiFalconAttention(FalconAttention): def __init__(self, config: FalconConfig): super().__init__(config) - if config.new_decoder_architecture: - qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim - elif config.multi_query: - qkv_out_dim = self.hidden_size + 2 * self.head_dim - else: - qkv_out_dim = 3 * self.hidden_size - if os.getenv("QUANT_CONFIG", ""): self.sdpa = ScaledDotProductAttention(config) @@ -274,197 +256,6 @@ def update_sincos_cache(self, seq_len): seq_len, self.query_key_value.weight.device, self.query_key_value.weight.dtype ) - def forward( - self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - token_idx: Optional[torch.Tensor] = None, - reuse_cache: Optional[bool] = False, - cache_idx: int = None, - **kwargs, - ): - """ - Copied from FalconAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py - The only differences are: - - add new args token_idx and position_ids - - replace F.scaled_dot_product_attention with Habana torch's version - - add new args reuse_cache - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - - batch_size, query_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) - value_layer = value_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) - - kv_seq_len = key_layer.shape[-2] - if layer_past is not None: - if token_idx is not None: - if reuse_cache: - kv_seq_len = layer_past[0][-2] # layer_past conveys only shapes without kv tensors - else: - kv_seq_len = layer_past[0].shape[-2] - else: - kv_length += layer_past[0].shape[-2] - if alibi is None: - cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) - query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids) - - if layer_past is not None or reuse_cache: - if reuse_cache: - key_layer = self.k_cache(key_layer, -2, token_idx) - value_layer = self.v_cache(value_layer, -2, token_idx) - else: - key_layer = update( - layer_past[0], key_layer, -2, token_idx, self.inp_seq_len - ) # k_layer bs*1, q_len, head_dim - value_layer = update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) - - if cache_idx is not None and query_length == 1: - key_layer = key_layer[:, :, :cache_idx, :] - value_layer = value_layer[:, :, :cache_idx, :] - attention_mask = attention_mask[:, :, :, :cache_idx] - kv_seq_len = key_layer.shape[-2] - - kv_length = key_layer.shape[-2] - if use_cache: - if reuse_cache: - present = (self.k_cache.get_shape(), self.v_cache.get_shape()) - else: - present = (key_layer, value_layer) - else: - present = None - - if alibi is None: - if output_attentions: - attention_scores = query_layer @ key_layer.transpose(-1, -2) - attention_scores /= math.sqrt(self.head_dim) - - attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) - attn_output = attention_scores @ value_layer - else: - if FusedSDPA: - if os.getenv("QUANT_CONFIG", ""): - attn_output = self.sdpa( - query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False - ) - else: - with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): - attn_output = FusedSDPA.apply( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - self.is_causal and attention_mask is None and query_length > 1, - ) - else: - # Workaround util scaled_dot_product_attention support broadcast. - if self.training is True and query_layer.shape != key_layer.shape: - key_layer = torch.broadcast_to(key_layer, query_layer.shape) - value_layer = torch.broadcast_to(value_layer, query_layer.shape) - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - # Performance improvement for HPU - if self.training is True and htcore: - htcore.mark_step() - attention_scores = None - - attn_output = attn_output.view(batch_size, -1, query_length, self.head_dim) - attn_output = attn_output.permute(0, 2, 1, 3) - attn_output = attn_output.reshape(batch_size, query_length, -1) - - attn_output = self.dense(attn_output) - - if output_attentions: - return attn_output, present, attention_scores - else: - return attn_output, present - - else: - if self._use_sdpa and not output_attentions and head_mask is None: - if FusedSDPA: - with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): - attn_output = FusedSDPA.apply( - query_layer, - key_layer, - value_layer, - attention_mask, - self.attention_dropout.p if self.training else 0.0, - self.is_causal and attention_mask is None and query_length > 1, - ) - else: - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=attention_mask, - dropout_p=self.attention_dropout.p if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - - attn_output = self.dense(attn_output) - else: - matmul_result = query_layer @ key_layer.transpose(-1, -2) - - # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) - - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attention_scores.dtype - # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` - if input_dtype == torch.float16 or input_dtype == torch.bfloat16: - attention_scores = attention_scores.to(torch.float32) - - attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) - attention_logits *= self.inv_norm_factor - attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) - # [batch_size, num_heads, q_length, kv_length] - attention_probs = self.attention_dropout(attention_probs) - - if head_mask is not None: - attention_probs = attention_probs * head_mask - - # change view [batch_size, num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) - - # matmul: [batch_size * num_heads, q_length, head_dim] - attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) - - # change view [batch_size, q_length, num_heads * head_dim] - attn_output = self._merge_heads(attn_output) - - attn_output = self.dense(attn_output) - - if output_attentions: - return attn_output, present, attention_probs - else: - return attn_output, present - def pre_attn_forward( self, hidden_states: torch.Tensor, @@ -485,6 +276,7 @@ def pre_attn_forward( The only differences are: - add new args token_idx and position_ids - replace F.scaled_dot_product_attention with Habana torch's version + - add new arg reuse_cache """ if "padding_mask" in kwargs: warnings.warn( @@ -721,14 +513,10 @@ def forward( warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - if not self.config.new_decoder_architecture: - residual = hidden_states - - attention_layernorm_out = self.input_layernorm(hidden_states) - - # Self attention. - attn_outputs = self.self_attention( - attention_layernorm_out, + residual = hidden_states + hidden_states, present, attn_scores, attention_layernorm_out, mlp_layernorm_out = ( + self.pre_attn( # layernorm + attention before AllReduce + hidden_states, layer_past=layer_past, attention_mask=attention_mask, position_ids=position_ids, @@ -738,11 +526,17 @@ def forward( output_attentions=output_attentions, token_idx=token_idx, reuse_cache=reuse_cache, + cache_idx=cache_idx, **kwargs, ) + ) + + self.self_attention.attention_all_reduce(hidden_states) + hidden_states = self.self_attention.post_attn_forward(hidden_states) - attention_output = attn_outputs[0] + attention_output = hidden_states + if not self.config.new_decoder_architecture: if self.config.parallel_attn: mlp_layernorm_out = attention_layernorm_out else: @@ -751,42 +545,11 @@ def forward( ) mlp_layernorm_out = self.post_attention_layernorm(residual) - outputs = attn_outputs[1:] - else: - residual = hidden_states - hidden_states, present, attn_scores, mlp_layernorm_out = ( - self.pre_attn( # layernorm+attention before AllReduce - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - position_ids=position_ids, - alibi=alibi, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - token_idx=token_idx, - reuse_cache=reuse_cache, - cache_idx=cache_idx, - **kwargs, - ) - ) + outputs = (present, attn_scores) - self.self_attention.attention_all_reduce(hidden_states) - hidden_states = self.self_attention.post_attn_forward( - hidden_states - ) - - attention_output = hidden_states - - outputs = (present, attn_scores) - - # MLP - if not self.config.new_decoder_architecture: - hidden_states = self.mlp(mlp_layernorm_out) - else: - hidden_states = self.mlp.pre_mlp_forward(mlp_layernorm_out) - self.mlp.mlp_all_reduce(hidden_states) - hidden_states = self.mlp.post_mlp_forward(hidden_states) + hidden_states = self.mlp.pre_mlp_forward(mlp_layernorm_out) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.mlp.post_mlp_forward(hidden_states) if self.config.new_decoder_architecture or self.config.parallel_attn: hidden_states += attention_output @@ -852,7 +615,7 @@ def pre_attn( cache_idx=cache_idx, ) - return attn_outputs, present, attn_scores, mlp_layernorm_out + return attn_outputs, present, attn_scores, attention_layernorm_out, mlp_layernorm_out class GaudiFalconModel(FalconModel): @@ -864,6 +627,7 @@ class GaudiFalconModel(FalconModel): - set past_key_values_length=0 when token_idx is used (with static input shape) - add new arg tgt_len to _expand_mask because past_key_values_length is no longer valid with token_idx - use old version of _make_causal_mask to workaround toch.triu that is not supported in Synapse + - add new arg reuse_cache """ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): @@ -926,7 +690,7 @@ def forward( # Compute alibi tensor: check build_alibi_tensor documentation past_key_values_length = 0 - if past_key_values[0] is not None and token_idx is None: ### non static input + if past_key_values[0] is not None and token_idx is None: if reuse_cache: past_key_values_length = past_key_values[0][0][-2] else: From 69073f0787c305717f6200481be88f16dd61867d Mon Sep 17 00:00:00 2001 From: Local Lab User Date: Fri, 8 Mar 2024 21:45:56 +0200 Subject: [PATCH 12/33] resolve issues in finetuning --- examples/text-generation/README.md | 10 ++++------ .../transformers/models/falcon/modeling_falcon.py | 5 ++++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index ef909ad324..f7bfaddcc7 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -325,17 +325,15 @@ QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generati Here is an example to measure the tensor quantization statistics on Falcon-180B with 8 cards: ```bash QUANT_CONFIG=./quantization_config/maxabs_measure_include_outputs.json python ../gaudi_spawn.py \ ---use_deepspeed --world_size 8 run_generation.py \ +--use_deepspeed --world_size 8 run_lm_eval.py \ +-o acc_falcon180b_bs1_quant.txt \ --model_name_or_path tiiuae/falcon-180B \ --use_hpu_graphs \ --use_kv_cache \ ---limit_hpu_graphs \ ---max_input_tokens 128 \ ---max_new_tokens 128 \ +--trim_logits \ --batch_size 1 \ --bf16 \ ---reuse_cache \ ---trim_logits +--reuse_cache ``` Here is an example to quantize the model based on previous measurements for Falcon-180B with 8 cards: diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index a3809a7f8f..c9620b1a77 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -63,7 +63,10 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: https://github.com/huggingface/transformers/blob/b338a6c3b8eda29610d4d472cad8cd87cbfdaaed/src/transformers/models/falcon/modeling_falcon.py#L248 """ out = F.dropout(x, p=prob, training=training) - out.add_(residual) + if training: + out = residual + out + else: + out.add_(residual) return out From 3ef3e9154344302448cfb5b593d61e4bdb4e0d71 Mon Sep 17 00:00:00 2001 From: Local Lab User Date: Tue, 12 Mar 2024 01:13:31 +0200 Subject: [PATCH 13/33] enable non reuse cache flow for fp8 --- examples/text-generation/README.md | 1 + .../models/falcon/modeling_falcon.py | 70 ++++++++++--------- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index f7bfaddcc7..e14256792a 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -323,6 +323,7 @@ QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generati ``` Here is an example to measure the tensor quantization statistics on Falcon-180B with 8 cards: +> Please note that Falcon-180B is a gated model, and users are required to request access to it. Please refer to the instructions provided in the StarCoder example above. ```bash QUANT_CONFIG=./quantization_config/maxabs_measure_include_outputs.json python ../gaudi_spawn.py \ --use_deepspeed --world_size 8 run_lm_eval.py \ diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index c9620b1a77..026fd6bb40 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -178,27 +178,6 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa return self.bmm2(attn_weight, value) -def update(prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - cur = cur.to(dtype=prev.dtype) - - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - - if cur.shape[-2] > 1 and cur.shape[-2] <= prev.shape[-2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - prev_cast = prev.to(orig_cur.dtype) - return prev_cast - else: - return torch.cat((prev, cur), dim=dim) - - class KVCache(torch.nn.Module): def __init__(self): super(KVCache, self).__init__() @@ -224,7 +203,23 @@ def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) def update(self, prev, cur, dim, idx, inp_seq_len): - return update(prev, cur, dim, idx, inp_seq_len) + orig_cur = cur + cur = cur.to(dtype=prev.dtype) + + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + + if cur.shape[-2] > 1 and cur.shape[-2] <= prev.shape[-2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[-2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) class GaudiFalconAttention(FalconAttention): @@ -310,31 +305,39 @@ def pre_attn_forward( cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids) - if layer_past is not None or reuse_cache: + if use_cache: if reuse_cache: key_layer = self.k_cache(key_layer, -2, token_idx) value_layer = self.v_cache(value_layer, -2, token_idx) + present = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: - key_layer = update( + if layer_past is None: + past_key = torch.zeros( + key_layer.shape, + dtype=self.query_key_value.weight.dtype, + device=self.query_key_value.weight.device, + ) + past_value = torch.zeros( + key_layer.shape, + dtype=self.query_key_value.weight.dtype, + device=self.query_key_value.weight.device, + ) + layer_past = (past_key, past_value) + key_layer = self.k_cache.update( layer_past[0], key_layer, -2, token_idx, self.inp_seq_len ) # k_layer bs*1, q_len, head_dim - value_layer = update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) + value_layer = self.v_cache.update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) + present = layer_past if cache_idx is not None and query_length == 1: key_layer = key_layer[:, :, :cache_idx, :] value_layer = value_layer[:, :, :cache_idx, :] attention_mask = attention_mask[:, :, :, :cache_idx] - kv_seq_len = key_layer.shape[-2] - - kv_length = key_layer.shape[-2] - if use_cache: - if reuse_cache: - present = (self.k_cache.get_shape(), self.v_cache.get_shape()) - else: - present = (key_layer, value_layer) else: present = None + kv_length = present[0][-2] if reuse_cache else present[0].shape[-2] + if alibi is None: if output_attentions: attention_scores = query_layer @ key_layer.transpose(-1, -2) @@ -349,7 +352,6 @@ def pre_attn_forward( attn_output = self.sdpa( query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False ) - else: with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): attn_output = FusedSDPA.apply( From ee8a90a008485a1c9de8d07250137bf4df8dd1c8 Mon Sep 17 00:00:00 2001 From: Local Lab User Date: Wed, 13 Mar 2024 02:15:18 +0200 Subject: [PATCH 14/33] revert non reuse_cache flow for training due to perf drop --- .../models/falcon/modeling_falcon.py | 105 ++++++++++-------- 1 file changed, 61 insertions(+), 44 deletions(-) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 026fd6bb40..4acda4f08c 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -178,6 +178,27 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa return self.bmm2(attn_weight, value) +def update(prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + cur = cur.to(dtype=prev.dtype) + + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + + if cur.shape[-2] > 1 and cur.shape[-2] <= prev.shape[-2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + prev_cast = prev.to(orig_cur.dtype) + return prev_cast + else: + return torch.cat((prev, cur), dim=dim) + + class KVCache(torch.nn.Module): def __init__(self): super(KVCache, self).__init__() @@ -203,23 +224,7 @@ def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) def update(self, prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - cur = cur.to(dtype=prev.dtype) - - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - - if cur.shape[-2] > 1 and cur.shape[-2] <= prev.shape[-2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[-2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - return prev - else: - return torch.cat((prev, cur), dim=dim) + return update(prev, cur, dim, idx, inp_seq_len) class GaudiFalconAttention(FalconAttention): @@ -306,37 +311,49 @@ def pre_attn_forward( query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids) if use_cache: - if reuse_cache: - key_layer = self.k_cache(key_layer, -2, token_idx) - value_layer = self.v_cache(value_layer, -2, token_idx) - present = (self.k_cache.get_shape(), self.v_cache.get_shape()) + if self.training: + if layer_past is not None: + key_layer = update(layer_past[0], key_layer, -2, token_idx, self.inp_seq_len) + value_layer = update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) + present = (key_layer, value_layer) + else: + present = None + else: - if layer_past is None: - past_key = torch.zeros( - key_layer.shape, - dtype=self.query_key_value.weight.dtype, - device=self.query_key_value.weight.device, - ) - past_value = torch.zeros( - key_layer.shape, - dtype=self.query_key_value.weight.dtype, - device=self.query_key_value.weight.device, - ) - layer_past = (past_key, past_value) - key_layer = self.k_cache.update( - layer_past[0], key_layer, -2, token_idx, self.inp_seq_len - ) # k_layer bs*1, q_len, head_dim - value_layer = self.v_cache.update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) - present = layer_past - - if cache_idx is not None and query_length == 1: - key_layer = key_layer[:, :, :cache_idx, :] - value_layer = value_layer[:, :, :cache_idx, :] - attention_mask = attention_mask[:, :, :, :cache_idx] + if reuse_cache: + key_layer = self.k_cache(key_layer, -2, token_idx) + value_layer = self.v_cache(value_layer, -2, token_idx) + present = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + if layer_past is None: + past_key = torch.zeros( + key_layer.shape, + dtype=self.query_key_value.weight.dtype, + device=self.query_key_value.weight.device, + ) + past_value = torch.zeros( + key_layer.shape, + dtype=self.query_key_value.weight.dtype, + device=self.query_key_value.weight.device, + ) + layer_past = (past_key, past_value) + key_layer = self.k_cache.update( + layer_past[0], key_layer, -2, token_idx, self.inp_seq_len + ) # k_layer bs*1, q_len, head_dim + value_layer = self.v_cache.update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) + present = layer_past + + if cache_idx is not None and query_length == 1: + key_layer = key_layer[:, :, :cache_idx, :] + value_layer = value_layer[:, :, :cache_idx, :] + attention_mask = attention_mask[:, :, :, :cache_idx] else: present = None - kv_length = present[0][-2] if reuse_cache else present[0].shape[-2] + if self.training and layer_past is None: + kv_length = key_layer.shape[-2] + else: + kv_length = present[0][-2] if reuse_cache else present[0].shape[-2] if alibi is None: if output_attentions: From d90366aeb8c6f06a53c59d9333d651dde8a5ee3c Mon Sep 17 00:00:00 2001 From: Local Lab User Date: Thu, 14 Mar 2024 01:32:31 +0200 Subject: [PATCH 15/33] add falcon180B FP8 test --- tests/test_text_generation_example.py | 29 ++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index c397c7269a..a5638e82b8 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -26,6 +26,9 @@ ("mistralai/Mistral-7B-v0.1", 125.26115369093216), ("mistralai/Mixtral-8x7B-v0.1", 23.78652574031883), ], + "fp8": [ + ("tiiuae/falcon-180B", 47.67900945905787), + ], "deepspeed": [ ("bigscience/bloomz", 36.34664210641816), ("meta-llama/Llama-2-70b-hf", 61.973950428647164), @@ -69,6 +72,7 @@ def _test_text_generation( deepspeed: bool = False, world_size: int = 8, torch_compile: bool = False, + fp8: bool = False, ): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" @@ -106,6 +110,13 @@ def _test_text_generation( if not deepspeed: command.append("--bf16") + if fp8: + command += [ + "--fp8", + "--reuse_cache", + "--trim_logits", + ] + with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}") print(f"\n\nCommand to test: {' '.join(command)}\n") @@ -115,7 +126,16 @@ def _test_text_generation( pattern = re.compile(r"([\"\'].+?[\"\'])|\s") command = [x for y in command for x in re.split(pattern, y) if x] - proc = subprocess.run(command, env=env_variables) + if fp8: + os.environ["QUANT_CONFIG"] = os.path.join( + path_to_example_dir, "text-generation/quantization_config/maxabs_measure_include_outputs.json" + ) + subprocess.run(command) + os.environ["QUANT_CONFIG"] = os.path.join( + path_to_example_dir, "text-generation/quantization_config/maxabs_quant.json" + ) + + proc = subprocess.run(command) # Ensure the run finished without any issue # Use try-except to avoid logging the token if used @@ -138,6 +158,13 @@ def test_text_generation_bf16(model_name: str, baseline: float, token: str): _test_text_generation(model_name, baseline, token) +@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["fp8"]) +def test_text_generation_fp8(model_name: str, baseline: float, token: str): + deepspeed = True if "falcon-180B" in model_name else False + world_size = 8 if "falcon-180B" in model_name else None + _test_text_generation(model_name, baseline, token, deepspeed=deepspeed, world_size=world_size, fp8=True) + + @pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["deepspeed"]) def test_text_generation_deepspeed(model_name: str, baseline: float, token: str): world_size = 2 if "opt-66b" in model_name else 8 From 68f63592f2080b2c2c23c64728ce7a9056f3c9ec Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Sat, 16 Mar 2024 07:15:50 +0000 Subject: [PATCH 16/33] fix run_lm_eval.py to save --reuse_cache --- examples/text-generation/run_lm_eval.py | 7 ++++++- optimum/habana/transformers/modeling_utils.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 4ae8dcb26c..554490fd57 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -75,10 +75,15 @@ def __init__(self, tokenizer, model, args, options): self.options = options self._device = args.device self.model_inputs = {"use_cache": self.options.use_cache} - if self.model.config.model_type == "llama": + if self.model.config.model_type == "llama" or "falcon": self.model_inputs.update( { "reuse_cache": self.options.reuse_cache, + } + ) + if self.model.config.model_type == "llama": + self.model_inputs.update( + { "attn_softmax_bf16": self.options.attn_softmax_bf16, } ) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index f848dac500..66a396cf02 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -85,6 +85,7 @@ gaudi_esm_for_protein_folding_forward, gaudi_esmfolding_trunk_forward, gaudi_falcon_attention_split_heads, + gaudi_generate_speech, gaudi_get_extended_attention_mask, gaudi_gpt2_block_forward, gaudi_gpt2_forward, From 8d045ff4a09b50520bd902d3ccfe842229fa744d Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Tue, 19 Mar 2024 01:29:00 +0000 Subject: [PATCH 17/33] modify comments --- .../models/falcon/modeling_falcon.py | 72 ++++++++++--------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 4acda4f08c..1044fbbf85 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -228,6 +228,15 @@ def update(self, prev, cur, dim, idx, inp_seq_len): class GaudiFalconAttention(FalconAttention): + """ + Inherits from FalconAttention: https://github.com/huggingface/transformers/blob/838b87abe231fd70be5132088d0dee72a7bb8d62/src/transformers/models/falcon/modeling_falcon.py#L267 + The only differences are: + - add new args token_idx and position_ids + - replace F.scaled_dot_product_attention with Habana torch's version for BF16 + - use ScaledDotProductAttention for FP8 quantization + - add new arg reuse_cache + """ + def __init__(self, config: FalconConfig): super().__init__(config) @@ -274,13 +283,6 @@ def pre_attn_forward( cache_idx: int = None, **kwargs, ): - """ - Copied from FalconAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py - The only differences are: - - add new args token_idx and position_ids - - replace F.scaled_dot_product_attention with Habana torch's version - - add new arg reuse_cache - """ if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" @@ -483,6 +485,10 @@ def post_attn_forward(self, attn_output): class GaudiFalconMLP(FalconMLP): + """ + Inherits from FalconMLP: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py + """ + def pre_mlp_forward(self, x): x = self.act(self.dense_h_to_4h(x)) x = self.dense_4h_to_h(x) @@ -499,6 +505,14 @@ def post_mlp_forward(self, x): class GaudiFalconDecoderLayer(FalconDecoderLayer): + """ + Inherits from FalconDecoderLayer: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py + The only differences are: + - add new args token_idx and position_ids + - add token_idx and position_ids into attention inputs + - add new args reuse_cache + """ + def __init__(self, config: FalconConfig): super().__init__(config) self.self_attention = GaudiFalconAttention(config) @@ -524,33 +538,30 @@ def forward( cache_idx: int = None, **kwargs, ): - """ - Copied from FalconDecoderLayer: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py - The only differences are: - - add new args token_idx and position_ids - - add token_idx and position_ids into attention inputs - - add new args reuse_cache - """ if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) residual = hidden_states - hidden_states, present, attn_scores, attention_layernorm_out, mlp_layernorm_out = ( - self.pre_attn( # layernorm + attention before AllReduce - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - position_ids=position_ids, - alibi=alibi, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - token_idx=token_idx, - reuse_cache=reuse_cache, - cache_idx=cache_idx, - **kwargs, - ) + ( + hidden_states, + present, + attn_scores, + attention_layernorm_out, + mlp_layernorm_out, + ) = self.pre_attn( # layernorm + attention before AllReduce + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, + **kwargs, ) self.self_attention.attention_all_reduce(hidden_states) @@ -646,9 +657,6 @@ class GaudiFalconModel(FalconModel): The only differences are: - add new args token_idx and position_ids - add token_idx and position_ids into decoder inputs - - set past_key_values_length=0 when token_idx is used (with static input shape) - - add new arg tgt_len to _expand_mask because past_key_values_length is no longer valid with token_idx - - use old version of _make_causal_mask to workaround toch.triu that is not supported in Synapse - add new arg reuse_cache """ From 2ac8038dab7bcd6e501cff7c5a6fccffd2fcc321 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Tue, 19 Mar 2024 16:32:34 -0700 Subject: [PATCH 18/33] add falcon180b FP8 test (#104) --- tests/test_text_generation_example.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index e7c4a98a2f..93912e0f37 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -30,6 +30,9 @@ "fp8": [ ("tiiuae/falcon-180B", 47.67900945905787), ], + "fp8": [ + ("tiiuae/falcon-180B", 47.67900945905787), + ], "deepspeed": [ ("bigscience/bloomz", 36.34664210641816), ("meta-llama/Llama-2-70b-hf", 61.973950428647164), From fe9209472552a3ece716092bfd696c2e2f1af90e Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Wed, 20 Mar 2024 21:25:01 +0000 Subject: [PATCH 19/33] fix Falcon view+inplace error --- examples/text-generation/utils.py | 2 +- optimum/habana/transformers/models/falcon/modeling_falcon.py | 5 +++-- tests/test_text_generation_example.py | 3 --- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index d4d9dab871..f5e21ae532 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -237,7 +237,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): model = deepspeed.init_inference(model, **ds_inference_kwargs) model = model.module - if model.config.model_type == "llama" or "falcon": + if model.config.model_type in ["llama", "falcon"]: patch_scoped_linear_all_reduce(model) if args.quant_config: diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 1044fbbf85..7873da6fd6 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -65,9 +65,10 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: out = F.dropout(x, p=prob, training=training) if training: out = residual + out + return out else: - out.add_(residual) - return out + residual.add_(out) + return residual def apply_customized_rope(q, k, cos, sin, position_ids): diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 93912e0f37..e7c4a98a2f 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -30,9 +30,6 @@ "fp8": [ ("tiiuae/falcon-180B", 47.67900945905787), ], - "fp8": [ - ("tiiuae/falcon-180B", 47.67900945905787), - ], "deepspeed": [ ("bigscience/bloomz", 36.34664210641816), ("meta-llama/Llama-2-70b-hf", 61.973950428647164), From f882a9e682010a801dff04b9e2f519dc88c04a98 Mon Sep 17 00:00:00 2001 From: Pankaj Dixit Date: Tue, 19 Mar 2024 15:05:31 +0200 Subject: [PATCH 20/33] Add Llama7b FSDP test for torch.compile mode --- tests/baselines/llama_7b.json | 42 +++++++++++++++++++++++++++++++- tests/test_examples.py | 45 ++++++++++++++++++++++++++++------- 2 files changed, 77 insertions(+), 10 deletions(-) diff --git a/tests/baselines/llama_7b.json b/tests/baselines/llama_7b.json index a631e510a4..af519a29ac 100644 --- a/tests/baselines/llama_7b.json +++ b/tests/baselines/llama_7b.json @@ -59,6 +59,46 @@ ] } } + }, + "tatsu-lab/alpaca_fsdpcompile": { + "num_train_epochs": 1, + "eval_batch_size": 1, + "distribution": { + "multi_card": { + "learning_rate": 3e-4, + "train_batch_size": 8, + "perplexity": 2.4502, + "train_runtime": 210.305, + "train_samples_per_second": 85.0801, + "extra_arguments": [ + "--bf16 True", + "--gradient_accumulation_steps 2", + "--evaluation_strategy no", + "--save_strategy no", + "--warmup_ratio 0.03", + "--lr_scheduler_type constant", + "--max_grad_norm 0.3", + "--logging_steps 1", + "--lora_rank 8", + "--lora_alpha 16", + "--lora_dropout 0.05", + "--lora_target_modules q_proj v_proj", + "--dataset_concatenation", + "--max_seq_length 512", + "--low_cpu_mem_usage True", + "--adam_epsilon 1e-08", + "--ddp_bucket_cap_mb 50", + "--validation_split_percentage 10", + "--attn_softmax_bf16", + "--pipelining_fwd_bwd False", + "--fsdp auto_wrap", + "--torch_compile_backend hpu_backend", + "--torch_compile", + "--fsdp_config examples/language-modeling/fsdp_config.json" + ] + } + } } } -} \ No newline at end of file +} + diff --git a/tests/test_examples.py b/tests/test_examples.py index a8cea6163a..ab1b434e2d 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -164,7 +164,7 @@ class ExampleTestMeta(type): """ @staticmethod - def to_test(model_name: str, multi_card: bool, deepspeed: bool, example_name: str): + def to_test(model_name: str, multi_card: bool, deepspeed: bool, example_name: str, fsdp: bool): models_with_specific_rules = [ "albert-xxlarge-v1", "gpt2-xl", @@ -198,14 +198,14 @@ def to_test(model_name: str, multi_card: bool, deepspeed: bool, example_name: st return True elif "bridgetower" in model_name and os.environ.get("GAUDI2_CI", "0") == "1": return True - elif "falcon" in model_name and os.environ.get("GAUDI2_CI", "0") == "1": + elif "falcon" in model_name and os.environ.get("GAUDI2_CI", "0") == "1" and not fsdp: return True elif "bloom" in model_name and deepspeed and os.environ.get("GAUDI2_CI", "0") == "0": return True return False - def __new__(cls, name, bases, attrs, example_name=None, multi_card=False, deepspeed=False): + def __new__(cls, name, bases, attrs, example_name=None, multi_card=False, deepspeed=False, fsdp=False): distribution = "single_card" if multi_card: distribution = "multi_card" @@ -216,7 +216,7 @@ def __new__(cls, name, bases, attrs, example_name=None, multi_card=False, deepsp models_to_test = _SCRIPT_TO_MODEL_MAPPING.get(example_name) if models_to_test is None: if example_name in ["run_esmfold", "run_lora_clm"]: - attrs[f"test_{example_name}_{distribution}"] = cls._create_test(None, None, None, None) + attrs[f"test_{example_name}_{distribution}"] = cls._create_test(None, None, None, None, None) attrs["EXAMPLE_NAME"] = example_name return super().__new__(cls, name, bases, attrs) else: @@ -225,16 +225,21 @@ def __new__(cls, name, bases, attrs, example_name=None, multi_card=False, deepsp ) for model_name, gaudi_config_name in models_to_test: - if cls.to_test(model_name, multi_card, deepspeed, example_name): + if cls.to_test(model_name, multi_card, deepspeed, example_name, fsdp): attrs[f"test_{example_name}_{model_name.split('/')[-1]}_{distribution}"] = cls._create_test( - model_name, gaudi_config_name, multi_card, deepspeed + model_name, gaudi_config_name, multi_card, deepspeed, fsdp ) attrs["EXAMPLE_NAME"] = example_name return super().__new__(cls, name, bases, attrs) @classmethod def _create_test( - cls, model_name: str, gaudi_config_name: str, multi_card: bool = False, deepspeed: bool = False + cls, + model_name: str, + gaudi_config_name: str, + multi_card: bool = False, + deepspeed: bool = False, + fsdp: bool = False, ) -> Callable[[], None]: """ Create a test function that runs an example for a specific (model_name, gaudi_config_name) pair. @@ -310,11 +315,15 @@ def test(self): env_variables["DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED"] = "1" env_variables["PT_HPU_MAX_COMPOUND_OP_SYNC"] = "1" env_variables["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1" + elif fsdp: + env_variables["LOWER_LIST"] = str(example_script.parent / "ops_bf16.txt") + env_variables["PT_HPU_LAZY_MODE"] = "0" with TemporaryDirectory() as tmp_dir: cmd_line = self._create_command_line( multi_card, deepspeed, + fsdp, example_script, model_name, gaudi_config_name, @@ -364,6 +373,7 @@ class ExampleTesterBase(TestCase): EXAMPLE_NAME = None TASK_NAME = None DATASET_PARAMETER_NAME = "dataset_name" + DATASET_NAME = None REGRESSION_METRICS = { "eval_f1": (TestCase.assertGreaterEqual, ACCURACY_PERF_FACTOR), "eval_accuracy": (TestCase.assertGreaterEqual, ACCURACY_PERF_FACTOR), @@ -379,6 +389,7 @@ def _create_command_line( self, multi_card: bool, deepspeed: bool, + fsdp: bool, script: Path, model_name: str, gaudi_config_name: str, @@ -390,7 +401,8 @@ def _create_command_line( task: Optional[str] = None, extra_command_line_arguments: Optional[List[str]] = None, ) -> List[str]: - task_option = f"--{self.DATASET_PARAMETER_NAME} {task}" if task else " " + dataset_name = self.DATASET_NAME if self.DATASET_NAME else task + task_option = f"--{self.DATASET_PARAMETER_NAME} {dataset_name}" if task else " " cmd_line = ["python3"] if multi_card: @@ -418,11 +430,15 @@ def _create_command_line( f"--per_device_eval_batch_size {eval_batch_size}", f" --num_train_epochs {num_epochs}", "--use_habana", - "--use_lazy_mode", "--throughput_warmup_steps 3", "--save_strategy no", ] + if "compile" in task: + cmd_line += ["--use_lazy_mode False"] + else: + cmd_line += ["--use_lazy_mode"] + if "bloom" not in model_name: cmd_line.append("--do_eval") @@ -604,3 +620,14 @@ class MultiCardSeq2SeqSpeechRecognitionExampleTester( ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_speech_recognition_seq2seq", multi_card=True ): TASK_NAME = "mozilla-foundation/common_voice_11_0" + + +class MultiCardCausalLanguageModelingLORAFSDPCompileExampleTester( + ExampleTesterBase, + metaclass=ExampleTestMeta, + example_name="run_lora_clm", + multi_card=True, + fsdp=True, +): + TASK_NAME = "tatsu-lab/alpaca_fsdpcompile" + DATASET_NAME = "tatsu-lab/alpaca" From 12064e3f3c4853b47285152657ad05d70fbdddbb Mon Sep 17 00:00:00 2001 From: Vivek Date: Wed, 28 Feb 2024 13:35:11 +0200 Subject: [PATCH 21/33] Clean-up BERT-BASE FSDP test --- tests/test_fsdp_examples.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_fsdp_examples.py b/tests/test_fsdp_examples.py index 29198ae7bd..248647cb53 100644 --- a/tests/test_fsdp_examples.py +++ b/tests/test_fsdp_examples.py @@ -54,8 +54,6 @@ def _test_fsdp( world_size: int = 8, ): os.environ["PT_HPU_LAZY_MODE"] = "0" - os.environ["PT_HPU_EAGER_4_STAGE_PIPELINE_ENABLE"] = "0" # To be removed later - os.environ["PT_HPU_EAGER_PIPELINE_ENABLE"] = "0" # To be removed later path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" # Install question-answering example requirements From 7a409fa58afb16b24fea2002173205420f35ab22 Mon Sep 17 00:00:00 2001 From: Vivek Goel Date: Thu, 29 Feb 2024 20:22:33 +0530 Subject: [PATCH 22/33] enable hpu_graph support for wav2vec2-asr (#59) --- examples/speech-recognition/README.md | 11 +- optimum/habana/transformers/modeling_utils.py | 2 + .../habana/transformers/models/__init__.py | 1 + .../transformers/models/wav2vec2/__init__.py | 1 + .../models/wav2vec2/modeling_wav2vec2.py | 161 +++++++++++++----- tests/baselines/wav2vec2_large_lv60.json | 8 +- 6 files changed, 137 insertions(+), 47 deletions(-) diff --git a/examples/speech-recognition/README.md b/examples/speech-recognition/README.md index 30d4d54598..0d037481c8 100644 --- a/examples/speech-recognition/README.md +++ b/examples/speech-recognition/README.md @@ -78,7 +78,9 @@ python run_speech_recognition_ctc.py \ --use_lazy_mode \ --gaudi_config_name="Habana/wav2vec2" \ --throughput_warmup_steps="3" \ - --bf16 + --bf16 \ + --use_hpu_graphs_for_training \ + --use_hpu_grpahs_for_inference ``` On a single HPU, this script should run in *ca.* 6 hours and yield a CTC loss of **0.059** and a word error rate of **0.0423**. @@ -117,7 +119,9 @@ python ../gaudi_spawn.py \ --use_lazy_mode \ --gaudi_config_name Habana/wav2vec2 \ --throughput_warmup_steps 3 \ - --bf16 + --bf16 \ + --use_hpu_graphs_for_training \ + --use_hpu_graphs_for_inference ``` On 8 HPUs, this script should run in *ca.* 49 minutes and yield a CTC loss of **0.0613** and a word error rate of **0.0458**. @@ -196,7 +200,8 @@ python run_speech_recognition_ctc.py \ --use_habana \ --use_lazy_mode \ --gaudi_config_name="Habana/wav2vec2" \ - --bf16 + --bf16 \ + --use_hpu_graphs_for_inference ``` ## Sequence to Sequence diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index f7f88cc690..3c1d186154 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -135,6 +135,7 @@ gaudi_vit_self_attention_forward, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, + gaudi_wav2vec2forctc_forward, gaudi_wav2vec2_tdnnlayer_forward, ) @@ -162,6 +163,7 @@ def adapt_transformers_to_gaudi(): ) transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward = gaudi_wav2vec2_forward transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder.forward = gaudi_wav2vec2_encoder_forward + transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward = gaudi_wav2vec2forctc_forward transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer.forward = gaudi_wav2vec2_tdnnlayer_forward # Generation is modified to run faster in lazy mode diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 481ddae49d..5631859530 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -146,5 +146,6 @@ _gaudi_wav2vec2_sample_negative_indices, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, + gaudi_wav2vec2forctc_forward, gaudi_wav2vec2_tdnnlayer_forward, ) diff --git a/optimum/habana/transformers/models/wav2vec2/__init__.py b/optimum/habana/transformers/models/wav2vec2/__init__.py index 3a5bae22b8..df43104ce5 100644 --- a/optimum/habana/transformers/models/wav2vec2/__init__.py +++ b/optimum/habana/transformers/models/wav2vec2/__init__.py @@ -4,5 +4,6 @@ _gaudi_wav2vec2_sample_negative_indices, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, + gaudi_wav2vec2forctc_forward, gaudi_wav2vec2_tdnnlayer_forward, ) diff --git a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py index 983c5b5375..328461d3f2 100644 --- a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -17,13 +17,18 @@ from typing import Optional, Tuple, Union import torch +from habana_frameworks.torch.hpex.kernels import CTCLoss from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_outputs import ( BaseModelOutput, + CausalLMOutput, Wav2Vec2BaseModelOutput, ) +ctc_loss_fwd = CTCLoss.apply + + def _gaudi_wav2vec2_compute_mask_indices( shape: Tuple[int, int], mask_prob: float, @@ -33,7 +38,8 @@ def _gaudi_wav2vec2_compute_mask_indices( ) -> torch.Tensor: """ Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L135 - The only difference is that the processing is performed with PyTorch on HPUs (Numpy is used in Transformers). + The only differences are (1) that the processing is performed with PyTorch on HPUs (Numpy is used in Transformers), (2) epsilon is generated on HPU instead of CPU, (3) check + to ensure indices are not larger than sequence length is re-written to avoid host sync. """ batch_size, sequence_length = shape @@ -122,8 +128,9 @@ def compute_num_masked_span(input_length): spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length - if spec_aug_mask_idxs.max() > sequence_length - 1: - spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + mask = (spec_aug_mask_idxs > sequence_length - 1) * (spec_aug_mask_idxs.max() > sequence_length - 1) + inverse_mask = torch.bitwise_not(mask) + spec_aug_mask_idxs = spec_aug_mask_idxs * inverse_mask + (sequence_length - 1) * mask # scatter indices to mask spec_aug_mask.scatter_(-1, spec_aug_mask_idxs, 1) @@ -172,6 +179,63 @@ def _gaudi_wav2vec2_sample_negative_indices( return sampled_negative_indices +def gaudi_wav2vec2_forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + """ + Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1282 + The only difference is that a clone of `hidden_states` is given to _mask_hidden_states to avoid an error. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states.clone(), mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return Wav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def _gaudi_wav2vec2_mask_hidden_states( self, hidden_states: torch.FloatTensor, @@ -300,61 +364,74 @@ def gaudi_wav2vec2_encoder_forward( ) -def gaudi_wav2vec2_forward( +_HIDDEN_STATES_START_POSITION = 2 + + +def gaudi_wav2vec2forctc_forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, -) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + labels: Optional[torch.Tensor] = None, +) -> Union[Tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. """ - Copied from Transformers: https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1282 - The only difference is that a clone of `hidden_states` is given to _mask_hidden_states to avoid an error. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + copied from Transformers https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1950 + only differences are (1) attention_mask tensor generation using ones_like is done on HPU, (2) masked_select is not applied on labels to compute flattened_targets to avoid + changing flattened_targets tensor shapes across training iterations. + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - extract_features = self.feature_extractor(input_values) - extract_features = extract_features.transpose(1, 2) - - if attention_mask is not None: - # compute reduced attention_mask corresponding to feature vectors - attention_mask = self._get_feature_vector_attention_mask( - extract_features.shape[1], attention_mask, add_adapter=False - ) - - hidden_states, extract_features = self.feature_projection(extract_features) - hidden_states = self._mask_hidden_states( - hidden_states.clone(), mask_time_indices=mask_time_indices, attention_mask=attention_mask - ) - - encoder_outputs = self.encoder( - hidden_states, + outputs = self.wav2vec2( + input_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) - - hidden_states = encoder_outputs[0] - - if self.adapter is not None: - hidden_states = self.adapter(hidden_states) + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + if labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask + if attention_mask is not None + else torch.ones_like(input_values, dtype=torch.long, device="hpu") + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels + # ctc_loss doesn't support fp16 + log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + with torch.backends.cudnn.flags(enabled=False): + loss = ctc_loss_fwd( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + self.config.pad_token_id, + self.config.ctc_loss_reduction, + self.config.ctc_zero_infinity, + ) if not return_dict: - return (hidden_states, extract_features) + encoder_outputs[1:] - - return Wav2Vec2BaseModelOutput( - last_hidden_state=hidden_states, - extract_features=extract_features, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + return CausalLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) def gaudi_wav2vec2_tdnnlayer_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/tests/baselines/wav2vec2_large_lv60.json b/tests/baselines/wav2vec2_large_lv60.json index b1071302fa..86fa3b92b5 100644 --- a/tests/baselines/wav2vec2_large_lv60.json +++ b/tests/baselines/wav2vec2_large_lv60.json @@ -21,7 +21,9 @@ "--layerdrop 0.0", "--freeze_feature_encoder", "--dataloader_num_workers 8", - "--chars_to_ignore ',?.!-;:\"“%‘”'" + "--chars_to_ignore ',?.!-;:\"“%‘”'", + "--use_hpu_graphs_for_training", + "--use_hpu_graphs_for_inference" ] } } @@ -49,7 +51,9 @@ "--layerdrop 0.0", "--freeze_feature_encoder", "--dataloader_num_workers 8", - "--chars_to_ignore ',?.!-;:\"“%‘”'" + "--chars_to_ignore ',?.!-;:\"“%‘”'", + "--use_hpu_graphs_for_training", + "--use_hpu_graphs_for_inference" ] } } From f10da0838ec100a42c740b61f34de5e1224cf353 Mon Sep 17 00:00:00 2001 From: Vivek Goel Date: Thu, 7 Mar 2024 09:29:24 +0530 Subject: [PATCH 23/33] Run custom ctc_loss only for Gaudi2 (#95) --- .../models/wav2vec2/modeling_wav2vec2.py | 46 +++++++++++++------ 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py index 328461d3f2..72bdd09ec1 100644 --- a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -18,6 +18,7 @@ import torch from habana_frameworks.torch.hpex.kernels import CTCLoss +from habana_frameworks.torch.hpu import get_device_name from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_outputs import ( BaseModelOutput, @@ -128,9 +129,13 @@ def compute_num_masked_span(input_length): spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length - mask = (spec_aug_mask_idxs > sequence_length - 1) * (spec_aug_mask_idxs.max() > sequence_length - 1) - inverse_mask = torch.bitwise_not(mask) - spec_aug_mask_idxs = spec_aug_mask_idxs * inverse_mask + (sequence_length - 1) * mask + if get_device_name() == "GAUDI": + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + else: + mask = (spec_aug_mask_idxs > sequence_length - 1) * (spec_aug_mask_idxs.max() > sequence_length - 1) + inverse_mask = torch.bitwise_not(mask) + spec_aug_mask_idxs = spec_aug_mask_idxs * inverse_mask + (sequence_length - 1) * mask # scatter indices to mask spec_aug_mask.scatter_(-1, spec_aug_mask_idxs, 1) @@ -414,19 +419,32 @@ def gaudi_wav2vec2forctc_forward( # when not being attended to labels_mask = labels >= 0 target_lengths = labels_mask.sum(-1) - flattened_targets = labels # ctc_loss doesn't support fp16 log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) - with torch.backends.cudnn.flags(enabled=False): - loss = ctc_loss_fwd( - log_probs, - flattened_targets, - input_lengths, - target_lengths, - self.config.pad_token_id, - self.config.ctc_loss_reduction, - self.config.ctc_zero_infinity, - ) + if get_device_name() == "GAUDI": + flattened_targets = labels.masked_select(labels_mask) + with torch.backends.cudnn.flags(enabled=False): + loss = torch.nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + else: + flattened_targets = labels + with torch.backends.cudnn.flags(enabled=False): + loss = ctc_loss_fwd( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + self.config.pad_token_id, + self.config.ctc_loss_reduction, + self.config.ctc_zero_infinity, + ) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] From 6739c0994456ac06ec012fa6434ad95adb3ccc91 Mon Sep 17 00:00:00 2001 From: Vivek Date: Sat, 9 Mar 2024 09:57:35 +0200 Subject: [PATCH 24/33] Update test baseline --- tests/baselines/wav2vec2_large_lv60.json | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/baselines/wav2vec2_large_lv60.json b/tests/baselines/wav2vec2_large_lv60.json index 86fa3b92b5..6792b855ee 100644 --- a/tests/baselines/wav2vec2_large_lv60.json +++ b/tests/baselines/wav2vec2_large_lv60.json @@ -21,9 +21,7 @@ "--layerdrop 0.0", "--freeze_feature_encoder", "--dataloader_num_workers 8", - "--chars_to_ignore ',?.!-;:\"“%‘”'", - "--use_hpu_graphs_for_training", - "--use_hpu_graphs_for_inference" + "--chars_to_ignore ',?.!-;:\"“%‘”'" ] } } @@ -35,12 +33,12 @@ "eval_batch_size": 8, "distribution": { "multi_card": { - "learning_rate": 3e-4, + "learning_rate": 4e-4, "train_batch_size": 8, - "eval_wer": 0.0531535105117017, - "train_runtime": 356.4723, - "train_samples_per_second": 183.245, - "eval_samples_per_second": 158.985, + "eval_wer": 0.06120587068623562, + "train_runtime": 308.8036, + "train_samples_per_second": 225.572, + "eval_samples_per_second": 196.665, "extra_arguments": [ "--dataset_config_name clean", "--train_split_name train.100", From 8a59700fdff783f229cb8470334e2e05ef802921 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Sat, 23 Mar 2024 00:29:53 +0000 Subject: [PATCH 25/33] Enable Llama2 70B to run with hqt on single card (#50) Add disk_offload flag that controls device_map=auto. Setting this flag enbales weights offload to disk when cpu memory runs OOM. Add const serialization path flag that gets a path for where to serialize const sections, so if there is no space on device to save all const sections they will be offloaded to disk. --- examples/text-generation/run_generation.py | 1 - examples/text-generation/run_lm_eval.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 2c89c63256..aff06f288b 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -573,7 +573,6 @@ def generate_dataset(batch): habana_quantization_toolkit.finish_measurements(model) if args.const_serialization_path and os.path.isdir(args.const_serialization_path): import shutil - shutil.rmtree(args.const_serialization_path) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index b94fe59454..2c324d58de 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -183,7 +183,6 @@ def main(): habana_quantization_toolkit.finish_measurements(model) if args.const_serialization_path and os.path.isdir(args.const_serialization_path): import shutil - shutil.rmtree(args.const_serialization_path) From bc0993f77bd020e9680262e7f3db6618ed179351 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Sat, 23 Mar 2024 00:47:16 +0000 Subject: [PATCH 26/33] Cherry pick llama fp8 - enable non reuse cache flow for fp8 (#64) --- examples/text-generation/README.md | 1 - examples/text-generation/run_generation.py | 10 +-- examples/text-generation/run_lm_eval.py | 8 +- examples/text-generation/utils.py | 26 +++--- .../generation/configuration_utils.py | 3 - .../habana/transformers/generation/utils.py | 5 +- .../models/llama/modeling_llama.py | 79 +++++++++---------- 7 files changed, 56 insertions(+), 76 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 0cef3cb569..0f9a2c7b16 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -108,7 +108,6 @@ Here are a few settings you may be interested in: - `--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it - `--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it - `--fp8` Enable Quantization to fp8 -- `--kv_cache_fp8` Deprecated - Store kv-cache in float8 when kv-cache is used. should not be used with HQT(The Quantization Toolkit) For example, you can reproduce the results presented in [this blog post](https://huggingface.co/blog/habana-gaudi-2-bloom) with the following command: ```bash diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index aff06f288b..1f503ed5e1 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -221,11 +221,6 @@ def setup_parser(parser): help="Preprocess on cpu, and some other optimizations. Useful to prevent recompilations when using dynamic prompts (simulate_dyn_prompt)", ) - parser.add_argument( - "--kv_cache_fp8", - action="store_true", - help="Store kv-cache in float8 when kv-cache is used. Can't use this argument together with QUANT_CONFIG env var", - ) parser.add_argument("--fp8", action="store_true", help="Enable Quantization to fp8") parser.add_argument( "--use_flash_attention", @@ -259,10 +254,6 @@ def setup_parser(parser): args.limit_hpu_graphs = False args.quant_config = os.getenv("QUANT_CONFIG", "") - if args.quant_config and args.kv_cache_fp8: - # can't use both quant_config and kv_cache_fp8, since quant_config may trigger kv cache quantization - # with habana quantization toolkit - raise parser.error("Can't use QUANT_CONFIG env var with kv_cache_fp8 argument") return args @@ -573,6 +564,7 @@ def generate_dataset(batch): habana_quantization_toolkit.finish_measurements(model) if args.const_serialization_path and os.path.isdir(args.const_serialization_path): import shutil + shutil.rmtree(args.const_serialization_path) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 2c324d58de..8d61118890 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -136,12 +136,7 @@ def _model_call(self, inps): if self.options.static_shapes: bucket_length = self.find_bucket(seq_length) if self.options.use_cache and self.options.reuse_cache: - self.model.allocate_kv_cache( - bs, - bucket_length + 1, - bucket_length, - False, - ) + self.model.allocate_kv_cache(bs, bucket_length + 1, bucket_length) padding_length = bucket_length - seq_length inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id) logits = self.model(inps.to(self._device), **self.model_inputs)["logits"].cpu() @@ -183,6 +178,7 @@ def main(): habana_quantization_toolkit.finish_measurements(model) if args.const_serialization_path and os.path.isdir(args.const_serialization_path): import shutil + shutil.rmtree(args.const_serialization_path) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index ca42a7a021..ca5de7ce58 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -96,15 +96,15 @@ def setup_distributed(args): args.global_rank = int(os.getenv("RANK", "0")) -def setup_quantization(args, model): - import habana_frameworks.torch.core as htcore - from habana_frameworks.torch.hpu import hpu - - print("Initializing inference with quantization") - if not args.quant_config: - hpu.enable_quantization() - htcore.hpu_initialize(model) - return model +def setup_const_serialization(const_serialization_path): + import uuid + + const_serialization_path = os.path.join(const_serialization_path + uuid.uuid4().hex) + os.makedirs(const_serialization_path) + from habana_frameworks.torch.hpu import enable_const_section_serialization + + print("Serializing const params to {}".format(const_serialization_path)) + enable_const_section_serialization(const_serialization_path, False, True) def setup_env(args): @@ -346,7 +346,6 @@ def setup_generation_config(args, model, tokenizer): generation_config.reduce_recompile = args.reduce_recompile if generation_config.reduce_recompile: assert generation_config.bucket_size > 0 - generation_config.kv_cache_fp8 = args.kv_cache_fp8 generation_config.use_flash_attention = args.use_flash_attention return generation_config @@ -392,7 +391,12 @@ def initialize_model(args, logger): print("Serializing const params to {}".format(args.const_serialization_path)) enable_const_section_serialization(args.const_serialization_path, True) if args.fp8: - model = setup_quantization(args, model) + import habana_frameworks.torch.core as htcore + + print("Initializing inference mode") + const_marking = os.getenv("ENABLE_CONST_MARKING", "True") + if const_marking == "True": + htcore.hpu_initialize(model) init_end = time.perf_counter() logger.info(f"Args: {args}") logger.info(f"device: {args.device}, n_hpu: {args.world_size}, bf16: {model_dtype == torch.bfloat16}") diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index e75e48a7c7..93df1335db 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -29,8 +29,6 @@ class GaudiGenerationConfig(GenerationConfig): Only active if `static_shapes` is used. Can't be used with `reuse_cache`. bucket_internal (`bool`, *optional*): Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large. - kv_cache_fp8 (`bool`, *optional*): - Store kv-cache in float8 when kv-cache is used use_flash_attention (`bool`, *optional*): Whether to use flash attention optimization. flash_attention_recompute (`bool`, *optional*): @@ -48,7 +46,6 @@ def __init__(self, **kwargs): self.bucket_size = kwargs.get("bucket_size", -1) self.bucket_internal = kwargs.get("bucket_internal", None) self.reduce_recompile = kwargs.get("reduce_recompile", None) - self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None) self.use_flash_attention = kwargs.get("use_flash_attention", None) self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 9aa0f73194..cbc38d737b 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -733,10 +733,7 @@ def generate( bs, _ = input_ids.shape if not is_greedy_or_beam_and_bucket: unwrap_deepspeed_model(self).allocate_kv_cache( - bs * generation_config.num_beams, - calculated_max_length, - token_idx, - generation_config.kv_cache_fp8, + bs * generation_config.num_beams, calculated_max_length, token_idx ) model_kwargs["kv_cache_len"] = calculated_max_length diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 4f526506b0..15c74dc15b 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -46,25 +46,6 @@ FusedSDPA = None -def update(prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.dtype == torch.float8_e4m3fn: - from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2 - - cur = cast_to_fp8_v2(cur, None, False, False, prev.dtype)[0] - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - prev_cast = prev.to(orig_cur.dtype) - return prev_cast - else: - return torch.cat((prev, cur), dim=dim) - - def gaudi_llama_rmsnorm_forward(self, hidden_states): """ Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -171,11 +152,9 @@ def __init__(self): self.cache = None self.inp_seq_len = -1 - def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): + def allocate(self, inp_seq_len, dtype, device, shape): if self.cache is None or self.cache.shape != shape: self.inp_seq_len = inp_seq_len - if kv_cache_fp8: - dtype = torch.float8_e4m3fn self.cache = torch.zeros(shape, dtype=dtype, device=device) else: assert ( @@ -183,13 +162,29 @@ def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" self.cache.fill_(0) + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + def get_shape(self): if self.cache is None: return None return self.cache.shape def forward(self, cur, dim, idx): - return update(self.cache, cur, dim, idx, self.inp_seq_len) + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) class GaudiLlamaRotaryEmbedding(torch.nn.Module): @@ -273,12 +268,12 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) device = self.k_proj.weight.device dtype = self.config.torch_dtype - self.k_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) - self.v_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) def update_sincos_cache(self, seq_len): # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings @@ -373,14 +368,21 @@ def pre_attn_forward( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None or reuse_cache: + if use_cache: # reuse k, v, self_attention if reuse_cache: key_states = self.k_cache(key_states, 2, token_idx) value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: - key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) - value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) if cache_idx is not None and q_len == 1: key_states = key_states[:, :, :cache_idx, :] @@ -388,12 +390,6 @@ def pre_attn_forward( if attention_mask is not None: attention_mask = attention_mask[:, :, :, :cache_idx] kv_seq_len = key_states.shape[-2] - - if use_cache: - if reuse_cache: - past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) - else: - past_key_value = (key_states.contiguous(), value_states.contiguous()) else: past_key_value = None @@ -475,8 +471,8 @@ def __init__(self, config: LlamaConfig, layer_idx: int): self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): - self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.self_attn.reorder_kv_cache(beam_idx) @@ -631,9 +627,9 @@ def __init__(self, config: LlamaConfig): # Initialize weights and apply final processing self.post_init() - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): for layer in self.layers: - layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) @@ -822,9 +818,8 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): - add new args reuse_cache """ - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): - self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) - self.kv_cache_len = max_seq_len + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) From 172406dc06a91f9148bb1037a5732b5a518caa2d Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Sat, 23 Mar 2024 00:59:38 +0000 Subject: [PATCH 27/33] Fix merge for PR766 --- examples/text-generation/utils.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index ca5de7ce58..54d08d017f 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -382,14 +382,7 @@ def initialize_model(args, logger): generation_config = setup_generation_config(args, model, tokenizer) if args.const_serialization_path: - import uuid - - args.const_serialization_path = os.path.join(args.const_serialization_path + uuid.uuid4().hex) - os.makedirs(args.const_serialization_path) - from habana_frameworks.torch.hpu import enable_const_section_serialization - - print("Serializing const params to {}".format(args.const_serialization_path)) - enable_const_section_serialization(args.const_serialization_path, True) + setup_const_serialization(args.const_serialization_path) if args.fp8: import habana_frameworks.torch.core as htcore From 799a1b7f84d97613cab741e5821048b0314bbbf9 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Sat, 23 Mar 2024 06:45:39 +0000 Subject: [PATCH 28/33] Fix falcon reuse_cache issue. --- optimum/habana/transformers/generation/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index cbc38d737b..b48c5317d8 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -584,6 +584,7 @@ def generate( assert self.config.model_type in [ "llama", "mistral", + "falcon", ], "reuse_cache only supported by llama and mistral at the moment" if not generation_config.bucket_internal: assert ( From 1deba9b4666b4de0c3f8bbe35e15122c212eccbe Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Sat, 23 Mar 2024 07:20:27 +0000 Subject: [PATCH 29/33] Fix text-generation fp8 test env issue. --- tests/test_text_generation_example.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index e7c4a98a2f..0c9dac3417 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -129,15 +129,15 @@ def _test_text_generation( command = [x for y in command for x in re.split(pattern, y) if x] if fp8: - os.environ["QUANT_CONFIG"] = os.path.join( + env_variables["QUANT_CONFIG"] = os.path.join( path_to_example_dir, "text-generation/quantization_config/maxabs_measure_include_outputs.json" ) - subprocess.run(command) - os.environ["QUANT_CONFIG"] = os.path.join( + subprocess.run(command, env=env_variables) + env_variables["QUANT_CONFIG"] = os.path.join( path_to_example_dir, "text-generation/quantization_config/maxabs_quant.json" ) - proc = subprocess.run(command) + proc = subprocess.run(command, env=env_variables) # Ensure the run finished without any issue # Use try-except to avoid logging the token if used From 460f923dc2cb7e4ffd3fbfa939b43fcc704121f2 Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Mon, 25 Mar 2024 11:10:30 -0700 Subject: [PATCH 30/33] fix fp8 key error --- tests/test_text_generation_example.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 0c9dac3417..273f297cec 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -59,6 +59,7 @@ ("mistralai/Mistral-7B-v0.1", 40.00435417311187), ("microsoft/phi-2", 90.10751623430603), ], + "fp8": [], "deepspeed": [ ("bigscience/bloomz-7b1", 31.044523676681507), ], From 5496c6b5ee98f36bc1d3c538566448957642607b Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Mon, 25 Mar 2024 13:24:24 -0700 Subject: [PATCH 31/33] disable fsdp tests for gaudi1 --- tests/test_fsdp_examples.py | 60 +++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/tests/test_fsdp_examples.py b/tests/test_fsdp_examples.py index 248647cb53..777f441bfb 100644 --- a/tests/test_fsdp_examples.py +++ b/tests/test_fsdp_examples.py @@ -10,35 +10,37 @@ from .test_examples import ACCURACY_PERF_FACTOR, TIME_PERF_FACTOR -# Gaudi2 CI baselines -# FSDP is not supported on Gaudi1 -MODELS_TO_TEST = { - "bf16": [ - ( - "bert-base-uncased", - "Habana/bert-base-uncased", - 2807, - 85.4688, - "question-answering", - 24, - 8, - "run_qa.py", - "full_shard", - ), - ( - "meta-llama/Llama-2-7b-hf", - "", - 54, - 0.92, - "language-modeling", - 8, - 8, - "run_lora_clm.py", - "auto_wrap", - ), - ], -} - +if os.environ.get("GAUDI2_CI", "0") == "1": + # Gaudi2 CI baselines + MODELS_TO_TEST = { + "bf16": [ + ( + "bert-base-uncased", + "Habana/bert-base-uncased", + 2807, + 85.4688, + "question-answering", + 24, + 8, + "run_qa.py", + "full_shard", + ), + ( + "meta-llama/Llama-2-7b-hf", + "", + 54, + 0.92, + "language-modeling", + 8, + 8, + "run_lora_clm.py", + "auto_wrap", + ), + ], + } +else: + # FSDP is not supported on Gaudi1 + MODELS_TO_TEST = {"bf16":[]} def _test_fsdp( model_name: str, From ae98c19a15f508dd79a5e16828051c61b3b1dd92 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Wed, 27 Mar 2024 14:52:09 +0000 Subject: [PATCH 32/33] Make style --- optimum/habana/transformers/modeling_utils.py | 1 - .../habana/transformers/models/__init__.py | 1 - .../transformers/models/wav2vec2/__init__.py | 1 - .../models/wav2vec2/modeling_wav2vec2.py | 83 ------------------- tests/test_fsdp_examples.py | 3 +- 5 files changed, 2 insertions(+), 87 deletions(-) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index a6fa20d66f..6dc40a73bf 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -135,7 +135,6 @@ gaudi_vit_self_attention_forward, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, - gaudi_wav2vec2forctc_forward, gaudi_wav2vec2_tdnnlayer_forward, gaudi_wav2vec2forctc_forward, ) diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index ee84d9159c..1582d3f09e 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -146,7 +146,6 @@ _gaudi_wav2vec2_sample_negative_indices, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, - gaudi_wav2vec2forctc_forward, gaudi_wav2vec2_tdnnlayer_forward, gaudi_wav2vec2forctc_forward, ) diff --git a/optimum/habana/transformers/models/wav2vec2/__init__.py b/optimum/habana/transformers/models/wav2vec2/__init__.py index 53b78fdefd..84372061b6 100644 --- a/optimum/habana/transformers/models/wav2vec2/__init__.py +++ b/optimum/habana/transformers/models/wav2vec2/__init__.py @@ -4,7 +4,6 @@ _gaudi_wav2vec2_sample_negative_indices, gaudi_wav2vec2_encoder_forward, gaudi_wav2vec2_forward, - gaudi_wav2vec2forctc_forward, gaudi_wav2vec2_tdnnlayer_forward, gaudi_wav2vec2forctc_forward, ) diff --git a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py index 9478f36267..a5cb5adf74 100644 --- a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -378,89 +378,6 @@ def gaudi_wav2vec2_encoder_forward( ) -_HIDDEN_STATES_START_POSITION = 2 - - -def gaudi_wav2vec2forctc_forward( - self, - input_values: Optional[torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - labels: Optional[torch.Tensor] = None, -) -> Union[Tuple, CausalLMOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): - Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to - the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. - All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., - config.vocab_size - 1]`. - """ - """ - copied from Transformers https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1950 - only differences are (1) attention_mask tensor generation using ones_like is done on HPU, (2) masked_select is not applied on labels to compute flattened_targets to avoid - changing flattened_targets tensor shapes across training iterations. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.wav2vec2( - input_values, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states) - logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - if labels.max() >= self.config.vocab_size: - raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") - # retrieve loss input_lengths from attention_mask - attention_mask = ( - attention_mask - if attention_mask is not None - else torch.ones_like(input_values, dtype=torch.long, device="hpu") - ) - input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) - # assuming that padded tokens are filled with -100 - # when not being attended to - labels_mask = labels >= 0 - target_lengths = labels_mask.sum(-1) - # ctc_loss doesn't support fp16 - log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) - if get_device_name() == "GAUDI": - flattened_targets = labels.masked_select(labels_mask) - with torch.backends.cudnn.flags(enabled=False): - loss = torch.nn.functional.ctc_loss( - log_probs, - flattened_targets, - input_lengths, - target_lengths, - blank=self.config.pad_token_id, - reduction=self.config.ctc_loss_reduction, - zero_infinity=self.config.ctc_zero_infinity, - ) - else: - flattened_targets = labels - with torch.backends.cudnn.flags(enabled=False): - loss = ctc_loss_fwd( - log_probs, - flattened_targets, - input_lengths, - target_lengths, - self.config.pad_token_id, - self.config.ctc_loss_reduction, - self.config.ctc_zero_infinity, - ) - - if not return_dict: - output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] - return ((loss,) + output) if loss is not None else output - return CausalLMOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) - - def gaudi_wav2vec2_tdnnlayer_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Copied from Transformers: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L2290 diff --git a/tests/test_fsdp_examples.py b/tests/test_fsdp_examples.py index 777f441bfb..aec5cccc8f 100644 --- a/tests/test_fsdp_examples.py +++ b/tests/test_fsdp_examples.py @@ -40,7 +40,8 @@ } else: # FSDP is not supported on Gaudi1 - MODELS_TO_TEST = {"bf16":[]} + MODELS_TO_TEST = {"bf16": []} + def _test_fsdp( model_name: str, From f168ddbd67de65e863c348db0454a9136a798a44 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Wed, 27 Mar 2024 15:13:49 +0000 Subject: [PATCH 33/33] Remove ctc_loss_fwd in modeling_wav2vec2 --- .../habana/transformers/models/wav2vec2/modeling_wav2vec2.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py index a5cb5adf74..c6dd9cb546 100644 --- a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -36,9 +36,6 @@ custom_ctc_loss_fwd = None -ctc_loss_fwd = CTCLoss.apply - - def _gaudi_wav2vec2_compute_mask_indices( shape: Tuple[int, int], mask_prob: float,