Skip to content

Commit bb4cf65

Browse files
authored
Update train_ensembler.py
1 parent e714765 commit bb4cf65

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

train_ensembler.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -139,30 +139,38 @@ def fit(train_gen, valid_gen, epochs):
139139
Nt = len(train_gen)
140140

141141
prev_loss = np.inf
142+
epoch_dice_loss = tf.keras.metrics.Mean()
143+
epoch_dice_loss_val = tf.keras.metrics.Mean()
142144

143145
for e in range(epochs):
144146
print('Epoch {}/{}'.format(e+1,epochs))
145147
b = 0
146148
for Xb, yb in train_gen:
147149
b += 1
148150
loss = train_step(Xb, yb)
149-
stdout.write('\rBatch: {}/{} - dice_loss: {:.4f}'.format(b, Nt, loss))
151+
epoch_dice_loss.update_state(loss)
152+
stdout.write('\rBatch: {}/{} - dice_loss: {:.4f}'.format(b, Nt, epoch_dice_loss.result()))
150153
stdout.flush()
151154

152155
for Xb, yb in valid_gen:
153156
loss_val = test_step(Xb, yb)
154-
stdout.write('\n dice_loss_val: {:.4f}'.format(loss_val))
157+
epoch_dice_loss_val.update_state(loss_val)
158+
stdout.write('\n dice_loss_val: {:.4f}'.format(epoch_dice_loss_val.result()))
155159
stdout.flush()
156160

157161
# save models
158162
print(' ')
159-
if loss_val[0] < prev_loss:
163+
if epoch_dice_loss_val.result() < prev_loss:
160164
E.save_weights(path + '/Ensembler.h5')
161165
print("Validation loss decresaed from {:.4f} to {:.4f}. Models' weights are now saved."
162-
.format(prev_loss, loss_val))
163-
prev_loss = losses_val[0]
166+
.format(prev_loss, epoch_dice_loss_val.result()))
167+
prev_loss = epoch_dice_loss_val.result()
164168
else:
165169
print("Validation loss did not decrese from {:.4f}.".format(prev_loss))
166170
print(' ')
167171

172+
# reset losses state
173+
epoch_dice_loss.reset_states()
174+
epoch_dice_loss_val.reset_states()
175+
168176
del Xb, yb

0 commit comments

Comments
 (0)