Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/source/ar/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,29 @@ tpu_use_sudo: false
use_cpu: false
```

</hfoption>
<hfoption id="Tensor Parallelism with PyTorch 2">

```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

```

</hfoption>
</hfoptions>
يُعد أمر [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) هو الطريقة المُوصى بها لتشغيل نص البرمجى للتدريب على نظام موزع باستخدام Accelerate و [`Trainer`] مع المعلمات المحددة في `config_file.yaml`. يتم حفظ هذا الملف في مجلد ذاكرة التخزين المؤقت لـ Accelerate ويتم تحميله تلقائيًا عند تشغيل `accelerate_launch`.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/llm_tutorial_optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ To give some examples of how much VRAM it roughly takes to load a model in bfloa

As of writing this document, the largest GPU chip on the market is the A100 & H100 offering 80GB of VRAM. Most of the models listed before require more than 80GB just to be loaded and therefore necessarily require [tensor parallelism](https://huggingface.co/docs/transformers/perf_train_gpu_many#tensor-parallelism) and/or [pipeline parallelism](https://huggingface.co/docs/transformers/perf_train_gpu_many#naive-model-parallelism-vertical-and-pipeline-parallelism).

🤗 Transformers does not support tensor parallelism out of the box as it requires the model architecture to be written in a specific way. If you're interested in writing models in a tensor-parallelism-friendly way, feel free to have a look at [the text-generation-inference library](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models/custom_modeling).
🤗 Transformers now supports tensor parallelism for supported models having `base_tp_plan` in their respecitve config classes. Learn more about Tensor Parallelism [here](perf_train_gpu_many#tensor-parallelism). Furthermore, if you're interested in writing models in a tensor-parallelism-friendly way, feel free to have a look at [the text-generation-inference library](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models/custom_modeling).

Naive pipeline parallelism is supported out of the box. For this, simply load the model with `device="auto"` which will automatically place the different layers on the available GPUs as explained [here](https://huggingface.co/docs/accelerate/v0.22.0/en/concept_guides/big_model_inference).
Note, however that while very effective, this naive pipeline parallelism does not tackle the issues of GPU idling. For this more advanced pipeline parallelism is required as explained [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many#naive-model-parallelism-vertical-and-pipeline-parallelism).
Expand Down
7 changes: 4 additions & 3 deletions docs/source/en/perf_train_gpu_many.md
Original file line number Diff line number Diff line change
Expand Up @@ -450,12 +450,13 @@ Implementations:
- [parallelformers](https://github.com/tunib-ai/parallelformers) (only inference at the moment)
- [SageMaker](https://arxiv.org/abs/2111.05972) - this is a proprietary solution that can only be used on AWS.
- [OSLO](https://github.com/tunib-ai/oslo) has the tensor parallelism implementation based on the Transformers.
- [`transformers` integration](main_classes/trainer) tensor parallelism is available through tp_size attribute for models having `base_tp_plan`. Further you can look at [example usage](perf_infer_gpu_multi)

SageMaker combines TP with DP for a more efficient processing.

🤗 Transformers status:
- core: not yet implemented in the core
- but if you want inference [parallelformers](https://github.com/tunib-ai/parallelformers) provides this support for most of our models. So until this is implemented in the core you can use theirs. And hopefully training mode will be supported too.
- core: uses PyTorch 2 APIs to support tensor parallelism to models having base_tp_plan in their respective config classes.
- Alternatively, you can as well try [parallelformers](https://github.com/tunib-ai/parallelformers) that provides this support for most of our models. Training mode with TP is as well supported natively in transformers.
- Deepspeed-Inference also supports our BERT, GPT-2, and GPT-Neo models in their super-fast CUDA-kernel-based inference mode, see more [here](https://www.deepspeed.ai/tutorials/inference-tutorial/)

🤗 Accelerate integrates with [TP from Megatron-LM](https://huggingface.co/docs/accelerate/v0.23.0/en/usage_guides/megatron_lm).
Expand Down Expand Up @@ -535,7 +536,7 @@ Important papers:
- [Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model](
https://arxiv.org/abs/2201.11990)

🤗 Transformers status: not yet implemented, since we have no PP and TP.
🤗 Transformers status: not yet implemented, since we have no PP.

## FlexFlow

Expand Down
23 changes: 23 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,29 @@ tpu_use_sudo: false
use_cpu: false
```

