Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xwhan authored Jul 28, 2019
1 parent 8dd499f commit 257b1ae
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,18 @@ def train(cfg):
# model.load_state_dict(torch.load(model_save_path))


print('..........Finished training, start testing.......')
print('\n..........Finished training, start testing.......')

test_data = DataLoader(cfg, documents, mode='test')
model.eval()
print('finished training, testing final model...')
test(model, test_data, cfg['eps'])

print('testing best model...')
model_save_path = 'model/{}/{}_best.pt'.format(cfg['name'], cfg['model_id'])
model.load_state_dict(torch.load(model_save_path))
model.eval()
test(model, test_data, cfg['eps'])
# print('testing best model...')
# model_save_path = 'model/{}/{}_best.pt'.format(cfg['name'], cfg['model_id'])
# model.load_state_dict(torch.load(model_save_path))
# model.eval()
# test(model, test_data, cfg['eps'])


def test(model, test_data, eps):
Expand Down Expand Up @@ -153,8 +153,6 @@ def test(model, test_data, eps):
print('avg_f1', np.mean(f1s))
print('avg_hits', np.mean(hits))



model.train()
return np.mean(f1s), np.mean(hits)

Expand All @@ -177,4 +175,4 @@ def test(model, test_data, eps):
model.eval()
test(model, test_data, cfg['eps'])
else:
assert False, "--train or --test?"
assert False, "--train or --test?"

0 comments on commit 257b1ae

Please sign in to comment.