From 5922e853f3a6219ea15c7195b91c9d53533851b4 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 3 Apr 2023 14:59:59 +0000 Subject: [PATCH 1/3] enable PP for T5 --- src/transformers/models/t5/modeling_t5.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index c056ef73cc1f..a3d1f8934782 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1746,7 +1746,11 @@ def forward( loss = None if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + lm_logits = lm_logits.view(-1, lm_logits.size(-1)) + labels = labels.view(-1) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits, labels) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: From a82c0b6eb5b8e30e7263cd6043da48a69203b76d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 3 Apr 2023 15:02:54 +0000 Subject: [PATCH 2/3] make fixup --- src/transformers/models/mt5/modeling_mt5.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 43beb69f10f2..bd2bcad019a3 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1778,7 +1778,11 @@ def forward( loss = None if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + lm_logits = lm_logits.view(-1, lm_logits.size(-1)) + labels = labels.view(-1) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits, labels) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: From 33ab0bb2a443d02ecd8a122740e6562578cca6ee Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 3 Apr 2023 15:29:02 +0000 Subject: [PATCH 3/3] fix failing tests --- src/transformers/models/mt5/modeling_mt5.py | 4 +--- src/transformers/models/t5/modeling_t5.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index bd2bcad019a3..7e7a8903c011 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1778,11 +1778,9 @@ def forward( loss = None if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) - lm_logits = lm_logits.view(-1, lm_logits.size(-1)) - labels = labels.view(-1) # move labels to correct device to enable PP labels = labels.to(lm_logits.device) - loss = loss_fct(lm_logits, labels) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index a3d1f8934782..dfb9907a3cf5 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1746,11 +1746,9 @@ def forward( loss = None if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) - lm_logits = lm_logits.view(-1, lm_logits.size(-1)) - labels = labels.view(-1) # move labels to correct device to enable PP labels = labels.to(lm_logits.device) - loss = loss_fct(lm_logits, labels) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: