-
Notifications
You must be signed in to change notification settings - Fork 18
/
count_samples.py
94 lines (84 loc) · 2.71 KB
/
count_samples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import sys
import colorsys
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
def report_samples(infile):
counts = get_sample_counts(infile)
for agent in sorted(counts.keys()):
print('Agent: %s' % agent)
print('Types')
agent_types = sorted(set(c for epoch in counts[agent] for c in epoch))
agent_types = [t for t in agent_types if '<unk>' not in t and '?' not in t]
print('Counters')
counters = [Counter(epoch) for epoch in counts[agent]]
print('Counts')
mat = np.array([[c[t] for c in counters] for t in agent_types])
print('Plots')
for t, row in zip(agent_types, range(mat.shape[0])):
print('Plot: %s' % t)
plt.plot(np.arange(len(counters)), mat[row, :])
plt.legend(['%s: %s' % (agent, t) for t in agent_types])
plt.show()
def get_sample_counts(infile):
agent = None
counts = defaultdict(list)
current_samples = []
for i, line in enumerate(infile):
if i % 100000 == 0:
print('Line %d' % i)
line = line.strip()
if agent is None:
if line.endswith(' samples:'):
agent = line[:-len(' samples:')]
else:
if ' -> ' in line:
current_samples.append(parse_sample(line))
else:
counts[agent].append(current_samples)
current_samples = []
if line.endswith(' samples:'):
agent = line[:-len(' samples:')]
else:
agent = None
return counts
def parse_sample(line):
'''
>>> parse_sample("'teal' -> (180, 100, 100)")
"'teal' -> C"
>>> parse_sample("(240, 100, 100) -> 'blue'")
"B -> 'blue'"
'''
inp, out = line.split(' -> ')
return '%s -> %s' % (normalize_color(inp), normalize_color(out))
def normalize_color(fragment):
'''
>>> normalize_color("'blue'")
"'blue'"
>>> normalize_color("(60, 100, 100)")
'Y'
>>> normalize_color("(120, 100, 100)")
'?'
>>> normalize_color("(180, 100, 100)")
'C'
'''
value = eval(fragment)
if isinstance(value, tuple):
hsv_0_1 = (value[0] / 360.0, value[1] / 100.0, value[2] / 100.0)
r, g, b = colorsys.hsv_to_rgb(*hsv_0_1)
components = 'R' * (r > 0.5) + 'G' * (g > 0.5) + 'B' * (b > 0.5)
return {
'': '?',
'R': '?',
'G': '?',
'B': 'B',
'RG': 'Y',
'RB': '?',
'GB': 'C',
'RGB': '?',
}[components]
else:
return fragment
if __name__ == '__main__':
with open(sys.argv[1], 'r') as infile:
report_samples(infile)