-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_ewc_pen_with_expan.py
115 lines (72 loc) · 2.79 KB
/
plot_ewc_pen_with_expan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
from copy import deepcopy
import pandas as pd
import h5py
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy import stats
import matplotlib.pylab as pylab
params = {'axes.titlesize':'x-large',
'axes.labelsize': 'x-large'}
pylab.rcParams.update(params)
DIRECTORY = 'final/plots/' # TODO change this as needed
def parse_h5(filename):
f = h5py.File(filename, 'r')
ewc_pens = []
avg_accs = []
for data in f['ewc_pen']:
ewc_pens.append(data)
for data in f['avg_acc']:
avg_accs.append(data)
f.close()
return ewc_pens, avg_accs
def plot_line_avg_acc(avg_accuracies, labels):
plt.figure()
for i, avg_acc in enumerate(avg_accuracies):
plt.plot(avg_acc, label=labels[i])
plt.ylabel('Average Accuracy on All Tasks')
plt.xlabel('Total Task Count')
plt.xlim(1, len(avg_accuracies[0]))
plt.legend(ncol=1, fancybox=True, shadow=True)
# plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.05),
# ncol=3, fancybox=True, shadow=True)
plt.savefig('{}avg_acc.pdf'.format(DIRECTORY), dpi=300, format='pdf')
def plot_line_ewc_pen(ewc_pens, labels):
plt.figure()
for i, ewc_pen in enumerate(ewc_pens):
plt.plot(ewc_pen, label=labels[i])
plt.ylabel('EWC Loss Penalty')
plt.xlabel('Total Task Count')
plt.xlim(1, len(ewc_pens[0]))
plt.ylim(0.0, 0.8)
markers = [2, 4, 6, 10, 16, 40]
for mark in markers:
plt.axvline(x=mark, color='r')
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.05),
fancybox=True, shadow=True, ncol=5)
# plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.05),
# ncol=3, fancybox=True, shadow=True)
plt.savefig('{}ewc_pen_w_exp.pdf'.format(DIRECTORY), dpi=300, format='pdf')
def main():
#sns.set(color_codes=True)
parser = argparse.ArgumentParser(description='Plotting Tool')
parser.add_argument('--filenames',
nargs='+', type=str, default=['NONE'], metavar='FILENAMES',
help='names of .h5 files containing experimental result data')
args = parser.parse_args()
print(args.filenames)
ewc_pens_list = []
avg_accs_list = []
for filename in args.filenames:
ewc_pens, avg_accs = parse_h5(filename)
ewc_pens_list.append(ewc_pens)
avg_accs_list.append(avg_accs)
print(len(ewc_pens_list))
print(len(ewc_pens_list[0]))
print(len(ewc_pens_list[1]))
labels = ["Variable Capacity", "Fixed Size"]
# plot_line_avg_acc(avg_accs_list, labels)
plot_line_ewc_pen(ewc_pens_list, labels)
if __name__ == '__main__':
main()