-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvalidation.py
68 lines (53 loc) · 1.81 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import time
import statistics
from src.l3svms import *
from src.utils import *
args = get_args(__file__)
TRAIN = args.train_file
TEST = args.test_file
LAND = args.nb_landmarks # default 10
CLUS = args.nb_clusters # default 1
NORM = args.norm # default False
LIN = args.linear # default True
PCA_BOOL = args.pca # default False
ITER = args.nb_iterations # default 1
VERB = args.verbose # default False
YPOS = args.y_pos # default 0
verboseprint = print if VERB else lambda *a, **k: None
verboseprint("training on {}, testing on {}: {} clusters, {} landmarks".format(TRAIN,TEST,CLUS,LAND))
if LIN:
verboseprint("linear kernel")
else:
verboseprint("rbf kernel")
if NORM:
verboseprint("normalized dataset")
else:
verboseprint("scaled data")
t1 = time.time()
# load dataset
try:
train_y,train_x = load_sparse_dataset(TRAIN,norm=NORM,y_pos=YPOS)
test_y,test_x = load_sparse_dataset(TEST,norm=NORM,y_pos=YPOS)
except:
train_y,train_x = load_dense_dataset(TRAIN,norm=NORM,y_pos=YPOS)
test_y,test_x = load_dense_dataset(TEST,norm=NORM,y_pos=YPOS)
t2 = time.time()
verboseprint("dataset loading time:",t2-t1,"s")
if PCA_BOOL:
if LAND > train_x.shape[1]:
raise Exception("When using PCA, the nb landmarks must be at most the nb of features")
verboseprint("landmarks = principal components")
else:
verboseprint("random landmarks")
verboseprint("--------------------\n")
acc_list = []
time_list = []
for it in range(ITER):
acc,time = learning(train_x,train_y,test_x,test_y,verboseprint,CLUS,PCA_BOOL,LIN,LAND)
acc_list.append(acc)
time_list.append(time)
print("Mean accuracy (%), mean stdev (%), mean time (s) over {} iterations:".format(ITER))
try:
print(statistics.mean(acc_list),statistics.stdev(acc_list),statistics.mean(time_list))
except:
print(acc_list[0],0.,time_list[0])