-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnn_plot.py
executable file
·45 lines (38 loc) · 1.13 KB
/
nn_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#!/usr/bin/python3.6
import numpy as np
import matplotlib.pyplot as plt
import sys
import csv
if __name__ == '__main__':
algs = ["random_hill_climbing", "simulated_annealing", "genetic"]
output = "nn_plot.jpg"
x = []
yTrain = []
yTest = []
for algo in algs:
filename = "./out/" + algo + ".out"
curX = []
curYTrain = []
curYTest = []
with open(filename,"r") as f:
reader = csv.reader(f, delimiter=",", skipinitialspace=True)
for i, line in enumerate(reader):
curX.append(int(line[0]))
curYTrain.append(float(line[1]))
curYTest.append(float(line[2]))
x.append(curX)
yTrain.append(curYTrain)
yTest.append(curYTest)
plt.figure(figsize=(16,16))
axes = plt.gca()
train_line = ["b:","r:","g:"]
test_line = ["b-","r-","g-"]
for algo in algs:
index = algs.index(algo)
line1, = plt.plot(x[index],yTrain[index],train_line[index],linewidth=1,label=algo + "_train")
line2, = plt.plot(x[index],yTest[index],test_line[index],linewidth=1,label=algo + "_test")
# val, = plt.plot(x[2],y[2],plot_line[2],linewidth=1,label="RHC")
plt.xlabel("Iteration")
plt.ylabel("Accuracy")
plt.legend()
plt.savefig(output)