Run Llama2 with torch.compile on Gaudi2#605
Conversation
Signed-off-by: kausik <kmaiti@habana.ai>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| output_hidden_states=output_hidden_states, | ||
| **hpu_graphs_kwargs, | ||
| ) | ||
| if torch_compile: |
There was a problem hiding this comment.
wrapping model only for greedy_search does not look right, it should probably be done in generate() so that it works for other modes (such as beam_search also),
There was a problem hiding this comment.
Not even sure we should do it in generate at all. If using the trainer, it should already be taken care of (see discussion above). Otherwise, for example in the text-generation example, I think we should just have a get_torch_compiled_model in text-generation/utils.py. That seems to be the way recommended by Transformers.
There was a problem hiding this comment.
@regisss thanks for your comments, we will check if we can go with adding get_torch_compiled_model in text-generation/utils.py
There was a problem hiding this comment.
Ok. I would create 'get_torch_compiled_model' in text-generation/utils.py.
| negative_prompt_ids: Optional[torch.Tensor] = None, | ||
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, | ||
| lazy_mode: Optional[bool] = False, | ||
| torch_compile: Optional[bool] = False, |
There was a problem hiding this comment.
For normal training, eval, predict models are wrapped within accelerator.prepare_model() call, adding new code for generate() may not be aligned. @regisss any idea how direct model.generate() calls are handled in transformers for compile mode, I tried to search there but did not find anything.
There was a problem hiding this comment.
In the trainer, the link with Accelerate is made here:
And then in Accelerate it happens here:
It was introduced in #465.
There was a problem hiding this comment.
Outside of the trainer, Transformers recommends to simply use:
model = torch.compile(model)
https://huggingface.co/docs/transformers/v4.36.1/en/perf_torch_compile
There was a problem hiding this comment.
As suggested, I would create 'get_torch_compiled_model()' in text-generation/utils.py. And this will be called inside setup_model() in text-generation/utils.py.
| help="Whether to use the key/value cache for decoding. It should speed up generation.", | ||
| ) | ||
| parser.add_argument( | ||
| "--use_torch_compile", |
There was a problem hiding this comment.
| "--use_torch_compile", | |
| "--torch_compile", |
to be aligned with Transformers and GaudiTrainingArguments
There was a problem hiding this comment.
ok. I would change.
| negative_prompt_ids: Optional[torch.Tensor] = None, | ||
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, | ||
| lazy_mode: Optional[bool] = False, | ||
| torch_compile: Optional[bool] = False, |
There was a problem hiding this comment.
In the trainer, the link with Accelerate is made here:
And then in Accelerate it happens here:
It was introduced in #465.
| negative_prompt_ids: Optional[torch.Tensor] = None, | ||
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, | ||
| lazy_mode: Optional[bool] = False, | ||
| torch_compile: Optional[bool] = False, |
There was a problem hiding this comment.
Outside of the trainer, Transformers recommends to simply use:
model = torch.compile(model)
https://huggingface.co/docs/transformers/v4.36.1/en/perf_torch_compile
| output_hidden_states=output_hidden_states, | ||
| **hpu_graphs_kwargs, | ||
| ) | ||
| if torch_compile: |
There was a problem hiding this comment.
Not even sure we should do it in generate at all. If using the trainer, it should already be taken care of (see discussion above). Otherwise, for example in the text-generation example, I think we should just have a get_torch_compiled_model in text-generation/utils.py. That seems to be the way recommended by Transformers.
|
I created a separate PR after making necessary changes. Kindly refer to #616 |
… generation tests (huggingface#2200) (huggingface#605) Co-authored-by: Grzegorz Pluto-Prondzinski <gplutopx@habana.ai>
What does this PR do?
This change allows the user to run Llama2 model with torch.compile on Gaudi2.