</hfoption>
<hfoption id="Tensor Parallelism with PyTorch 2">

```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

```

</hfoption>
</hfoptions>

Expand Down
24 changes: 24 additions & 0 deletions docs/source/es/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,30 @@ use_cpu: false

```

</hfoption>

<hfoption id="Tensor Parallelism with PyTorch 2">

```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

```

</hfoption>
</hfoptions>

Expand Down
23 changes: 23 additions & 0 deletions docs/source/ko/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,29 @@ tpu_use_sudo: false
use_cpu: false
```

</hfoption>
<hfoption id="Tensor Parallelism with PyTorch 2">

```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

```

</hfoption>
</hfoptions>

Expand Down
12 changes: 11 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@
)

DATA_SAMPLERS = [RandomSampler]
if version.parse(accelerate_version) > version.parse("1.3.0"):
from accelerate.utils import TorchTensorParallelPlugin
if version.parse(accelerate_version) > version.parse("0.23.0"):
from accelerate.data_loader import SeedableRandomSampler

Expand Down Expand Up @@ -5094,6 +5096,14 @@ def create_accelerator_and_postprocess(self):
args["dataloader_config"] = dataloader_config
else:
args.update(accelerator_config)
# tp is initialized at Accelerator init phase so
# args should be prepared here
if self.args.tp_size > 1:
self.is_tp_enabled = True
if version.parse(accelerate_version) > version.parse("1.3.0"):
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with this API, so we need some documentation about what this uses under the hood!
Also we could check if model supports TP? or is it not even needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we could check if model supports TP? or is it not even needed?

Hi @ArthurZucker We check this in accelerate here - https://github.com/huggingface/accelerate/blob/526925b48c07d997cdd9bf5911f659ca45778915/src/accelerate/accelerator.py#L1511

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker

I am not familiar with this API, so we need some documentation about what this uses under the hood!

Do you want me to add that as a comment above this piece of code or please point me to a place in HF docs where you want me to add it. Thank you.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be added to the documentation for TP feature !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but interesting did not know accelerate already supported it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker We added the support to accelerate at this PR - huggingface/accelerate#3173

else:
raise ValueError("Requires accelerate>1.3.0 to use Tensor Parallelism.")

# create accelerator object
self.accelerator = Accelerator(**args)
Expand All @@ -5108,7 +5118,7 @@ def create_accelerator_and_postprocess(self):
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None

self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
Expand Down
25 changes: 24 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,10 @@ class TrainingArguments:
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
used when the xla flag is set to true, and an auto wrapping policy is specified through
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.

tp_size (`int`, *optional*):
Use tp_size to enable PyTorch tensor parallelism. Tensor parallelism support is only available to models having `base_tp_plan`
in their respective config classes.
Set a value greater than 1 to activate TP. The same is used to prepare device mesh internally. Requires accelerate>1.3.0.
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/deepspeedai/DeepSpeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
Expand Down Expand Up @@ -1250,6 +1253,18 @@ class TrainingArguments:
)
},
)
tp_size: Optional[int] = field(
default=0,
metadata={
"help": (
"Use tp_size to enable pytorch tensor parallelism."
"Tensor parallelism support is only available to models having `base_tp_plan` in their respective config classes."
"Set a value greater than 1 to activate TP."
"The same is used to prepare device mesh internally."
"Requires accelerate>1.3.0."
)
},
)
fsdp_transformer_layer_cls_to_wrap: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -1975,6 +1990,14 @@ def __post_init__(self):
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")

if self.tp_size > 1:
if not is_accelerate_available("1.3.1"):
raise NotImplementedError(
"TP using PyTorch requires Accelerate version `accelerate` >= 1.3.1. "
"This is not supported and we recommend you to update your version."
)
os.environ["ACCELERATE_USE_TP"] = "true"
os.environ["TP_SIZE"] = str(self.tp_size)
# accelerate integration for FSDP
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
os.environ["ACCELERATE_USE_FSDP"] = "true"
Expand Down