[Trainer] Force is_model_parallel when model is loaded in multiple GPUs using accelerate#22532
Conversation
|
Could you elaborate why is such a patch needed and what is the goal of your PR? Cause all of this seems very hacky. |
|
These hacks were needed because Moreover, modifying transformers/src/transformers/training_args.py Line 1801 in 9419f14 That is why I preferred to introduce a new argument to avoid modifying what is already in place and modify directly what is needed to be edited, without having to modify the model's internals (forcing |
|
The documentation is not available anymore as the PR was closed or merged. |
|
Or you could just analyze the device map of the model and determine if there are several GPUs used. It would be cleaner and not require the user to learn the 97th training argument. |
|
Ahh yes good point! |
Trainer] Add is_model_parallel argumentTrainer] Force is_model_parallel when model is loaded in multiple GPUs using accelerate
|
|
||
| if ( | ||
| getattr(model, "hf_device_map", None) is not None | ||
| and len(set(model.hf_device_map.values())) > 1 |
There was a problem hiding this comment.
Maybe actually check the number of GPUs, cause this could be one GPU and CPU here.
There was a problem hiding this comment.
Multi-device placement should only be on GPUs for naive pipelining to work, right? Offloading to CPU/disk won't work, isn't it the case?
There was a problem hiding this comment.
I think offloading to CPU/disk won't work yes, I am also unsure if CPU/disk offload training works out of the box with accelerate (without DeepSpeed)
…e GPUs using `accelerate` (huggingface#22532) * add `is_model_parallel` arg on Trainer * add warning * adapt from suggestions * revert t5 changes * remove commas * adapt from suggestions
…e GPUs using `accelerate` (huggingface#22532) * add `is_model_parallel` arg on Trainer * add warning * adapt from suggestions * revert t5 changes * remove commas * adapt from suggestions
What does this PR do?
When using the Trainer on a multi-GPU environment, users currently apply a patch that leads to some bugs.
Before running a training they need to call:
Which can lead to unexpected bugs on some models, such as T5, that has the
parallelizeAPI that is still in place, thus when forcingmodel_parallelto beTrue, calls that API, which is deprecated and should not be maintained.Script to reproduce:
cc @sgugger
Related: huggingface/peft#205