@@ -139,30 +139,38 @@ def fit(train_gen, valid_gen, epochs):
139
139
Nt = len (train_gen )
140
140
141
141
prev_loss = np .inf
142
+ epoch_dice_loss = tf .keras .metrics .Mean ()
143
+ epoch_dice_loss_val = tf .keras .metrics .Mean ()
142
144
143
145
for e in range (epochs ):
144
146
print ('Epoch {}/{}' .format (e + 1 ,epochs ))
145
147
b = 0
146
148
for Xb , yb in train_gen :
147
149
b += 1
148
150
loss = train_step (Xb , yb )
149
- stdout .write ('\r Batch: {}/{} - dice_loss: {:.4f}' .format (b , Nt , loss ))
151
+ epoch_dice_loss .update_state (loss )
152
+ stdout .write ('\r Batch: {}/{} - dice_loss: {:.4f}' .format (b , Nt , epoch_dice_loss .result ()))
150
153
stdout .flush ()
151
154
152
155
for Xb , yb in valid_gen :
153
156
loss_val = test_step (Xb , yb )
154
- stdout .write ('\n dice_loss_val: {:.4f}' .format (loss_val ))
157
+ epoch_dice_loss_val .update_state (loss_val )
158
+ stdout .write ('\n dice_loss_val: {:.4f}' .format (epoch_dice_loss_val .result ()))
155
159
stdout .flush ()
156
160
157
161
# save models
158
162
print (' ' )
159
- if loss_val [ 0 ] < prev_loss :
163
+ if epoch_dice_loss_val . result () < prev_loss :
160
164
E .save_weights (path + '/Ensembler.h5' )
161
165
print ("Validation loss decresaed from {:.4f} to {:.4f}. Models' weights are now saved."
162
- .format (prev_loss , loss_val ))
163
- prev_loss = losses_val [ 0 ]
166
+ .format (prev_loss , epoch_dice_loss_val . result () ))
167
+ prev_loss = epoch_dice_loss_val . result ()
164
168
else :
165
169
print ("Validation loss did not decrese from {:.4f}." .format (prev_loss ))
166
170
print (' ' )
167
171
172
+ # reset losses state
173
+ epoch_dice_loss .reset_states ()
174
+ epoch_dice_loss_val .reset_states ()
175
+
168
176
del Xb , yb
0 commit comments