From e8c616291696284d2b3a2dbbd47a120c211fa657 Mon Sep 17 00:00:00 2001 From: kausik Date: Thu, 28 Dec 2023 06:45:57 +0200 Subject: [PATCH 1/3] Run Llama2 with torch.compile on Gaudi2 Signed-off-by: kausik --- examples/text-generation/run_generation.py | 16 ++++++++++++++-- examples/text-generation/utils.py | 9 +++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 445794048f..c4de443940 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -231,9 +231,17 @@ def setup_parser(parser): action="store_true", help="Whether to enable Habana Flash Attention, provided that the model supports it.", ) + parser.add_argument( + "--torch_compile", + action="store_true", + help="Whether to use torch compiled model or not.", + ) args = parser.parse_args() + if args.torch_compile: + args.use_hpu_graphs = False + if not args.use_hpu_graphs: args.limit_hpu_graphs = False @@ -245,6 +253,10 @@ def main(): args = setup_parser(parser) model, tokenizer, generation_config = initialize_model(args, logger) + use_lazy_mode = True + if args.torch_compile: + use_lazy_mode = False + import habana_frameworks.torch.hpu as torch_hpu if args.dataset_name is None: @@ -297,7 +309,7 @@ def generate(size=None, reduce_recompile=False): outputs = model.generate( **input_tokens, generation_config=generation_config, - lazy_mode=True, + lazy_mode=use_lazy_mode, hpu_graphs=args.use_hpu_graphs, profiling_steps=args.profiling_steps, profiling_warmup_steps=args.profiling_warmup_steps, @@ -477,7 +489,7 @@ def generate_dataset(batch): outputs = model.generate( **batch, generation_config=generation_config, - lazy_mode=True, + lazy_mode=use_lazy_mode, hpu_graphs=args.use_hpu_graphs, profiling_steps=args.profiling_steps, profiling_warmup_steps=args.profiling_warmup_steps, diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 225ae7c5ad..baeab6173d 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -151,6 +151,11 @@ def patch_scoped_linear_all_reduce(model): patch_scoped_linear_all_reduce(module) +def get_torch_compiled_model(model): + model.model = torch.compile(model.model, backend="aot_hpu_inference_backend") + return model + + def setup_model(args, model_dtype, model_kwargs, logger): logger.info("Single-device run.") @@ -170,6 +175,10 @@ def setup_model(args, model_dtype, model_kwargs, logger): model = wrap_in_hpu_graph(model, hash_with_views=not args.skip_hash_with_views) else: model = wrap_in_hpu_graph(model) + + if args.torch_compile: + model = get_torch_compiled_model(model) + return model From e13946fbe877a6e80f6be8cd1b0ad9baf61e0cd8 Mon Sep 17 00:00:00 2001 From: kausik Date: Sun, 7 Jan 2024 12:27:44 +0200 Subject: [PATCH 2/3] Added model specific check to enable torch.compile support only for Llama2 Signed-off-by: kausik --- examples/text-generation/run_generation.py | 2 +- examples/text-generation/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index c4de443940..e5300cf5e9 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -254,7 +254,7 @@ def main(): model, tokenizer, generation_config = initialize_model(args, logger) use_lazy_mode = True - if args.torch_compile: + if args.torch_compile and model.config.model_type == "llama": use_lazy_mode = False import habana_frameworks.torch.hpu as torch_hpu diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index baeab6173d..c964b86ecb 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -176,7 +176,7 @@ def setup_model(args, model_dtype, model_kwargs, logger): else: model = wrap_in_hpu_graph(model) - if args.torch_compile: + if args.torch_compile and model.config.model_type == "llama": model = get_torch_compiled_model(model) return model From 03596e6a12204aba9ee1d64041015dd3db1cc49c Mon Sep 17 00:00:00 2001 From: kausik Date: Thu, 18 Jan 2024 12:29:05 +0200 Subject: [PATCH 3/3] Added a test Signed-off-by: kausik --- examples/text-generation/utils.py | 2 +- tests/test_text_generation_example.py | 35 +++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 4bf7fb641c..46ef9fa1f6 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -115,7 +115,7 @@ def setup_env(args): check_min_version("4.34.0") check_optimum_habana_min_version("1.9.0.dev0") - if args.global_rank == 0: + if args.global_rank == 0 and not args.torch_compile: os.environ.setdefault("GRAPH_VISUALIZATION", "true") shutil.rmtree(".graph_dumps", ignore_errors=True) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index faae5b79db..20ff84ac1d 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -30,6 +30,9 @@ ("meta-llama/Llama-2-70b-hf", 58.2750262232098), ("facebook/opt-66b", 28.16154122335556), ], + "torch_compile": [ + ("meta-llama/Llama-2-7b-hf", 8.95169640119334), + ], } else: # Gaudi1 CI baselines @@ -50,13 +53,22 @@ "deepspeed": [ ("bigscience/bloomz-7b1", 27.34439410425298), ], + "torch_compile": [], } -def _test_text_generation(model_name: str, baseline: float, token: str, deepspeed: bool = False, world_size: int = 8): +def _test_text_generation( + model_name: str, + baseline: float, + token: str, + deepspeed: bool = False, + world_size: int = 8, + torch_compile: bool = False, +): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" + deepspeed = deepspeed and not torch_compile if deepspeed: command += [ f"{path_to_example_dir / 'gaudi_spawn.py'}", @@ -68,11 +80,22 @@ def _test_text_generation(model_name: str, baseline: float, token: str, deepspee f"{path_to_example_dir / 'text-generation' / 'run_generation.py'}", f"--model_name_or_path {model_name}", "--batch_size 1", - "--use_hpu_graphs", "--use_kv_cache", "--max_new_tokens 100", ] + if torch_compile: + command += [ + "--attn_softmax_bf16", + "--reuse_cache", + "--trim_logits", + "--torch_compile", + ] + else: + command += [ + "--use_hpu_graphs", + ] + if not deepspeed: command.append("--bf16") @@ -115,3 +138,11 @@ def test_text_generation_bf16(model_name: str, baseline: float, token: str): def test_text_generation_deepspeed(model_name: str, baseline: float, token: str): world_size = 2 if "opt-66b" in model_name else 8 _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size) + + +@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["torch_compile"]) +def test_text_generation_torch_compile(model_name: str, baseline: float, token: str): + os.environ["PT_ENABLE_INT64_SUPPORT"] = "1" + os.environ["PT_HPU_LAZY_MODE"] = "0" + os.environ["WORLD_SIZE"] = "0" + _test_text_generation(model_name, baseline, token, torch_compile=True)