Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,11 @@ def __init__(self, config: BioGptConfig):
self.activation_fn = ACT2FN[config.hidden_act]
self.activation_dropout = config.activation_dropout

self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

def forward(
self,
Expand Down Expand Up @@ -452,7 +452,7 @@ def __init__(self, config: BioGptConfig):
self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)

self.layers = nn.ModuleList([BioGptDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layer_norm = nn.LayerNorm(self.embed_dim)
self.layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

self.gradient_checkpointing = False
# Initialize weights and apply final processing
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,9 @@ def __init__(self, config: CLIPConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = CLIPAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

def forward(
self,
Expand Down Expand Up @@ -676,7 +676,7 @@ def __init__(self, config: CLIPTextConfig):
embed_dim = config.hidden_size
self.embeddings = CLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
Expand Down Expand Up @@ -826,9 +826,9 @@ def __init__(self, config: CLIPVisionConfig):
embed_dim = config.hidden_size

self.embeddings = CLIPVisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(embed_dim)
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/clipseg/modeling_clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,9 @@ def __init__(self, config: CLIPSegConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = CLIPSegAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = CLIPSegMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

def forward(
self,
Expand Down Expand Up @@ -687,7 +687,7 @@ def __init__(self, config: CLIPSegTextConfig):
embed_dim = config.hidden_size
self.embeddings = CLIPSegTextEmbeddings(config)
self.encoder = CLIPSegEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

@add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegTextConfig)
Expand Down Expand Up @@ -833,9 +833,9 @@ def __init__(self, config: CLIPSegVisionConfig):
embed_dim = config.hidden_size

self.embeddings = CLIPSegVisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(embed_dim)
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPSegEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

@add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegVisionConfig)
Expand Down Expand Up @@ -1174,9 +1174,9 @@ def __init__(self, config: CLIPSegConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = CLIPSegAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = CLIPSegMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

def forward(
self,
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/glpn/modeling_glpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def extra_repr(self) -> str:
class GLPNOverlapPatchEmbeddings(nn.Module):
"""Construct the overlapping patch embeddings."""

def __init__(self, patch_size, stride, num_channels, hidden_size):
def __init__(self, config, patch_size, stride, num_channels, hidden_size):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need this new argument so we can use eps=config.layer_norm_eps. As this is an internal class, should be fine.

super().__init__()
self.proj = nn.Conv2d(
num_channels,
Expand All @@ -103,7 +103,7 @@ def __init__(self, patch_size, stride, num_channels, hidden_size):
padding=patch_size // 2,
)

self.layer_norm = nn.LayerNorm(hidden_size)
self.layer_norm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)

def forward(self, pixel_values):
embeddings = self.proj(pixel_values)
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_
self.sr = nn.Conv2d(
hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio
)
self.layer_norm = nn.LayerNorm(hidden_size)
self.layer_norm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)

def transpose_for_scores(self, hidden_states):
new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
Expand Down Expand Up @@ -294,15 +294,15 @@ class GLPNLayer(nn.Module):

def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio):
super().__init__()
self.layer_norm_1 = nn.LayerNorm(hidden_size)
self.layer_norm_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
self.attention = GLPNAttention(
config,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
sequence_reduction_ratio=sequence_reduction_ratio,
)
self.drop_path = GLPNDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.layer_norm_2 = nn.LayerNorm(hidden_size)
self.layer_norm_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
mlp_hidden_size = int(hidden_size * mlp_ratio)
self.mlp = GLPNMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size)

Expand Down Expand Up @@ -345,6 +345,7 @@ def __init__(self, config):
for i in range(config.num_encoder_blocks):
embeddings.append(
GLPNOverlapPatchEmbeddings(
config=config,
patch_size=config.patch_sizes[i],
stride=config.strides[i],
num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
Expand Down Expand Up @@ -378,7 +379,7 @@ def __init__(self, config):

# Layer norms
self.layer_norm = nn.ModuleList(
[nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)]
[nn.LayerNorm(config.hidden_sizes[i], eps=config.layer_norm_eps) for i in range(config.num_encoder_blocks)]
)

def forward(
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/groupvit/modeling_groupvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,9 +714,9 @@ def __init__(self, config: GroupViTConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = GroupViTAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = GroupViTMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

def forward(
self,
Expand Down Expand Up @@ -1076,7 +1076,7 @@ def __init__(self, config: GroupViTTextConfig):
embed_dim = config.hidden_size
self.embeddings = GroupViTTextEmbeddings(config)
self.encoder = GroupViTTextEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

@add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=GroupViTTextConfig)
Expand Down Expand Up @@ -1219,7 +1219,7 @@ def __init__(self, config: GroupViTVisionConfig):

self.embeddings = GroupViTVisionEmbeddings(config)
self.encoder = GroupViTVisionEncoder(config)
self.layernorm = nn.LayerNorm(embed_dim)
self.layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

@add_start_docstrings_to_model_forward(GROUPVIT_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=GroupViTVisionConfig)
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/lxmert/modeling_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(self, config):

# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, input_ids, token_type_ids=None, inputs_embeds=None):
Expand Down Expand Up @@ -375,7 +375,7 @@ class LxmertAttentionOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
Expand Down Expand Up @@ -437,7 +437,7 @@ class LxmertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
Expand Down Expand Up @@ -563,11 +563,11 @@ def __init__(self, config):

# Object feature encoding
self.visn_fc = nn.Linear(feat_dim, config.hidden_size)
self.visn_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.visn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

# Box position encoding
self.box_fc = nn.Linear(pos_dim, config.hidden_size)
self.box_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.box_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

self.dropout = nn.Dropout(config.hidden_dropout_prob)

Expand Down Expand Up @@ -684,7 +684,7 @@ def __init__(self, config):
super(LxmertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act]
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
Expand Down Expand Up @@ -721,7 +721,7 @@ def __init__(self, config, num_labels):
self.logit_fc = nn.Sequential(
nn.Linear(hid_dim, hid_dim * 2),
GeLU(),
nn.LayerNorm(hid_dim * 2, eps=1e-12),
nn.LayerNorm(hid_dim * 2, eps=config.layer_norm_eps),
nn.Linear(hid_dim * 2, num_labels),
)

Expand Down
12 changes: 6 additions & 6 deletions src/transformers/models/owlvit/modeling_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,9 @@ def __init__(self, config: OwlViTConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = OwlViTAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = OwlViTMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

def forward(
self,
Expand Down Expand Up @@ -790,7 +790,7 @@ def __init__(self, config: OwlViTTextConfig):
embed_dim = config.hidden_size
self.embeddings = OwlViTTextEmbeddings(config)
self.encoder = OwlViTEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

@add_start_docstrings_to_model_forward(OWLVIT_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTTextConfig)
Expand Down Expand Up @@ -922,9 +922,9 @@ def __init__(self, config: OwlViTVisionConfig):
self.config = config

self.embeddings = OwlViTVisionEmbeddings(config)
self.pre_layernorm = nn.LayerNorm(config.hidden_size)
self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.encoder = OwlViTEncoder(config)
self.post_layernorm = nn.LayerNorm(config.hidden_size)
self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

@add_start_docstrings_to_model_forward(OWLVIT_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTVisionConfig)
Expand Down Expand Up @@ -1318,7 +1318,7 @@ def __init__(self, config: OwlViTConfig):
self.class_head = OwlViTClassPredictionHead(config)
self.box_head = OwlViTBoxPredictionHead(config)

self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size)
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
self.sigmoid = nn.Sigmoid()

def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/segformer/modeling_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def extra_repr(self) -> str:
class SegformerOverlapPatchEmbeddings(nn.Module):
"""Construct the overlapping patch embeddings."""

def __init__(self, patch_size, stride, num_channels, hidden_size):
def __init__(self, config, patch_size, stride, num_channels, hidden_size):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need this new argument so we can use eps=config.layer_norm_eps. As this is an internal class, should be fine.

super().__init__()
self.proj = nn.Conv2d(
num_channels,
Expand All @@ -134,7 +134,7 @@ def __init__(self, patch_size, stride, num_channels, hidden_size):
padding=patch_size // 2,
)

self.layer_norm = nn.LayerNorm(hidden_size)
self.layer_norm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)

def forward(self, pixel_values):
embeddings = self.proj(pixel_values)
Expand Down Expand Up @@ -175,7 +175,7 @@ def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_
self.sr = nn.Conv2d(
hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio
)
self.layer_norm = nn.LayerNorm(hidden_size)
self.layer_norm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)

def transpose_for_scores(self, hidden_states):
new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
Expand Down Expand Up @@ -319,15 +319,15 @@ class SegformerLayer(nn.Module):

def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio):
super().__init__()
self.layer_norm_1 = nn.LayerNorm(hidden_size)
self.layer_norm_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
self.attention = SegformerAttention(
config,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
sequence_reduction_ratio=sequence_reduction_ratio,
)
self.drop_path = SegformerDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.layer_norm_2 = nn.LayerNorm(hidden_size)
self.layer_norm_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
mlp_hidden_size = int(hidden_size * mlp_ratio)
self.mlp = SegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size)

Expand Down Expand Up @@ -370,6 +370,7 @@ def __init__(self, config):
for i in range(config.num_encoder_blocks):
embeddings.append(
SegformerOverlapPatchEmbeddings(
config=config,
patch_size=config.patch_sizes[i],
stride=config.strides[i],
num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
Expand Down Expand Up @@ -403,7 +404,7 @@ def __init__(self, config):

# Layer norms
self.layer_norm = nn.ModuleList(
[nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)]
[nn.LayerNorm(config.hidden_sizes[i], eps=config.layer_norm_eps) for i in range(config.num_encoder_blocks)]
)

def forward(
Expand Down
Loading