From d9d4fc079cc065c3d0f42efe2c5eca0e6bc4ea3b Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Fri, 29 Mar 2024 23:08:38 +0000 Subject: [PATCH 01/12] Release: v1.11.0 --- optimum/habana/version.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/version.py b/optimum/habana/version.py index 497fd1da79..714a1d7075 100644 --- a/optimum/habana/version.py +++ b/optimum/habana/version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.11.0.dev0" +__version__ = "1.11.0" diff --git a/setup.py b/setup.py index 5904ba2ea1..c822c41e7f 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ QUALITY_REQUIRES = [ "ruff", - "hf_doc_builder @ git+https://github.com/huggingface/doc-builder.git", + "hf_doc_builder", ] EXTRAS_REQUIRE = { From 4160e9c8a068a515e3981b37690cebe1551ffe8c Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Sat, 30 Mar 2024 02:30:23 -0700 Subject: [PATCH 02/12] Fix fp8 ci (#852) Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- tests/test_text_generation_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 8f3da77526..48150d635c 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -136,7 +136,7 @@ def _test_text_generation( env_variables["QUANT_CONFIG"] = os.path.join( path_to_example_dir, "text-generation/quantization_config/maxabs_quant.json" ) - command.insert(-1, "--fp8") + command.insert(-2, "--fp8") proc = subprocess.run(command, env=env_variables) From b0eefc55da46276b75fdf17b75f67d33ff92f0f3 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Sat, 30 Mar 2024 10:41:13 +0100 Subject: [PATCH 03/12] Fix PR #848 (#853) --- tests/test_examples.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index 5cf2559f5f..6eefb5c571 100755 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -276,7 +276,7 @@ def test(self): self.assertEqual(return_code, 0) return elif self.EXAMPLE_NAME == "run_clip": - if not os.environ.get("DATA_CACHE", "0"): + if os.environ.get("DATA_CACHE", None) is None: from .clip_coco_utils import COCO_URLS, download_files download_files(COCO_URLS) @@ -327,8 +327,8 @@ def test(self): extra_command_line_arguments = baseline.get("distribution").get(distribution).get("extra_arguments", []) - if os.environ.get("DATA_CACHE", "0") and self.EXAMPLE_NAME == "run_clip": - extra_command_line_arguments[0] = "--data_dir {}".format(os.environ.get("DATA_CACHE", "$PWD")) + if os.environ.get("DATA_CACHE", None) is not None and self.EXAMPLE_NAME == "run_clip": + extra_command_line_arguments[0] = "--data_dir {}".format(os.environ["DATA_CACHE"]) with TemporaryDirectory() as tmp_dir: cmd_line = self._create_command_line( @@ -410,7 +410,7 @@ def _create_command_line( task: Optional[str] = None, extra_command_line_arguments: Optional[List[str]] = None, ) -> List[str]: - dataset_name = self.DATASET_NAME if self.DATASET_NAME else task + dataset_name = self.DATASET_NAME if self.DATASET_NAME is not None else task task_option = f"--{self.DATASET_PARAMETER_NAME} {dataset_name}" if task else " " cmd_line = ["python3"] @@ -583,7 +583,7 @@ class MultiCardSpeechRecognitionExampleTester( ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_speech_recognition_ctc", multi_card=True ): TASK_NAME = "regisss/librispeech_asr_for_optimum_habana_ci" - DATASET_NAME = os.environ.get("DATA_CACHE", 0) + DATASET_NAME = os.environ.get("DATA_CACHE", None) class MultiCardSummarizationExampleTester( From 8ee87de9085d13593c1b20ada9b1b99a0715da69 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Sat, 30 Mar 2024 11:00:50 +0100 Subject: [PATCH 04/12] Disable safe loading tests in CI (#854) --- tests/test_trainer.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 1963f805c7..2cb3147523 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -49,6 +49,7 @@ get_gpu_count, get_tests_dir, is_staging_test, + parse_flag_from_env, require_optuna, require_safetensors, require_sentencepiece, @@ -90,6 +91,20 @@ PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" +_run_safe_loading_tests_ = parse_flag_from_env("SAFE_LOADING_TESTS", default=False) + + +def safe_loading_test(test_case): + """ + Decorator marking a test as needing custom bf16 ops. + Custom bf16 ops must be declared before `habana_frameworks.torch.core` is imported, which is not possible if some other tests are executed before. + + Such tests are skipped by default. Set the CUSTOM_BF16_OPS environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_safe_loading_tests_, "test requires SAFE_LOADING_TESTS")(test_case) + + class RegressionDataset: def __init__(self, a=2, b=3, length=64, seed=42, label_names=None): np.random.seed(seed) @@ -1465,6 +1480,7 @@ def test_training_with_resume_from_checkpoint_false(self): trainer.train(resume_from_checkpoint=False) + @safe_loading_test @require_safetensors def test_resume_training_with_safe_checkpoint(self): # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of @@ -1658,6 +1674,7 @@ def test_load_best_model_at_end(self): self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False) self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False) + @safe_loading_test @require_safetensors def test_load_best_model_from_safetensors(self): total = int(self.n_epochs * 64 / self.batch_size) From 1c7b2cd3c05db740a2345aa3a1af75a6b5aff946 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Sun, 31 Mar 2024 12:39:12 +0200 Subject: [PATCH 05/12] Update QA example --- examples/question-answering/run_qa.py | 8 ++++++++ tests/example_diff/run_qa.txt | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/question-answering/run_qa.py b/examples/question-answering/run_qa.py index e95e014f92..3ed40396da 100644 --- a/examples/question-answering/run_qa.py +++ b/examples/question-answering/run_qa.py @@ -649,6 +649,14 @@ def post_processing_function(examples, features, predictions, stage="eval"): references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples] return EvalPrediction(predictions=formatted_predictions, label_ids=references) + if data_args.version_2_with_negative: + accepted_best_metrics = ("exact", "f1", "HasAns_exact", "HasAns_f1") + else: + accepted_best_metrics = ("exact_match", "f1") + + if training_args.load_best_model_at_end and training_args.metric_for_best_model not in accepted_best_metrics: + warnings.warn(f"--metric_for_best_model should be set to one of {accepted_best_metrics}") + metric = evaluate.load( "squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir ) diff --git a/tests/example_diff/run_qa.txt b/tests/example_diff/run_qa.txt index ec84749d7f..2f43379f2c 100644 --- a/tests/example_diff/run_qa.txt +++ b/tests/example_diff/run_qa.txt @@ -67,9 +67,9 @@ > if tokenizer.pad_token is None: > tokenizer.add_special_tokens({"pad_token": "[PAD]"}) > tokenizer.cls_token = tokenizer.bos_token -639a662 +647a670 > gaudi_config=gaudi_config, -708,712d730 +716,720d738 < < < def _mp_fn(index): From c87d3123011c8fe41850c6b8b56de7bc412b59a4 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Sun, 31 Mar 2024 12:41:01 +0200 Subject: [PATCH 06/12] Update Bert large Gaudi1 CI baseline --- tests/baselines/bert_large_uncased_whole_word_masking.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/baselines/bert_large_uncased_whole_word_masking.json b/tests/baselines/bert_large_uncased_whole_word_masking.json index d153328e4a..c9d67aeeea 100644 --- a/tests/baselines/bert_large_uncased_whole_word_masking.json +++ b/tests/baselines/bert_large_uncased_whole_word_masking.json @@ -47,8 +47,8 @@ "learning_rate": 3e-5, "train_batch_size": 16, "eval_f1": 0.8897, - "train_runtime": 64.4986, - "train_samples_per_second": 968.596, + "train_runtime": 65.644, + "train_samples_per_second": 919.623, "extra_arguments": [ "--max_seq_length 128", "--use_hpu_graphs_for_inference" From 445be214769ffd0c11610471a09b3bc56e19b420 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Sun, 31 Mar 2024 04:51:00 -0700 Subject: [PATCH 07/12] Add warmup for eval (#855) Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- optimum/habana/transformers/trainer.py | 73 ++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 4f22bdde37..c91d0724f8 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -1621,6 +1621,73 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + def evaluate( + self, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> Dict[str, float]: + """ + From https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/trainer.py#L3162 with the following modification + 1. comment out TPU related + 2. use throughput_warmup_steps in evaluation throughput calculation + """ + # handle multipe eval datasets + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + if isinstance(eval_dataset, dict): + metrics = {} + for eval_dataset_name, _eval_dataset in eval_dataset.items(): + dataset_metrics = self.evaluate( + eval_dataset=_eval_dataset, + ignore_keys=ignore_keys, + metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}", + ) + metrics.update(dataset_metrics) + return metrics + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + eval_dataloader = self.get_eval_dataloader(eval_dataset) + + start_time = time.time() + self.start_time_after_warmup = None + + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if self.compute_metrics is None else None, + ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + num_samples = output.num_samples - self.args.throughput_warmup_steps * total_batch_size + num_steps = math.ceil(output.num_samples / total_batch_size) - self.args.throughput_warmup_steps + + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=num_samples, + num_steps=num_steps, + start_time_after_warmup=self.start_time_after_warmup, + ) + ) + + self.log(output.metrics) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return output.metrics + def evaluation_loop( self, dataloader: DataLoader, @@ -1716,6 +1783,12 @@ def evaluation_loop( observed_num_examples = 0 # Main evaluation loop for step, inputs in enumerate(dataloader): + if ( + self.args.throughput_warmup_steps > 0 + and not self.is_in_train + and step == self.args.throughput_warmup_steps + ): + self.start_time_after_warmup = time.time() # Update the observed num examples observed_batch_size = find_batch_size(inputs) if observed_batch_size is not None: From eaac913c56617a4deefa99548d89366722b7397e Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Wed, 3 Apr 2024 03:32:50 -0700 Subject: [PATCH 08/12] Fix mistral after syn1.15 update (#858) --- optimum/habana/transformers/models/mistral/modeling_mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index bda8738516..cf5fa6f2c0 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -518,7 +518,7 @@ def forward( class GaudiMistralForCausalLM(MistralForCausalLM): - def allocate_kv_cache(self, batch_size, seq_len, _, __): + def allocate_kv_cache(self, batch_size, seq_len, _): self.model.allocate_kv_cache(batch_size, seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): From 58503c59f71c11240ceb6580a4548b7e66071b79 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Thu, 4 Apr 2024 08:54:50 -0700 Subject: [PATCH 09/12] Fp8 merge fix (#863) Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- examples/text-generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 54d08d017f..a83242e8b3 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -104,7 +104,7 @@ def setup_const_serialization(const_serialization_path): from habana_frameworks.torch.hpu import enable_const_section_serialization print("Serializing const params to {}".format(const_serialization_path)) - enable_const_section_serialization(const_serialization_path, False, True) + enable_const_section_serialization(const_serialization_path, True) def setup_env(args): From 84e824104ae11851ecb34166fad3d2be04ba2bc6 Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Fri, 5 Apr 2024 15:22:36 +0530 Subject: [PATCH 10/12] Add mark step and inplace residual add in llama model code (#833) Signed-off-by: Puneesh Khanna --- .../habana/transformers/generation/utils.py | 4 ++ .../models/llama/modeling_llama.py | 53 +++++++++++++------ 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 92df17bb50..b46481da24 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1431,6 +1431,7 @@ def greedy_search( ) # prepare model inputs + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -1782,6 +1783,7 @@ def sample( break # prepare model inputs + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -2228,6 +2230,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1): params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile ) + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # if sequential is True, split the input to batches of batch_size and run sequentially @@ -3008,6 +3011,7 @@ def constrained_beam_search( if this_peer_finished_flag.item() == 0.0: break + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 4d0f3513d7..4bcc32b17b 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -45,6 +45,8 @@ print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None +import habana_frameworks.torch.core as htcore + def gaudi_llama_rmsnorm_forward(self, hidden_states): """ @@ -514,7 +516,7 @@ def forward( ) residual = hidden_states - output_pre_attn, self_attn_weights, present_key_value = self.pre_attn( + hidden_states, self_attn_weights, present_key_value = self.pre_attn( hidden_states, attention_mask, position_ids, @@ -530,13 +532,12 @@ def forward( cache_idx=cache_idx, **kwargs, ) + self.self_attn.attention_all_reduce(hidden_states) + hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.post_mlp(hidden_states, residual) - self.self_attn.attention_all_reduce(output_pre_attn) - output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual) - self.mlp.mlp_all_reduce(output_post_attn_pre_mlp) - output_post_mlp = self.post_mlp(output_post_attn_pre_mlp, residual_mlp) - - outputs = (output_post_mlp,) + outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) @@ -562,7 +563,7 @@ def pre_attn( cache_idx: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) - output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( + hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward( hidden_states, attention_mask, position_ids, @@ -577,23 +578,33 @@ def pre_attn( flash_attention_recompute, cache_idx=cache_idx, ) - return output_attn, attn_weights, present_key_value + return hidden_states, attn_weights, present_key_value - def post_attn_pre_mlp(self, input, residual): - output_post_attn = self.self_attn.post_attn_forward(input) + def post_attn_pre_mlp(self, hidden_states, residual): + hidden_states = self.self_attn.post_attn_forward(hidden_states) - hidden_states = residual + output_post_attn - residual = hidden_states + if self.training: + hidden_states = hidden_states + residual + residual = hidden_states + else: + residual.add_(hidden_states) + hidden_states = residual hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp.pre_mlp_forward(hidden_states) return hidden_states, residual - def post_mlp(self, input, residual): - output_post_mlp = self.mlp.post_mlp_forward(input) - output = output_post_mlp + residual - return output + def post_mlp(self, hidden_states, residual): + hidden_states = self.mlp.post_mlp_forward(hidden_states) + + if self.training: + hidden_states = hidden_states + residual + else: + residual.add_(hidden_states) + hidden_states = residual + + return hidden_states class GaudiLlamaModel(LlamaModel): @@ -658,6 +669,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -667,6 +679,7 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -743,6 +756,9 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if not use_new_cache else None + if lazy_mode: + htcore.mark_step() + for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -849,6 +865,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -877,6 +894,7 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, cache_idx=cache_idx, + lazy_mode=lazy_mode, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -1010,6 +1028,7 @@ def prepare_inputs_for_generation( "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "cache_idx": kwargs.get("cache_idx"), + "lazy_mode": kwargs.get("lazy_mode"), } ) return model_inputs From af85fd0b321593bb686f9af6f3b6ccc73adfc0ac Mon Sep 17 00:00:00 2001 From: Witold Szczurek <152967125+wszczurekhabana@users.noreply.github.com> Date: Fri, 5 Apr 2024 23:31:18 +0200 Subject: [PATCH 11/12] Enable Flash Attention in recompute and causal modes (#21) (#862) Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> Co-authored-by: Libin Tang --- examples/language-modeling/README.md | 3 +- examples/language-modeling/run_lora_clm.py | 12 ++++- examples/text-generation/README.md | 24 +++++++++ examples/text-generation/run_generation.py | 54 +++++++++++++++++++ examples/text-generation/utils.py | 2 + .../generation/configuration_utils.py | 3 ++ .../habana/transformers/generation/utils.py | 2 + .../models/llama/modeling_llama.py | 27 ++++++++-- optimum/habana/transformers/trainer.py | 4 ++ 9 files changed, 125 insertions(+), 6 deletions(-) diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index 776993aca1..72853a6811 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -556,7 +556,8 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \ --lora_rank 4 \ --lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \ --validation_split_percentage 4 \ - --use_flash_attention True + --use_flash_attention True \ + --flash_attention_causal_mask True ``` - Multi-card finetuning of Llama2-70B with FSDP and LoRA: diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index d057fc3c94..91d139ce6b 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -155,6 +155,15 @@ class ModelArguments: ) }, ) + flash_attention_causal_mask: bool = field( + default=False, + metadata={ + "help": ( + "Whether to enable causal mask in Habana flash attention for fine-tuning." + " It is applicable only when use_flash_attention is True.", + ) + }, + ) use_fused_rope: bool = field( default=True, metadata={ @@ -547,7 +556,8 @@ def main(): if model_args.use_flash_attention: model.generation_config.use_flash_attention = True model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute - if model_args.use_fused_rope is False: + model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask + if not model_args.use_fused_rope: model.generation_config.use_fused_rope = False if hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None: diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 0f9a2c7b16..e6563e433d 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -354,6 +354,30 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ ``` `--fp8` is required to enable quantization in fp8. +### Using Habana Flash Attention + +Habana Flash Attention addresses large sequence lengths on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes. + +Below example uses `flash_attention_recompute` mode in order to reduce memory consumption on prompt stage. Additionally since all sequences in a batch are of the same length it uses `flash_attention_causal_mask` which will further improve performance by taking advantage of specific lower-diagonal shape of inputs to softmax operation. + +```bash +python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ +--model_name_or_path meta-llama/Llama-2-70b-hf \ +--use_hpu_graphs \ +--use_kv_cache \ +--reuse_cache \ +--trim_logits \ +--attn_softmax_bf16 \ +--max_input_tokens 31744 \ +--max_new_tokens 1024 \ +--batch_size=12 \ +--use_flash_attention \ +--flash_attention_recompute \ +--flash_attention_causal_mask \ +--book_source +``` + +For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa). ## Language Model Evaluation Harness diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 1f503ed5e1..0b4ff8b5af 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -227,6 +227,21 @@ def setup_parser(parser): action="store_true", help="Whether to enable Habana Flash Attention, provided that the model supports it.", ) + parser.add_argument( + "--flash_attention_recompute", + action="store_true", + help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.", + ) + parser.add_argument( + "--flash_attention_causal_mask", + action="store_true", + help="Whether to enable Habana Flash Attention in causal mode on first token generation.", + ) + parser.add_argument( + "--book_source", + action="store_true", + help="Whether to use project Guttenberg books data as input. Usefull for testing large sequence lenghts.", + ) parser.add_argument( "--torch_compile", action="store_true", @@ -272,6 +287,45 @@ def main(): # Benchmark over the prompts below if args.prompt: input_sentences = args.prompt + elif args.book_source: + + def download_book(book_id): + import os + + import requests + + url = f"https://www.gutenberg.org/cache/epub/{book_id}/pg{book_id}.txt" + response = requests.get(url) + if response.status_code == 200: + pid = os.getpid() + save_path = f"/tmp/{book_id}_{pid}.txt" + with open(save_path, "wb") as file: + file.write(response.content) + print(f"Book downloaded and saved to: {save_path}") + return save_path + else: + print("Failed to download book! Exiting...") + import sys + + sys.exit() + + def assemble_prompt(prompt_size, book_path): + prompt = "" + counter = 0 + book_lines = open(book_path).readlines() + for line in book_lines: + for word in line.split(): + counter += 1 + prompt += word + " " + if counter == prompt_size: + return [prompt] * args.batch_size + + book_ids = [ + 2701, # Moby Dick; Or, The Whale + 1513, # Romeo and Juliet + 1342, # Pride and Prejudice + ] + input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0])) else: input_sentences = [ "DeepSpeed is a machine learning framework", diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index a83242e8b3..9b60fe4920 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -347,6 +347,8 @@ def setup_generation_config(args, model, tokenizer): if generation_config.reduce_recompile: assert generation_config.bucket_size > 0 generation_config.use_flash_attention = args.use_flash_attention + generation_config.flash_attention_recompute = args.flash_attention_recompute + generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask return generation_config diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 93df1335db..61585b559f 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -33,6 +33,8 @@ class GaudiGenerationConfig(GenerationConfig): Whether to use flash attention optimization. flash_attention_recompute (`bool`, *optional*): Whether to enable recompute if use Habana flash attention. + flash_attention_causal_mask (`bool`, *optional*): + Whether to enable causal_mask if use Habana flash attention. """ def __init__(self, **kwargs): @@ -48,4 +50,5 @@ def __init__(self, **kwargs): self.reduce_recompile = kwargs.get("reduce_recompile", None) self.use_flash_attention = kwargs.get("use_flash_attention", None) self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None) + self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index b46481da24..0d50470532 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -726,6 +726,8 @@ def generate( # determine whether flash attention needs to be used model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False + model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False + if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[-1] if not generation_config.static_shapes and generation_config.max_new_tokens is not None: diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 4bcc32b17b..1381f30b1e 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -313,6 +313,7 @@ def pre_attn_forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -325,6 +326,7 @@ def pre_attn_forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ bsz, q_len, _ = hidden_states.size() @@ -408,10 +410,15 @@ def pre_attn_forward( ) else: # first token - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same length + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) else: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( @@ -498,6 +505,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -509,6 +517,7 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ if "padding_mask" in kwargs: warnings.warn( @@ -529,6 +538,7 @@ def forward( reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, **kwargs, ) @@ -560,6 +570,7 @@ def pre_attn( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) @@ -576,6 +587,7 @@ def pre_attn( reuse_cache, use_flash_attention, flash_attention_recompute, + flash_attention_causal_mask, cache_idx=cache_idx, ) return hidden_states, attn_weights, present_key_value @@ -668,6 +680,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -679,6 +692,7 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -778,6 +792,7 @@ def forward( False, use_flash_attention, flash_attention_recompute, + flash_attention_causal_mask, ) else: layer_outputs = decoder_layer( @@ -793,6 +808,7 @@ def forward( reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, ) hidden_states = layer_outputs[0] @@ -864,6 +880,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: @@ -893,6 +910,7 @@ def forward( reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, lazy_mode=lazy_mode, ) @@ -1027,6 +1045,7 @@ def prepare_inputs_for_generation( "reuse_cache": reuse_cache, "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), "cache_idx": kwargs.get("cache_idx"), "lazy_mode": kwargs.get("lazy_mode"), } diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index c91d0724f8..dc6e136a41 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -925,6 +925,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if self.model.generation_config.flash_attention_causal_mask: + inputs["flash_attention_causal_mask"] = True # TODO: keep syncs for fast DDP? with self.accelerator.accumulate(model): @@ -1806,6 +1808,8 @@ def evaluation_loop( inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if self.model.generation_config.flash_attention_causal_mask: + inputs["flash_attention_causal_mask"] = True # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) From 8cdeee1db8951580708528291b9efe7c097af6be Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Mon, 8 Apr 2024 21:35:55 +0000 Subject: [PATCH 12/12] Add mark_step for inference(Propagate OHF PRs 126, 96, 75) --- optimum/habana/transformers/models/llama/modeling_llama.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 1381f30b1e..6ab8bcf583 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -774,6 +774,13 @@ def forward( htcore.mark_step() for layer_idx, decoder_layer in enumerate(self.layers): + if ( + lazy_mode + and not self.training + and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) + ): + htcore.mark_step() + if output_hidden_states: all_hidden_states += (hidden_states,)