Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/models/dinat/configuration_dinat.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class DinatConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
layer_scale_init_value (`float`, *optional*, defaults to 0.0):
The initial value for the layer scale. Disabled if <=0.

Example:

Expand Down Expand Up @@ -110,6 +112,7 @@ def __init__(
patch_norm=True,
initializer_range=0.02,
layer_norm_eps=1e-5,
layer_scale_init_value=0.0,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -134,3 +137,4 @@ def __init__(
# we set the hidden_size attribute in order to make Dinat work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.layer_scale_init_value = layer_scale_init_value
16 changes: 14 additions & 2 deletions src/transformers/models/dinat/modeling_dinat.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,11 @@ def __init__(self, config, dim, num_heads, dilation, drop_path_rate=0.0):
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = DinatIntermediate(config, dim)
self.output = DinatOutput(config, dim)
self.layer_scale_parameters = (
nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True)
if config.layer_scale_init_value > 0
else None
)

def maybe_pad(self, hidden_states, height, width):
window_size = self.window_size
Expand Down Expand Up @@ -496,11 +501,18 @@ def forward(
if was_padded:
attention_output = attention_output[:, :height, :width, :].contiguous()

if self.layer_scale_parameters is not None:
attention_output = self.layer_scale_parameters[0] * attention_output

hidden_states = shortcut + self.drop_path(attention_output)

layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = hidden_states + self.output(layer_output)
layer_output = self.output(self.intermediate(layer_output))

if self.layer_scale_parameters is not None:
layer_output = self.layer_scale_parameters[1] * layer_output

layer_output = hidden_states + self.drop_path(layer_output)

layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
return layer_outputs
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/nat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/nat/configuration_nat.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class NatConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
layer_scale_init_value (`float`, *optional*, defaults to 0.0):
The initial value for the layer scale. Disabled if <=0.

Example:

Expand Down Expand Up @@ -107,6 +109,7 @@ def __init__(
patch_norm=True,
initializer_range=0.02,
layer_norm_eps=1e-5,
layer_scale_init_value=0.0,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -130,3 +133,4 @@ def __init__(
# we set the hidden_size attribute in order to make Nat work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.layer_scale_init_value = layer_scale_init_value
16 changes: 14 additions & 2 deletions src/transformers/models/nat/modeling_nat.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,11 @@ def __init__(self, config, dim, num_heads, drop_path_rate=0.0):
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = NatIntermediate(config, dim)
self.output = NatOutput(config, dim)
self.layer_scale_parameters = (
nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True)
if config.layer_scale_init_value > 0
else None
)

def maybe_pad(self, hidden_states, height, width):
window_size = self.kernel_size
Expand Down Expand Up @@ -479,11 +484,18 @@ def forward(
if was_padded:
attention_output = attention_output[:, :height, :width, :].contiguous()

if self.layer_scale_parameters is not None:
attention_output = self.layer_scale_parameters[0] * attention_output

hidden_states = shortcut + self.drop_path(attention_output)

layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = hidden_states + self.output(layer_output)
layer_output = self.output(self.intermediate(layer_output))

if self.layer_scale_parameters is not None:
layer_output = self.layer_scale_parameters[1] * layer_output

layer_output = hidden_states + self.drop_path(layer_output)

layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
return layer_outputs
Expand Down