10
10
11
11
from .utils import sigmoid , softmax , get_list_of_image_predictions , pytorch_where
12
12
from . import callbacks as cbk
13
- from .unet_models import UNetResNet
13
+ from .architectures import UNetResNet , LargeKernelMatters , UNetResNetWithDepth , StackingFCN , StackingFCNWithDepth , \
14
+ EmptinessClassifier
14
15
from .lovasz_losses import lovasz_hinge
15
16
16
- PRETRAINED_NETWORKS = {'ResNet18' : {'model' : UNetResNet ,
17
- 'model_config' : {'encoder_depth' : 18 , 'use_hypercolumn' : False ,
18
- 'dropout_2d' : 0.0 , 'pretrained' : True ,
19
- },
20
- 'init_weights' : False },
21
- 'ResNet34' : {'model' : UNetResNet ,
22
- 'model_config' : {'encoder_depth' : 34 , 'use_hypercolumn' : False ,
23
- 'dropout_2d' : 0.0 , 'pretrained' : True ,
24
- },
25
- 'init_weights' : False },
26
- 'ResNet50' : {'model' : UNetResNet ,
27
- 'model_config' : {'encoder_depth' : 50 , 'use_hypercolumn' : False ,
28
- 'dropout_2d' : 0.0 , 'pretrained' : True ,
29
- },
30
- 'init_weights' : False },
31
- 'ResNet101' : {'model' : UNetResNet ,
32
- 'model_config' : {'encoder_depth' : 101 , 'use_hypercolumn' : False ,
33
- 'dropout_2d' : 0.0 , 'pretrained' : True ,
34
- },
35
- 'init_weights' : False },
36
- 'ResNet152' : {'model' : UNetResNet ,
37
- 'model_config' : {'encoder_depth' : 152 , 'use_hypercolumn' : False ,
38
- 'dropout_2d' : 0.0 , 'pretrained' : True ,
39
- },
40
- 'init_weights' : False },
41
- 'ResNetHyper18' : {'model' : UNetResNet ,
42
- 'model_config' : {'encoder_depth' : 18 , 'use_hypercolumn' : True ,
43
- 'dropout_2d' : 0.0 , 'pretrained' : True ,
44
- },
45
- 'init_weights' : False },
46
- 'ResNetHyper34' : {'model' : UNetResNet ,
17
+ ARCHITECTURES = {'UNetResNet' : {'model' : UNetResNet ,
18
+ 'model_config' : {'encoder_depth' : 34 , 'use_hypercolumn' : True ,
19
+ 'dropout_2d' : 0.0 , 'pretrained' : True ,
20
+ },
21
+ 'init_weights' : False },
22
+
23
+ 'UNetResNetWithDepth' : {'model' : UNetResNetWithDepth ,
47
24
'model_config' : {'encoder_depth' : 34 , 'use_hypercolumn' : True ,
48
25
'dropout_2d' : 0.0 , 'pretrained' : True ,
49
26
},
50
27
'init_weights' : False },
51
- 'ResNetHyper50' : {'model' : UNetResNet ,
52
- 'model_config' : {'encoder_depth' : 50 , 'use_hypercolumn' : True ,
53
- 'dropout_2d' : 0.0 , 'pretrained' : True ,
28
+ 'LargeKernelMatters' : {'model' : LargeKernelMatters ,
29
+ 'model_config' : {'encoder_depth' : 34 , 'pretrained' : True ,
30
+ 'kernel_size' : 9 , 'internal_channels' : 21 ,
31
+ 'dropout_2d' : 0.0 , 'use_relu' : True
32
+ },
33
+ 'init_weights' : False },
34
+ 'StackingFCN' : {'model' : StackingFCN ,
35
+ 'model_config' : {'input_model_nr' : 18 , 'filter_nr' : 32 , 'dropout_2d' : 0.0
36
+ },
37
+ 'init_weights' : True },
38
+ 'StackingFCNWithDepth' : {'model' : StackingFCNWithDepth ,
39
+ 'model_config' : {'input_model_nr' : 18 , 'filter_nr' : 32 , 'dropout_2d' : 0.0
40
+ },
41
+ 'init_weights' : True },
42
+ 'EmptinessClassifier' : {'model' : EmptinessClassifier ,
43
+ 'model_config' : {'encoder_depth' : 18 , 'pretrained' : True ,
54
44
},
55
45
'init_weights' : False },
56
- 'ResNetHyper101' : {'model' : UNetResNet ,
57
- 'model_config' : {'encoder_depth' : 101 , 'use_hypercolumn' : True ,
58
- 'dropout_2d' : 0.0 , 'pretrained' : True ,
59
- },
60
- 'init_weights' : False },
61
- 'ResNetHyper152' : {'model' : UNetResNet ,
62
- 'model_config' : {'encoder_depth' : 152 , 'use_hypercolumn' : True ,
63
- 'dropout_2d' : 0.0 , 'pretrained' : True ,
64
- },
65
- 'init_weights' : False },
66
- }
46
+ }
67
47
68
48
69
- class PyTorchUNet (Model ):
49
+ class SegmentationModel (Model ):
70
50
def __init__ (self , architecture_config , training_config , callbacks_config ):
71
51
super ().__init__ (architecture_config , training_config , callbacks_config )
72
52
self .activation_func = self .architecture_config ['model_params' ]['activation' ]
@@ -75,7 +55,7 @@ def __init__(self, architecture_config, training_config, callbacks_config):
75
55
self .weight_regularization = weight_regularization
76
56
self .optimizer = optim .Adam (self .weight_regularization (self .model , ** architecture_config ['regularizer_params' ]),
77
57
** architecture_config ['optimizer_params' ])
78
- self .callbacks = callbacks_unet (self .callbacks_config )
58
+ self .callbacks = callbacks_network (self .callbacks_config )
79
59
80
60
def fit (self , datagen , validation_datagen = None , meta_valid = None ):
81
61
self ._initialize_model_weights ()
@@ -179,8 +159,8 @@ def _transform(self, datagen, validation_datagen=None, **kwargs):
179
159
return outputs
180
160
181
161
def set_model (self ):
182
- encoder = self .architecture_config ['model_params' ]['encoder ' ]
183
- config = PRETRAINED_NETWORKS [ encoder ]
162
+ architecture = self .architecture_config ['model_params' ]['architecture ' ]
163
+ config = ARCHITECTURES [ architecture ]
184
164
self .model = config ['model' ](num_classes = self .architecture_config ['model_params' ]['out_channels' ],
185
165
** config ['model_config' ])
186
166
self ._initialize_model_weights = lambda : None
@@ -190,6 +170,7 @@ def set_loss(self):
190
170
raise NotImplementedError ('No softmax loss defined' )
191
171
elif self .activation_func == 'sigmoid' :
192
172
loss_function = lovasz_loss
173
+ # loss_function = nn.BCEWithLogitsLoss()
193
174
else :
194
175
raise Exception ('Only softmax and sigmoid activations are allowed' )
195
176
self .loss_function = [('mask' , loss_function , 1.0 )]
@@ -209,6 +190,84 @@ def load(self, filepath):
209
190
return self
210
191
211
192
193
+ class SegmentationModelWithDepth (SegmentationModel ):
194
+ def __init__ (self , architecture_config , training_config , callbacks_config ):
195
+ super ().__init__ (architecture_config , training_config , callbacks_config )
196
+ self .activation_func = self .architecture_config ['model_params' ]['activation' ]
197
+ self .set_model ()
198
+ self .set_loss ()
199
+ self .weight_regularization = weight_regularization
200
+ self .optimizer = optim .Adam (self .weight_regularization (self .model , ** architecture_config ['regularizer_params' ]),
201
+ ** architecture_config ['optimizer_params' ])
202
+ self .callbacks = callbacks_network (self .callbacks_config )
203
+
204
+ def _fit_loop (self , data ):
205
+ X = data [0 ]
206
+ D = data [1 ]
207
+ targets_tensors = data [2 :]
208
+
209
+ if torch .cuda .is_available ():
210
+ X = Variable (X ).cuda ()
211
+ D = Variable (D ).cuda ()
212
+ targets_var = []
213
+ for target_tensor in targets_tensors :
214
+ targets_var .append (Variable (target_tensor ).cuda ())
215
+ else :
216
+ X = Variable (X )
217
+ D = Variable (D )
218
+ targets_var = []
219
+ for target_tensor in targets_tensors :
220
+ targets_var .append (Variable (target_tensor ))
221
+
222
+ self .optimizer .zero_grad ()
223
+ outputs_batch = self .model (X , D )
224
+ partial_batch_losses = {}
225
+
226
+ if len (self .output_names ) == 1 :
227
+ for (name , loss_function , weight ), target in zip (self .loss_function , targets_var ):
228
+ batch_loss = loss_function (outputs_batch , target ) * weight
229
+ else :
230
+ for (name , loss_function , weight ), output , target in zip (self .loss_function , outputs_batch , targets_var ):
231
+ partial_batch_losses [name ] = loss_function (output , target ) * weight
232
+ batch_loss = sum (partial_batch_losses .values ())
233
+ partial_batch_losses ['sum' ] = batch_loss
234
+
235
+ batch_loss .backward ()
236
+ self .optimizer .step ()
237
+
238
+ return partial_batch_losses
239
+
240
+ def _transform (self , datagen , validation_datagen = None , ** kwargs ):
241
+ self .model .eval ()
242
+
243
+ batch_gen , steps = datagen
244
+ outputs = {}
245
+ for batch_id , data in enumerate (batch_gen ):
246
+ X = data [0 ]
247
+ D = data [1 ]
248
+
249
+ if torch .cuda .is_available ():
250
+ X = Variable (X , volatile = True ).cuda ()
251
+ D = Variable (D , volatile = True ).cuda ()
252
+ else :
253
+ X = Variable (X , volatile = True )
254
+ D = Variable (D , volatile = True )
255
+ outputs_batch = self .model (X , D )
256
+
257
+ if len (self .output_names ) == 1 :
258
+ outputs .setdefault (self .output_names [0 ], []).append (outputs_batch .data .cpu ().numpy ())
259
+ else :
260
+ for name , output in zip (self .output_names , outputs_batch ):
261
+ output_ = output .data .cpu ().numpy ()
262
+ outputs .setdefault (name , []).append (output_ )
263
+ if batch_id == steps :
264
+ break
265
+ self .model .train ()
266
+ outputs = {'{}_prediction' .format (name ): get_list_of_image_predictions (outputs_ ) for name , outputs_ in
267
+ outputs .items ()}
268
+ return outputs
269
+
270
+
212
271
class FocalWithLogitsLoss (nn .Module ):
213
272
def __init__ (self , alpha = 1.0 , gamma = 1.0 ):
214
273
super ().__init__ ()
@@ -235,7 +294,7 @@ def __init__(self, smooth=0, eps=1e-7):
235
294
236
295
def forward (self , output , target ):
237
296
return 1 - (2 * torch .sum (output * target ) + self .smooth ) / (
238
- torch .sum (output ) + torch .sum (target ) + self .smooth + self .eps )
297
+ torch .sum (output ) + torch .sum (target ) + self .smooth + self .eps )
239
298
240
299
241
300
def weight_regularization (model , regularize , weight_decay_conv2d ):
@@ -249,12 +308,13 @@ def weight_regularization(model, regularize, weight_decay_conv2d):
249
308
return parameter_list
250
309
251
310
252
- def callbacks_unet (callbacks_config ):
311
+ def callbacks_network (callbacks_config ):
253
312
experiment_timing = cbk .ExperimentTiming (** callbacks_config ['experiment_timing' ])
254
313
model_checkpoints = cbk .ModelCheckpoint (** callbacks_config ['model_checkpoint' ])
255
314
lr_scheduler = cbk .ReduceLROnPlateauScheduler (** callbacks_config ['reduce_lr_on_plateau_scheduler' ])
256
315
training_monitor = cbk .TrainingMonitor (** callbacks_config ['training_monitor' ])
257
316
validation_monitor = cbk .ValidationMonitor (** callbacks_config ['validation_monitor' ])
317
+ # validation_monitor = cbk.ValidationMonitorEmptiness(**callbacks_config['validation_monitor'])
258
318
neptune_monitor = cbk .NeptuneMonitor (** callbacks_config ['neptune_monitor' ])
259
319
early_stopping = cbk .EarlyStopping (** callbacks_config ['early_stopping' ])
260
320
0 commit comments