diff --git a/src/transformers/models/dinat/configuration_dinat.py b/src/transformers/models/dinat/configuration_dinat.py index d51fc660d1c9..28a5d9440411 100644 --- a/src/transformers/models/dinat/configuration_dinat.py +++ b/src/transformers/models/dinat/configuration_dinat.py @@ -70,6 +70,8 @@ class DinatConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. + layer_scale_init_value (`float`, *optional*, defaults to 0.0): + The initial value for the layer scale. Disabled if <=0. Example: @@ -110,6 +112,7 @@ def __init__( patch_norm=True, initializer_range=0.02, layer_norm_eps=1e-5, + layer_scale_init_value=0.0, **kwargs ): super().__init__(**kwargs) @@ -134,3 +137,4 @@ def __init__( # we set the hidden_size attribute in order to make Dinat work with VisionEncoderDecoderModel # this indicates the channel dimension after the last stage of the model self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.layer_scale_init_value = layer_scale_init_value diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index ba78a00b1998..4648da0ce26c 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -462,6 +462,11 @@ def __init__(self, config, dim, num_heads, dilation, drop_path_rate=0.0): self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.intermediate = DinatIntermediate(config, dim) self.output = DinatOutput(config, dim) + self.layer_scale_parameters = ( + nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True) + if config.layer_scale_init_value > 0 + else None + ) def maybe_pad(self, hidden_states, height, width): window_size = self.window_size @@ -496,11 +501,18 @@ def forward( if was_padded: attention_output = attention_output[:, :height, :width, :].contiguous() + if self.layer_scale_parameters is not None: + attention_output = self.layer_scale_parameters[0] * attention_output + hidden_states = shortcut + self.drop_path(attention_output) layer_output = self.layernorm_after(hidden_states) - layer_output = self.intermediate(layer_output) - layer_output = hidden_states + self.output(layer_output) + layer_output = self.output(self.intermediate(layer_output)) + + if self.layer_scale_parameters is not None: + layer_output = self.layer_scale_parameters[1] * layer_output + + layer_output = hidden_states + self.drop_path(layer_output) layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) return layer_outputs diff --git a/src/transformers/models/nat/__init__.py b/src/transformers/models/nat/__init__.py index c5cc25af9f8a..f98758fcffbc 100644 --- a/src/transformers/models/nat/__init__.py +++ b/src/transformers/models/nat/__init__.py @@ -12,7 +12,6 @@ # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, -# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. diff --git a/src/transformers/models/nat/configuration_nat.py b/src/transformers/models/nat/configuration_nat.py index cde88bd160d9..30fb682ff00a 100644 --- a/src/transformers/models/nat/configuration_nat.py +++ b/src/transformers/models/nat/configuration_nat.py @@ -68,6 +68,8 @@ class NatConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. + layer_scale_init_value (`float`, *optional*, defaults to 0.0): + The initial value for the layer scale. Disabled if <=0. Example: @@ -107,6 +109,7 @@ def __init__( patch_norm=True, initializer_range=0.02, layer_norm_eps=1e-5, + layer_scale_init_value=0.0, **kwargs ): super().__init__(**kwargs) @@ -130,3 +133,4 @@ def __init__( # we set the hidden_size attribute in order to make Nat work with VisionEncoderDecoderModel # this indicates the channel dimension after the last stage of the model self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.layer_scale_init_value = layer_scale_init_value diff --git a/src/transformers/models/nat/modeling_nat.py b/src/transformers/models/nat/modeling_nat.py index f5216f55e14b..3d57717c525b 100644 --- a/src/transformers/models/nat/modeling_nat.py +++ b/src/transformers/models/nat/modeling_nat.py @@ -445,6 +445,11 @@ def __init__(self, config, dim, num_heads, drop_path_rate=0.0): self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.intermediate = NatIntermediate(config, dim) self.output = NatOutput(config, dim) + self.layer_scale_parameters = ( + nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True) + if config.layer_scale_init_value > 0 + else None + ) def maybe_pad(self, hidden_states, height, width): window_size = self.kernel_size @@ -479,11 +484,18 @@ def forward( if was_padded: attention_output = attention_output[:, :height, :width, :].contiguous() + if self.layer_scale_parameters is not None: + attention_output = self.layer_scale_parameters[0] * attention_output + hidden_states = shortcut + self.drop_path(attention_output) layer_output = self.layernorm_after(hidden_states) - layer_output = self.intermediate(layer_output) - layer_output = hidden_states + self.output(layer_output) + layer_output = self.output(self.intermediate(layer_output)) + + if self.layer_scale_parameters is not None: + layer_output = self.layer_scale_parameters[1] * layer_output + + layer_output = hidden_states + self.drop_path(layer_output) layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) return layer_outputs