-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
45 lines (40 loc) · 1.58 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import pickle
import os
import torch
from torch import nn
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import numpy as np
import pandas as pd
from sklearn import metrics
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from models import MoralClassifier
from models.custom_transformer_classifier import OneHotMoralClassifier
from data import NewsDataset
import torch
def test(path, gpus):
# load
print("Start")
# file = open('data/nela-covid-2020/combined/headlines_contentmorals_cnn_bart_split.pkl', 'rb')
file = open('data/nela-covid-2020/combined/headlines_cnn_bart_split.pkl', 'rb')
data = pickle.load(file)
file.close()
print("Data Loaded")
test_dataset = NewsDataset(data['test'])
# test_loader = DataLoader(test_dataset, batch_size=32, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, num_workers=4)
# model = OneHotMoralClassifier.load_from_checkpoint(path)
model = OneHotMoralClassifier({}, use_mask=False)
model.load_state_dict(torch.load(path))
trainer = Trainer(gpus=gpus,
distributed_backend='dp')
trainer.test(model, test_dataloaders=test_loader)
if __name__ == '__main__':
# gpus = torch.cuda.device_count() if torch.cuda.is_available() else None
gpus = 1 if torch.cuda.is_available() else None
# path = "final_models/dicriminator_contentmorals_state.pkl"
path = "final_models/discriminator_titlemorals_state.pkl"
test(path, gpus)