Skip to content

Commit 636d21d

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Add Squeeze and Excitation to DenseNets (facebookresearch#427)
Summary: Pull Request resolved: facebookresearch#427 Plugged in the Squeeze and Excitation layer to DenseNets Differential Revision: D20358700 fbshipit-source-id: 2ef6df1b7257c85d97ec78a7c842cd9824ab253d
1 parent 2732a5e commit 636d21d

File tree

2 files changed

+63
-13
lines changed

2 files changed

+63
-13
lines changed

classy_vision/models/densenet.py

+42-12
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,24 @@
1616

1717
from . import register_model
1818
from .classy_model import ClassyModel
19+
from .common import SqueezeAndExcitationLayer
1920

2021

2122
# global setting for in-place ReLU:
2223
INPLACE = True
2324

2425

2526
class _DenseLayer(nn.Sequential):
26-
"""
27-
Single layer of a DenseNet.
28-
"""
29-
30-
def __init__(self, in_planes, growth_rate=32, expansion=4):
27+
"""Single layer of a DenseNet."""
3128

29+
def __init__(
30+
self,
31+
in_planes,
32+
growth_rate=32,
33+
expansion=4,
34+
use_se=False,
35+
se_reduction_ratio=16,
36+
):
3237
# assertions:
3338
assert is_pos_int(in_planes)
3439
assert is_pos_int(growth_rate)
@@ -56,6 +61,13 @@ def __init__(self, in_planes, growth_rate=32, expansion=4):
5661
bias=False,
5762
),
5863
)
64+
if use_se:
65+
self.add_module(
66+
"se",
67+
SqueezeAndExcitationLayer(
68+
growth_rate, reduction_ratio=se_reduction_ratio
69+
),
70+
)
5971

6072
def forward(self, x):
6173
new_features = super(_DenseLayer, self).forward(x)
@@ -98,22 +110,27 @@ def __init__(
98110
expansion,
99111
small_input,
100112
final_bn_relu,
113+
use_se=False,
114+
se_reduction_ratio=16,
101115
):
102116
"""
103117
Implementation of a standard densely connected network (DenseNet).
104118
105-
Set `small_input` to `True` for 32x32 sized image inputs.
106-
107-
Set `final_bn_relu` to `False` to exclude the final batchnorm and ReLU
108-
layers. These settings are useful when
109-
training Siamese networks.
110-
111119
Contains the following attachable blocks:
112120
block{block_idx}-{idx}: This is the output of each dense block,
113121
indexed by the block index and the index of the dense layer
114122
transition-{idx}: This is the output of the transition layers
115123
trunk_output: The final output of the `DenseNet`. This is
116124
where a `fully_connected` head is normally attached.
125+
126+
Args:
127+
small_input: set to `True` for 32x32 sized image inputs.
128+
final_bn_relu: set to `False` to exclude the final batchnorm and
129+
ReLU layers. These settings are useful when training Siamese
130+
networks.
131+
use_se: Enable squeeze and excitation
132+
se_reduction_ratio: The reduction ratio to apply in the excitation
133+
stage. Only used if `use_se` is `True`.
117134
"""
118135
super().__init__()
119136

@@ -158,6 +175,8 @@ def __init__(
158175
idx,
159176
growth_rate=growth_rate,
160177
expansion=expansion,
178+
use_se=use_se,
179+
se_reduction_ratio=se_reduction_ratio,
161180
)
162181
blocks.append(block)
163182
num_planes = num_planes + num_layers * growth_rate
@@ -192,7 +211,14 @@ def _make_trunk_output_block(self, num_planes, final_bn_relu):
192211
return self.build_attachable_block("trunk_output", layers)
193212

194213
def _make_dense_block(
195-
self, num_layers, in_planes, block_idx, growth_rate=32, expansion=4
214+
self,
215+
num_layers,
216+
in_planes,
217+
block_idx,
218+
growth_rate=32,
219+
expansion=4,
220+
use_se=False,
221+
se_reduction_ratio=16,
196222
):
197223
assert is_pos_int(in_planes)
198224
assert is_pos_int(growth_rate)
@@ -208,6 +234,8 @@ def _make_dense_block(
208234
in_planes + idx * growth_rate,
209235
growth_rate=growth_rate,
210236
expansion=expansion,
237+
use_se=use_se,
238+
se_reduction_ratio=se_reduction_ratio,
211239
),
212240
)
213241
)
@@ -233,6 +261,8 @@ def from_config(cls, config: Dict[str, Any]) -> "DenseNet":
233261
"expansion": config.get("expansion", 4),
234262
"small_input": config.get("small_input", False),
235263
"final_bn_relu": config.get("final_bn_relu", True),
264+
"use_se": config.get("use_se", False),
265+
"se_reduction_ratio": config.get("se_reduction_ratio", 16),
236266
}
237267
return cls(**config)
238268

test/models_densenet_test.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,27 @@
3030
"zero_init_bias": True,
3131
}
3232
],
33-
}
33+
},
34+
"small_densenet_se": {
35+
"name": "densenet",
36+
"num_blocks": [1, 1, 1, 1],
37+
"init_planes": 4,
38+
"growth_rate": 32,
39+
"expansion": 4,
40+
"final_bn_relu": True,
41+
"small_input": True,
42+
"use_se": True,
43+
"heads": [
44+
{
45+
"name": "fully_connected",
46+
"unique_id": "default_head",
47+
"num_classes": 1000,
48+
"fork_block": "trunk_output",
49+
"in_plane": 60,
50+
"zero_init_bias": True,
51+
}
52+
],
53+
},
3454
}
3555

3656

0 commit comments

Comments
 (0)