-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
34 lines (29 loc) · 1.24 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import pickle
import logging
from utils import RandomSplitter
def load_data(data_config,
test_splits=(0.1, 0.9),
seed=None):
with open(data_config.val_path, "rb") as f:
val_data = pickle.load(f)
with open(data_config.test_path, "rb") as f:
test_data = pickle.load(f)
val_acc = (torch.argmax(val_data["logits"], dim=1)
== val_data["labels"]).float().mean().item()
test_acc = (torch.argmax(test_data["logits"], dim=1)
== test_data["labels"]).float().mean().item()
logging.info("Dataset: val_acc: {:.4f}, test_acc: {:.4f}".format(val_acc, test_acc))
test_splitter = RandomSplitter(splits=test_splits,
num=test_data["logits"].shape[0],
seed=seed)
test_train_data, test_test_data = {}, {}
test_train_data["logits"], test_test_data["logits"] = test_splitter.split(
test_data["logits"])
test_train_data["labels"], test_test_data["labels"] = test_splitter.split(
test_data["labels"]
)
test_train_data["features"], test_test_data["features"] = test_splitter.split(
test_data["features"]
)
return val_data, test_train_data, test_test_data