diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index d0224e3caa5b..f422b17b204f 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -575,7 +575,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.swin.modeling_swin.SwinLayer with SwinDropPath->ClapDropPath, Swin->ClapAudio class ClapAudioLayer(nn.Module): - def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): + def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.shift_size = shift_size @@ -583,7 +583,7 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): self.input_resolution = input_resolution self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attention = ClapAudioAttention(config, dim, num_heads, window_size=self.window_size) - self.drop_path = ClapDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + self.drop_path = ClapDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.intermediate = ClapAudioIntermediate(config, dim) self.output = ClapAudioOutput(config, dim) @@ -712,6 +712,7 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d dim=dim, input_resolution=input_resolution, num_heads=num_heads, + drop_path_rate=drop_path[i], shift_size=0 if (i % 2 == 0) else config.window_size // 2, ) for i in range(depth) diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 8d639131b841..2d5272e8642e 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -558,7 +558,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin class DonutSwinLayer(nn.Module): - def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): + def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.shift_size = shift_size @@ -566,7 +566,7 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): self.input_resolution = input_resolution self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size) - self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + self.drop_path = DonutSwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.intermediate = DonutSwinIntermediate(config, dim) self.output = DonutSwinOutput(config, dim) @@ -695,6 +695,7 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d dim=dim, input_resolution=input_resolution, num_heads=num_heads, + drop_path_rate=drop_path[i], shift_size=0 if (i % 2 == 0) else config.window_size // 2, ) for i in range(depth) diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 9a40e0504598..598e1d8186a2 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -520,16 +520,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MaskFormerSwinLayer(nn.Module): - def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): + def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0): super().__init__() self.shift_size = shift_size self.window_size = config.window_size self.input_resolution = input_resolution self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attention = MaskFormerSwinAttention(config, dim, num_heads, self.window_size) - self.drop_path = ( - MaskFormerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() - ) + self.drop_path = MaskFormerSwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.intermediate = MaskFormerSwinIntermediate(config, dim) self.output = MaskFormerSwinOutput(config, dim) @@ -644,6 +642,7 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d dim=dim, input_resolution=input_resolution, num_heads=num_heads, + drop_path_rate=drop_path[i], shift_size=0 if (i % 2 == 0) else config.window_size // 2, ) for i in range(depth) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 45383a36d9be..23f0ba6da620 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -635,7 +635,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class SwinLayer(nn.Module): - def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): + def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.shift_size = shift_size @@ -643,7 +643,7 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): self.input_resolution = input_resolution self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size) - self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + self.drop_path = SwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.intermediate = SwinIntermediate(config, dim) self.output = SwinOutput(config, dim) @@ -771,6 +771,7 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d dim=dim, input_resolution=input_resolution, num_heads=num_heads, + drop_path_rate=drop_path[i], shift_size=0 if (i % 2 == 0) else config.window_size // 2, ) for i in range(depth) diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 035b31e8d43b..f1aa0bfef743 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -742,7 +742,14 @@ def build(self, input_shape=None): class TFSwinLayer(keras.layers.Layer): def __init__( - self, config, dim, input_resolution: Tuple[int, int], num_heads: int, shift_size: int = 0, **kwargs + self, + config, + dim, + input_resolution: Tuple[int, int], + num_heads: int, + drop_path_rate: float = 0.0, + shift_size: int = 0, + **kwargs, ) -> None: super().__init__(**kwargs) self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -754,8 +761,8 @@ def __init__( self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") self.attention = TFSwinAttention(config, dim, num_heads, name="attention") self.drop_path = ( - TFSwinDropPath(config.drop_path_rate, name="drop_path") - if config.drop_path_rate > 0.0 + TFSwinDropPath(drop_path_rate, name="drop_path") + if drop_path_rate > 0.0 else keras.layers.Activation("linear", name="drop_path") ) self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") @@ -913,6 +920,7 @@ def __init__( input_resolution=input_resolution, num_heads=num_heads, shift_size=0 if (i % 2 == 0) else config.window_size // 2, + drop_path_rate=drop_path[i], name=f"blocks.{i}", ) for i in range(depth) diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index b0a773c8af34..d6bd8da9bed6 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -482,7 +482,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.swinv2.modeling_swinv2.Swinv2Layer with Swinv2->Swin2SR class Swin2SRLayer(nn.Module): - def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretrained_window_size=0): + def __init__( + self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0, pretrained_window_size=0 + ): super().__init__() self.input_resolution = input_resolution window_size, shift_size = self._compute_window_shift( @@ -500,7 +502,7 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretr else (pretrained_window_size, pretrained_window_size), ) self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) - self.drop_path = Swin2SRDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + self.drop_path = Swin2SRDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.intermediate = Swin2SRIntermediate(config, dim) self.output = Swin2SROutput(config, dim) self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 0c30e739a48f..191923958cfb 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -683,7 +683,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Swinv2Layer(nn.Module): - def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretrained_window_size=0): + def __init__( + self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0, pretrained_window_size=0 + ): super().__init__() self.input_resolution = input_resolution window_size, shift_size = self._compute_window_shift( @@ -701,7 +703,7 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretr else (pretrained_window_size, pretrained_window_size), ) self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) - self.drop_path = Swinv2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + self.drop_path = Swinv2DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.intermediate = Swinv2Intermediate(config, dim) self.output = Swinv2Output(config, dim) self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) @@ -819,6 +821,7 @@ def __init__( dim=dim, input_resolution=input_resolution, num_heads=num_heads, + drop_path_rate=drop_path[i], shift_size=0 if (i % 2 == 0) else config.window_size // 2, pretrained_window_size=pretrained_window_size, )