diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index ab308e7023..813f40d790 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -231,11 +231,19 @@ def setup_parser(parser): action="store_true", help="Whether to enable Habana Flash Attention, provided that the model supports it.", ) + parser.add_argument( + "--torch_compile", + action="store_true", + help="Whether to use torch compiled model or not.", + ) parser.add_argument("--temperature", default=1.0, type=float, help="Temperature value for text generation") parser.add_argument("--top_p", default=1.0, type=float, help="Top_p value for generating text via sampling") args = parser.parse_args() + if args.torch_compile: + args.use_hpu_graphs = False + if not args.use_hpu_graphs: args.limit_hpu_graphs = False @@ -247,6 +255,10 @@ def main(): args = setup_parser(parser) model, tokenizer, generation_config = initialize_model(args, logger) + use_lazy_mode = True + if args.torch_compile and model.config.model_type == "llama": + use_lazy_mode = False + import habana_frameworks.torch.hpu as torch_hpu if args.dataset_name is None: @@ -299,7 +311,7 @@ def generate(size=None, reduce_recompile=False): outputs = model.generate( **input_tokens, generation_config=generation_config, - lazy_mode=True, + lazy_mode=use_lazy_mode, hpu_graphs=args.use_hpu_graphs, profiling_steps=args.profiling_steps, profiling_warmup_steps=args.profiling_warmup_steps, @@ -479,7 +491,7 @@ def generate_dataset(batch): outputs = model.generate( **batch, generation_config=generation_config, - lazy_mode=True, + lazy_mode=use_lazy_mode, hpu_graphs=args.use_hpu_graphs, profiling_steps=args.profiling_steps, profiling_warmup_steps=args.profiling_warmup_steps, diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 5c03de7dc6..46ef9fa1f6 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -115,7 +115,7 @@ def setup_env(args): check_min_version("4.34.0") check_optimum_habana_min_version("1.9.0.dev0") - if args.global_rank == 0: + if args.global_rank == 0 and not args.torch_compile: os.environ.setdefault("GRAPH_VISUALIZATION", "true") shutil.rmtree(".graph_dumps", ignore_errors=True) @@ -151,6 +151,11 @@ def patch_scoped_linear_all_reduce(model): patch_scoped_linear_all_reduce(module) +def get_torch_compiled_model(model): + model.model = torch.compile(model.model, backend="aot_hpu_inference_backend") + return model + + def setup_model(args, model_dtype, model_kwargs, logger): logger.info("Single-device run.") @@ -170,6 +175,10 @@ def setup_model(args, model_dtype, model_kwargs, logger): model = wrap_in_hpu_graph(model, hash_with_views=not args.skip_hash_with_views) else: model = wrap_in_hpu_graph(model) + + if args.torch_compile and model.config.model_type == "llama": + model = get_torch_compiled_model(model) + return model diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index faae5b79db..20ff84ac1d 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -30,6 +30,9 @@ ("meta-llama/Llama-2-70b-hf", 58.2750262232098), ("facebook/opt-66b", 28.16154122335556), ], + "torch_compile": [ + ("meta-llama/Llama-2-7b-hf", 8.95169640119334), + ], } else: # Gaudi1 CI baselines @@ -50,13 +53,22 @@ "deepspeed": [ ("bigscience/bloomz-7b1", 27.34439410425298), ], + "torch_compile": [], } -def _test_text_generation(model_name: str, baseline: float, token: str, deepspeed: bool = False, world_size: int = 8): +def _test_text_generation( + model_name: str, + baseline: float, + token: str, + deepspeed: bool = False, + world_size: int = 8, + torch_compile: bool = False, +): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" + deepspeed = deepspeed and not torch_compile if deepspeed: command += [ f"{path_to_example_dir / 'gaudi_spawn.py'}", @@ -68,11 +80,22 @@ def _test_text_generation(model_name: str, baseline: float, token: str, deepspee f"{path_to_example_dir / 'text-generation' / 'run_generation.py'}", f"--model_name_or_path {model_name}", "--batch_size 1", - "--use_hpu_graphs", "--use_kv_cache", "--max_new_tokens 100", ] + if torch_compile: + command += [ + "--attn_softmax_bf16", + "--reuse_cache", + "--trim_logits", + "--torch_compile", + ] + else: + command += [ + "--use_hpu_graphs", + ] + if not deepspeed: command.append("--bf16") @@ -115,3 +138,11 @@ def test_text_generation_bf16(model_name: str, baseline: float, token: str): def test_text_generation_deepspeed(model_name: str, baseline: float, token: str): world_size = 2 if "opt-66b" in model_name else 8 _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size) + + +@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["torch_compile"]) +def test_text_generation_torch_compile(model_name: str, baseline: float, token: str): + os.environ["PT_ENABLE_INT64_SUPPORT"] = "1" + os.environ["PT_HPU_LAZY_MODE"] = "0" + os.environ["WORLD_SIZE"] = "0" + _test_text_generation(model_name, baseline, token, torch_compile=True)