-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtraining.py
146 lines (123 loc) · 5.67 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import numpy as np
import os
import random
import signal
import leadsheet
import constants
import param_cvt
import pickle as pickle
import traceback
from pprint import pformat
BATCH_SIZE = 10
SEGMENT_STEP = constants.WHOLE//constants.RESOLUTION_SCALAR
SEGMENT_LEN = 4*SEGMENT_STEP
def set_params(batch_size, segment_step, segment_len):
global BATCH_SIZE
global SEGMENT_STEP
global SEGMENT_LEN
BATCH_SIZE = batch_size
SEGMENT_STEP = segment_step
SEGMENT_LEN = segment_len
VALIDATION_CT = 5
def find_leadsheets(dirpath):
return [os.path.join(dirpath, fname) for fname in os.listdir(dirpath) if fname[-3:] == '.ls']
def filter_leadsheets(leadsheets):
new_leadsheets=[]
for lsfn in leadsheets:
print("---- {} ----".format(lsfn))
c,m = leadsheet.parse_leadsheet(lsfn, verbose=True)
length = leadsheet.get_leadsheet_length(c,m)
if length < SEGMENT_LEN:
print("Leadsheet {} is too short! Skipping...".format(lsfn))
else:
new_leadsheets.append(lsfn)
print("Found {} leadsheets.".format(len(leadsheets)))
return new_leadsheets
def get_batch(leadsheets, with_sample=False):
"""
Get a batch
leadsheets should be a list of dataset lists of (chord, melody) tuples, or just a dataset list of tuples
returns: chords, melodies
"""
if not isinstance(leadsheets[0], list):
leadsheets = [leadsheets]
sample_datasets = [random.randrange(len(leadsheets)) for _ in range(BATCH_SIZE)]
sample_fns = [random.choice(leadsheets[i]) for i in sample_datasets]
loaded_samples = [leadsheet.parse_leadsheet(lsfn) for lsfn in sample_fns]
sample_lengths = [leadsheet.get_leadsheet_length(c,m) for c,m in loaded_samples]
starts = [(0 if l==SEGMENT_LEN else random.randrange(0,l-SEGMENT_LEN,SEGMENT_STEP)) for l in sample_lengths]
sliced = [leadsheet.slice_leadsheet(c,m,s,s+SEGMENT_LEN) for (c,m),s in zip(loaded_samples, starts)]
res = list(zip(*sliced))
sample_sources = ["{}: starting at {} = bar {}".format(fn, start, start/(constants.WHOLE//constants.RESOLUTION_SCALAR)) for fn,start in zip(sample_fns, starts)]
if with_sample:
return res, sample_sources
else:
return res
def generate(model, leadsheets, filename, with_vis=False, batch=None):
if batch is None:
batch = get_batch(leadsheets, True)
(chords, melody), sample_sources = batch
generated_out, chosen, vis_probs, vis_info = model.produce(chords, melody)
if with_vis:
with open("{}_sources.txt".format(filename), "w") as f:
f.write('\n'.join(sample_sources))
np.save('{}_chosen.npy'.format(filename), chosen)
np.save('{}_probs.npy'.format(filename), vis_probs)
for i,v in enumerate(vis_info):
np.save('{}_info_{}.npy'.format(filename,i), v)
for samplenum, (melody, chords) in enumerate(zip(generated_out, chords)):
leadsheet.write_leadsheet(chords, melody, '{}_{}.ls'.format(filename, samplenum))
def validate(model, validation_leadsheets):
accum_loss = None
accum_infos = None
for i in range(VALIDATION_CT):
loss, infos = model.eval(*get_batch(validation_leadsheets))
if accum_loss is None:
accum_loss = loss
accum_infos = infos
else:
accum_loss += loss
for k in accum_info.keys():
accum_loss[k] += accum_infos[k]
accum_loss /= VALIDATION_CT
for k in accum_info.keys():
accum_loss[k] /= VALIDATION_CT
return accum_loss, accum_info
def validate_generate(model, validation_leadsheets, generated_dir):
for lsfn in validation_leadsheets:
ch,mel = leadsheet.parse_leadsheet(lsfn)
batch = ([ch],[mel]), [lsfn]
curdir = os.path.join(generated_dir, os.path.splitext(os.path.basename(lsfn))[0])
os.makedirs(curdir)
generate(model, None, os.path.join(curdir, "generated"), with_vis=True, batch=batch)
def train(model,leadsheets,num_updates,outputdir,start=0,save_params_interval=5000,validation_leadsheets=None,validation_generate_ct=1,auto_connectome_keys=None):
stopflag = [False]
def signal_handler(signame, sf):
stopflag[0] = True
print("Caught interrupt, waiting until safe. Press again to force terminate")
signal.signal(signal.SIGINT, old_handler)
old_handler = signal.signal(signal.SIGINT, signal_handler)
for i in range(start+1,start+num_updates+1):
if stopflag[0]:
break
loss, infos = model.train(*get_batch(leadsheets))
with open(os.path.join(outputdir,'data.csv'),'a') as f:
if i == 1:
f.seek(0)
f.truncate()
f.write("iter, loss, " + ", ".join(k for k,v in sorted(infos.items())) + "\n")
f.write("{}, {}, ".format(i,loss) + ", ".join(str(v) for k,v in sorted(infos.items())) + "\n")
if i % 10 == 0:
print("update {}: {}, info {}".format(i,loss,pformat(infos)))
if save_params_interval is not None and i % save_params_interval == 0:
paramfile = os.path.join(outputdir, 'params{}.p'.format(i))
pickle.dump(model.params,open(paramfile, 'wb'))
if auto_connectome_keys is not None:
param_cvt.main(paramfile, 18, auto_connectome_keys, make_zip=True)
if validation_leadsheets is None:
generate(model, leadsheets, os.path.join(outputdir,'sample{}'.format(i)))
else:
for gen_num in range(validation_generate_ct):
validate_generate(model, validation_leadsheets, os.path.join(outputdir, "validation_{}_sample_{}".format(i,gen_num)))
if not stopflag[0]:
signal.signal(signal.SIGINT, old_handler)