From 4444f1b1d0a2c2c5aba489ba53aa678110f6f0a0 Mon Sep 17 00:00:00 2001 From: kausik Date: Mon, 18 Dec 2023 14:50:04 +0200 Subject: [PATCH] [SW-169007] Enable torch.compile support for Llama2 Signed-off-by: kausik --- examples/text-generation/run_generation.py | 14 ++++++- .../habana/transformers/generation/utils.py | 39 +++++++++++++++---- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 445794048f..69a2e7ad97 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -73,6 +73,11 @@ def setup_parser(parser): action="store_true", help="Whether to use the key/value cache for decoding. It should speed up generation.", ) + parser.add_argument( + "--use_torch_compile", + action="store_true", + help="Whether to use torch compiled model or not.", + ) parser.add_argument( "--use_hpu_graphs", action="store_true", @@ -234,6 +239,9 @@ def setup_parser(parser): args = parser.parse_args() + if args.use_torch_compile: + args.use_hpu_graphs = False + if not args.use_hpu_graphs: args.limit_hpu_graphs = False @@ -297,7 +305,8 @@ def generate(size=None, reduce_recompile=False): outputs = model.generate( **input_tokens, generation_config=generation_config, - lazy_mode=True, + lazy_mode=True if not args.use_torch_compile else False, + torch_compile = args.use_torch_compile, hpu_graphs=args.use_hpu_graphs, profiling_steps=args.profiling_steps, profiling_warmup_steps=args.profiling_warmup_steps, @@ -477,7 +486,8 @@ def generate_dataset(batch): outputs = model.generate( **batch, generation_config=generation_config, - lazy_mode=True, + lazy_mode=True if not args.use_torch_compile else False, + torch_compile = args.use_torch_compile, hpu_graphs=args.use_hpu_graphs, profiling_steps=args.profiling_steps, profiling_warmup_steps=args.profiling_warmup_steps, diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index bc8fad5118..9439e02aef 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -406,6 +406,7 @@ def generate( negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, lazy_mode: Optional[bool] = False, + torch_compile: Optional[bool] = False, hpu_graphs: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, @@ -474,6 +475,8 @@ def generate( Attention_mask for `negative_prompt_ids`. lazy_mode (`bool`, *optional*, defaults to `False`): Whether the run is executed in lazy mode or not (i.e. eager mode). + torch_compile (`bool`, *optional*, defaults to `False`): + Whether the run is executed with torch.compile model or not. hpu_graphs (`bool`, *optional*, defaults to `False`): Whether to use HPU graphs for inference. profiling_warmup_steps (`int`, *optional*, defaults to 0): @@ -513,6 +516,10 @@ def generate( raise ValueError( "`hpu_graphs` is True but `lazy_mode` is False. HPU graphs require `lazy_mode` to be set to True." ) + if torch_compile and (lazy_mode or hpu_graphs): + raise ValueError( + "`torch_compile` is True. This requires both `lazy_mode` and `hpu_graphs` to be set to False." + ) # priority: `generation_config` argument > `model.generation_config` (the default generation config) if generation_config is None: @@ -838,6 +845,7 @@ def generate( synced_gpus=synced_gpus, streamer=streamer, lazy_mode=lazy_mode, + torch_compile=torch_compile, ignore_eos=generation_config.ignore_eos, profiling_warmup_steps=profiling_warmup_steps, profiling_steps=profiling_steps, @@ -1214,6 +1222,7 @@ def greedy_search( synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, lazy_mode: Optional[bool] = False, + torch_compile: Optional[bool] = False, ignore_eos: Optional[bool] = False, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, @@ -1265,6 +1274,8 @@ def greedy_search( through `streamer.put(token_ids)` and the streamer is responsible for any further processing. lazy_mode (`bool`, *optional*, defaults to `False`): Whether the run is executed in lazy mode or not (i.e. eager mode). + torch_compile (`bool`, *optional*, defaults to `False`): + Whether the run is executed with torch.compile model or not. ignore_eos (`bool`, *optional*, defaults to `False`): Whether to ignore finished sequences (faster in lazy mode and with HPU graphs) or not (eager mode). profiling_warmup_steps (`int`, *optional*, defaults to 0): @@ -1403,14 +1414,26 @@ def greedy_search( hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - **hpu_graphs_kwargs, - ) + if torch_compile: + # apply torch.compile + compiled_model = torch.compile(self, backend="aot_hpu_inference_backend") + # forward pass to get next token + outputs = compiled_model( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **hpu_graphs_kwargs, + ) + else: + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **hpu_graphs_kwargs, + ) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need