Skip to content

Commit 6203db8

Browse files
authored
Update eval_model.py
1 parent 6d78afc commit 6203db8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

utils/eval_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
def dt():
1818
return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
1919

20-
def eval_turn(model, data_loader, val_version, epoch_num, log_file):
20+
def eval_turn(Config, model, data_loader, val_version, epoch_num, log_file):
2121

2222
model.train(False)
2323

@@ -50,7 +50,7 @@ def eval_turn(model, data_loader, val_version, epoch_num, log_file):
5050
val_loss_recorder.update(loss)
5151
val_celoss_recorder.update(ce_loss)
5252

53-
if outputs[1].size(1) != 2:
53+
if Config.use_dcl and Config.cls_2xmul:
5454
outputs_pred = outputs[0] + outputs[1][:,0:num_cls] + outputs[1][:,num_cls:2*num_cls]
5555
else:
5656
outputs_pred = outputs[0]

0 commit comments

Comments
 (0)