Skip to content

[T5] Enable naive Pipeline Parallelism training for T5#22535

Merged
younesbelkada merged 4 commits into
huggingface:mainfrom
younesbelkada:fix-t5-pp
Apr 3, 2023
Merged

[T5] Enable naive Pipeline Parallelism training for T5#22535
younesbelkada merged 4 commits into
huggingface:mainfrom
younesbelkada:fix-t5-pp

Conversation

@younesbelkada

@younesbelkada younesbelkada commented Apr 3, 2023

Copy link
Copy Markdown
Contributor

What does this PR do?

Similarly as #22329 this PR enables training T5 models in a "Naive Pipeline Parallelism" setup. What is termed as "Naive Pipeline Parallelism" is simply to spread the model across multiple GPUs and run naively the forward/backward pass by communicating the activations and gradients between each GPU.

Without this fix, users will encounter device mismatch issues when training this model that has been loaded across multiple GPUs. Hence, the fix is to manually set the device of the labels to the same device as lm_logits.

A simple snippet to reproduce the behaviour below (this needs to be run on a multi-gpu env):

import torch
from transformers import AutoModelForSeq2SeqLM

model_id = "google/flan-t5-base"

model = AutoModelForSeq2SeqLM.from_pretrained(model_id, device_map="balanced")
print(set(model.hf_device_map.values())) # >>> {0, 1}

dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]])

loss = model(input_ids=dummy_input, labels=dummy_input).loss

Error trace:

│   1746 │   │   loss = None                                                                       │
│   1747 │   │   if labels is not None:                                                            │
│   1748 │   │   │   loss_fct = CrossEntropyLoss(ignore_index=-100)                                │
│ ❱ 1749 │   │   │   loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))      │
│   1750 │   │   │   # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc  │
│   1751 │   │                                                                                     │
│   1752 │   │   if not return_dict:                                                               │
│                                                                                                  │
│ /home/younes_huggingface_co/miniconda3/envs/fix-test/lib/python3.9/site-packages/torch/nn/module │
│ s/module.py:1501 in _call_impl                                                                   │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/younes_huggingface_co/miniconda3/envs/fix-test/lib/python3.9/site-packages/torch/nn/module │
│ s/loss.py:1174 in forward                                                                        │
│                                                                                                  │
│   1171 │   │   self.label_smoothing = label_smoothing                                            │
│   1172 │                                                                                         │
│   1173 │   def forward(self, input: Tensor, target: Tensor) -> Tensor:                           │
│ ❱ 1174 │   │   return F.cross_entropy(input, target, weight=self.weight,                         │
│   1175 │   │   │   │   │   │   │      ignore_index=self.ignore_index, reduction=self.reduction,  │
│   1176 │   │   │   │   │   │   │      label_smoothing=self.label_smoothing)                      │
│   1177                                                                                           │
│                                                                                                  │
│ /home/younes_huggingface_co/miniconda3/envs/fix-test/lib/python3.9/site-packages/torch/nn/functi │
│ onal.py:3029 in cross_entropy                                                                    │
│                                                                                                  │
│   3026 │   │   )                                                                                 │
│   3027 │   if size_average is not None or reduce is not None:                                    │
│   3028 │   │   reduction = _Reduction.legacy_get_string(size_average, reduce)                    │
│ ❱ 3029 │   return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(re  │
│   3030                                                                                           │
│   3031                                                                                           │
│   3032 def binary_cross_entropy(                                                                 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument target 
in method wrapper_CUDA_nll_loss_forward)

cc @sgugger

Related issues:

huggingface/peft#242
huggingface/peft#205

@younesbelkada younesbelkada marked this pull request as ready for review April 3, 2023 15:17
@younesbelkada younesbelkada requested a review from sgugger April 3, 2023 15:17
@HuggingFaceDocBuilderDev

HuggingFaceDocBuilderDev commented Apr 3, 2023

Copy link
Copy Markdown

The documentation is not available anymore as the PR was closed or merged.

@sgugger sgugger left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

@younesbelkada younesbelkada merged commit d7a4f5b into huggingface:main Apr 3, 2023
@younesbelkada younesbelkada deleted the fix-t5-pp branch April 3, 2023 15:55
raghavanone pushed a commit to raghavanone/transformers that referenced this pull request Apr 5, 2023
…#22535)

* enable PP for T5

* make fixup

* fix failing tests
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…#22535)

* enable PP for T5

* make fixup

* fix failing tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants