Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def postprocess_text(preds, labels):

# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/longt5/modeling_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,
Expand Down Expand Up @@ -247,6 +248,8 @@ def forward(self, hidden_states):
return self.weight * hidden_states


ALL_LAYERNORM_LAYERS.append(LongT5LayerNorm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be after the try/except block below, which can change LongT5LayerNorm.


try:
from apex.normalization import FusedRMSNorm

Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,
Expand Down Expand Up @@ -262,6 +262,8 @@ def forward(self, hidden_states):
return self.weight * hidden_states


ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be after the try except block below, which may change T5LayerNorm.


try:
from apex.normalization import FusedRMSNorm

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from .utils import logging


ALL_LAYERNORM_LAYERS = [nn.LayerNorm]

logger = logging.get_logger(__name__)

is_torch_less_than_1_8 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.8.0")
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
Expand Down Expand Up @@ -967,7 +968,7 @@ def create_optimizer(self):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, [nn.LayerNorm])
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
Expand Down