From c23fca0b413cbd701182987aaffa0f5b6126f997 Mon Sep 17 00:00:00 2001 From: Bhargav Date: Thu, 21 Dec 2023 11:28:43 +0200 Subject: [PATCH 1/3] Adding support for bf16_full_eval --- optimum/habana/transformers/trainer.py | 17 +++++++++++++++++ optimum/habana/transformers/training_args.py | 3 --- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 0ff64b560a..6eca5f280a 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -428,6 +428,11 @@ def train( self.is_in_train = True + # do_train is not a reliable argument, as it might not be set and .train() still called, so + # the following is a workaround: + if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: + self._move_model_to_device(self.model, args.device) + if "model_path" in kwargs: resume_from_checkpoint = kwargs.pop("model_path") warnings.warn( @@ -1510,6 +1515,12 @@ def evaluation_loop( ) self.already_wrapped_for_hpu_graphs = True + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + batch_size = self.args.eval_batch_size logger.info(f"***** Running {description} *****") @@ -1903,6 +1914,12 @@ def prediction_loop( ) self.already_wrapped_for_hpu_graphs = True + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + batch_size = dataloader.batch_size num_examples = self.num_examples(dataloader) logger.info(f"***** Running {description} *****") diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 740e682f75..2cc6917c0d 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -54,7 +54,6 @@ # List of arguments that are not supported by optimum-habana UNSUPPORTED_ARGUMENTS = [ - "bf16_full_eval", "fp16", "fp16_backend", "fp16_full_eval", @@ -314,8 +313,6 @@ def __post_init__(self): raise ValueError("must be using hpu graphs to set max_hpu_graphs.") # Raise errors for arguments that are not supported by optimum-habana - if self.bf16_full_eval: - raise ValueError("--bf16_full_eval is not supported by optimum-habana.") if self.fp16 or self.fp16_full_eval: raise ValueError( "--fp16, --fp16_backend, --fp16_full_eval and --fp16_opt_level are not" From 44dfc973d780e8412c0d0db1a6bc0bab811a8488 Mon Sep 17 00:00:00 2001 From: Bhargav Date: Wed, 27 Dec 2023 07:45:44 +0200 Subject: [PATCH 2/3] Adding changes for converting model for fp16 flag --- examples/summarization/README.md | 3 ++- optimum/habana/transformers/trainer.py | 8 ++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/summarization/README.md b/examples/summarization/README.md index 8ebed989ed..3bdf382c72 100644 --- a/examples/summarization/README.md +++ b/examples/summarization/README.md @@ -227,7 +227,8 @@ python run_summarization.py \ --gaudi_config_name Habana/t5 \ --ignore_pad_token_for_loss False \ --pad_to_max_length \ - --bf16 + --bf16 \ + --bf16_full_eval ``` You can run inference with BART on the CNN-DailyMail dataset on 1 Gaudi card with the following command: diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 6eca5f280a..c6c13ba790 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -1518,7 +1518,9 @@ def evaluation_loop( # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # while ``train`` is running, cast it to the right dtype first and then put on device if not self.is_in_train: - if args.bf16_full_eval: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: model = model.to(dtype=torch.bfloat16, device=args.device) batch_size = self.args.eval_batch_size @@ -1917,7 +1919,9 @@ def prediction_loop( # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # while ``train`` is running, cast it to the right dtype first and then put on device if not self.is_in_train: - if args.bf16_full_eval: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: model = model.to(dtype=torch.bfloat16, device=args.device) batch_size = dataloader.batch_size From 9f62538ebccc7709fd09da4498087e69884a149c Mon Sep 17 00:00:00 2001 From: Bhargav Date: Wed, 27 Dec 2023 08:16:51 +0200 Subject: [PATCH 3/3] Changing perf numbers --- tests/test_encoder_decoder_text_summarization.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_encoder_decoder_text_summarization.py b/tests/test_encoder_decoder_text_summarization.py index 506ae0e04e..197cc37e7a 100644 --- a/tests/test_encoder_decoder_text_summarization.py +++ b/tests/test_encoder_decoder_text_summarization.py @@ -15,7 +15,7 @@ MODELS_TO_TEST = { "bf16": [ ("facebook/bart-large-cnn", "Habana/bart", 4.691, 26.0688, 2, 1), - ("t5-3b", "Habana/t5", 2.28, 21.56, 2, 1), + ("t5-3b", "Habana/t5", 2.88, 21.56, 2, 1), ], } else: @@ -23,7 +23,7 @@ MODELS_TO_TEST = { "bf16": [ ("facebook/bart-large-cnn", "Habana/bart", 2.588, 26.0688, 2, 1), - ("t5-3b", "Habana/t5", 0.585, 21.72, 2, 1), + ("t5-3b", "Habana/t5", 0.98, 21.56, 2, 1), ], } @@ -76,6 +76,8 @@ def _test_text_summarization( if not deepspeed: command.append("--bf16") + if model_name == "t5-3b": + command.append("--bf16_full_eval") with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}")