Skip to content

Commit 6470b0e

Browse files
committed
refactor codes: moved functions to util
1 parent 3decef2 commit 6470b0e

18 files changed

+97
-186
lines changed

convert.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
import sys
1919
from preprocessing.vcc2018.feature_reader import Whole_feature_reader
20-
from preprocessing.normalizer import MinMaxScaler, StandardScaler
21-
from preprocessing.utils import read_hdf5, read_txt
22-
from util.wrapper import load, get_default_logdir_output
20+
from util.normalizer import MinMaxScaler, StandardScaler
21+
from util.misc import read_hdf5, read_txt, load, get_default_logdir_output
2322

2423
def main():
2524

mcd_calculate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import sys
2424
from preprocessing.vcc2018.feature_reader import Whole_feature_reader
25-
from preprocessing.normalizer import MinMaxScaler
25+
from util.normalizer import MinMaxScaler
2626

2727
def read_and_synthesize(file_list, arch, MCD, input_feat, output_feat):
2828

model/cdvae-cls-gan-mcc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
import numpy as np
12
import tensorflow as tf
23
from tensorflow.contrib import slim
4+
from util.misc import ValueWindow
35
from util.layers import (GaussianKLD, GaussianLogDensity, GaussianSampleLayer,
46
Layernorm, conv2d_nchw_layernorm, lrelu,
57
kl_loss, log_loss, gradient_penalty_loss)
6-
import numpy as np
7-
from util.wrapper import ValueWindow
88

99
class CDVAECLSGAN(object):
1010
def __init__(self, arch, normalizers=None):

model/cdvae.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
import numpy as np
12
import tensorflow as tf
23
from tensorflow.contrib import slim
4+
from util.misc import ValueWindow
35
from util.layers import (GaussianKLD, GaussianLogDensity, GaussianSampleLayer,
46
Layernorm, conv2d_nchw_layernorm, lrelu,
57
kl_loss, log_loss)
6-
import numpy as np
7-
from util.wrapper import ValueWindow
88

99
class CDVAE(object):
1010
def __init__(self, arch, normalizers=None):

model/vae.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
import numpy as np
12
import tensorflow as tf
23
from tensorflow.contrib import slim
4+
from util.misc import ValueWindow
35
from util.layers import (GaussianKLD, GaussianLogDensity, GaussianSampleLayer,
46
Layernorm, conv2d_nchw_layernorm, lrelu,
57
kl_loss, log_loss)
6-
import numpy as np
7-
from util.wrapper import ValueWindow
88

99
class VAE(object):
1010
def __init__(self, arch, normalizers=None):

preprocessing/vcc2018/calc_stats.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from sklearn.preprocessing import StandardScaler
1515

1616
import sys
17-
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
18-
from utils import (find_files, read_txt, write_hdf5)
19-
from vcc2018.feature_reader import Whole_feature_reader
17+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
18+
from util.misc import (find_files, read_txt, write_hdf5)
19+
from preprocess.vcc2018.feature_reader import Whole_feature_reader
2020

2121
def calc_stats(file_list, feat_param, spk_list, args):
2222

preprocessing/vcc2018/feature_extract.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from scipy.signal import lfilter
2828
from sprocket.speech.feature_extractor import FeatureExtractor
2929

30-
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
31-
from utils import (find_files, read_txt, read_hdf5, write_hdf5)
30+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
31+
from util.misc import (find_files, read_txt, read_hdf5, write_hdf5)
3232

3333
def energy_norm(feat):
3434
en = np.sum(feat + 1e-8, axis=1, keepdims=True)

synthesize.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
import multiprocessing as mp
2424

2525
import sys
26-
from preprocessing.synthesizer import world_synthesis
27-
from preprocessing.utils import read_hdf5
2826
from preprocessing.vcc2018.feature_reader import Whole_feature_reader
29-
from preprocessing.postfilter import fast_MLGV
30-
from preprocessing.f0transformation import log_linear_transformation
27+
from util.synthesizer import world_synthesis
28+
from util.misc import read_hdf5
29+
from util.postfilter import fast_MLGV
30+
from util.f0transformation import log_linear_transformation
3131

3232
def read_and_synthesize(file_list, arch, stats, input_feat, output_feat):
3333

train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
import sys
2121
from preprocessing.vcc2018.feature_reader import Segment_feature_reader
22-
from preprocessing.normalizer import MinMaxScaler
23-
from preprocessing.utils import read_hdf5
22+
from util.normalizer import MinMaxScaler
23+
from util.misc import read_hdf5
2424

2525
def main():
2626

trainer/base.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import numpy as np
33
import logging, os
44

5-
from util.wrapper import load
6-
from util.wrapper import ValueWindow
5+
from util.misc import load, ValueWindow
76

87
class Trainer(object):
98
def __init__(self, model, train_data, arch, args, dirs, ckpt):

trainer/cdvae-cls-gan.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import tensorflow as tf
66
from trainer.base import Trainer
7-
from util.wrapper import ValueWindow
7+
from util.misc import ValueWindow
88
import time
99

1010
class CDVAECLSGANTrainer(Trainer):

trainer/vae.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import tensorflow as tf
66
from trainer.base import Trainer
7-
from util.wrapper import ValueWindow
7+
from util.misc import ValueWindow
88
import time
99

1010
class VAETrainer(Trainer):
File renamed without changes.

preprocessing/utils.py util/misc.py

+74-83
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
1-
# -*- coding: utf-8 -*-
21

32
# Based on 2017 Tomoki Hayashi (Nagoya University)
43
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
54

6-
from __future__ import division
7-
from __future__ import print_function
8-
9-
import fnmatch
5+
import json
106
import os
117
import sys
12-
import threading
8+
from datetime import datetime
139

10+
import fnmatch
1411
import h5py
1512
import numpy as np
1613

17-
from numpy.matlib import repmat
14+
import tensorflow as tf
15+
import logging
1816

1917
def read_hdf5(hdf5_name, hdf5_path):
2018
"""FUNCTION TO READ HDF5 DATASET
@@ -118,80 +116,73 @@ def read_txt(file_list):
118116
filenames = f.readlines()
119117
return [filename.replace("\n", "") for filename in filenames]
120118

121-
122-
class BackgroundGenerator(threading.Thread):
123-
"""BACKGROUND GENERATOR
124-
125-
reference:
126-
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
127-
128-
Args:
129-
generator (object): generator instance
130-
max_prefetch (int): max number of prefetch
131-
"""
132-
133-
def __init__(self, generator, max_prefetch=1):
134-
threading.Thread.__init__(self)
135-
if sys.version_info.major == 2:
136-
from Queue import Queue
119+
def save(saver, sess, logdir, step):
120+
''' Save a model to logdir/model.ckpt-[step] '''
121+
model_name = 'model.ckpt'
122+
checkpoint_path = os.path.join(logdir, model_name)
123+
print('Storing checkpoint to {} ...'.format(logdir), end="")
124+
sys.stdout.flush()
125+
126+
if not os.path.exists(logdir):
127+
os.makedirs(logdir)
128+
129+
saver.save(sess, checkpoint_path, global_step=step)
130+
print(' Done.')
131+
132+
133+
def load(saver, sess, logdir, ckpt=None):
134+
'''
135+
Try to load model form a dir (search for the newest checkpoint)
136+
'''
137+
if ckpt:
138+
ckpt = os.path.join(logdir, ckpt)
139+
global_step = int(ckpt.split('/')[-1].split('-')[-1])
140+
logging.info(' Global step: {}'.format(global_step))
141+
saver.restore(sess, ckpt)
142+
return global_step
143+
else:
144+
ckpt = tf.train.latest_checkpoint(logdir)
145+
if ckpt:
146+
logging.info(' Checkpoint found: {}'.format(ckpt))
147+
global_step = int(ckpt.split('/')[-1].split('-')[-1])
148+
logging.info(' Global step: {}'.format(global_step))
149+
saver.restore(sess, ckpt)
150+
return global_step
137151
else:
138-
from queue import Queue
139-
self.queue = Queue(max_prefetch)
140-
self.generator = generator
141-
self.daemon = True
142-
self.start()
143-
144-
def run(self):
145-
for item in self.generator:
146-
self.queue.put(item)
147-
self.queue.put(None)
148-
149-
def next(self):
150-
next_item = self.queue.get()
151-
if next_item is None:
152-
raise StopIteration
153-
return next_item
154-
155-
def __next__(self):
156-
return self.next()
157-
158-
def __iter__(self):
159-
return self
160-
161-
162-
class background(object):
163-
"""BACKGROUND GENERATOR DECORATOR"""
164-
165-
def __init__(self, max_prefetch=1):
166-
self.max_prefetch = max_prefetch
167-
168-
def __call__(self, gen):
169-
def bg_generator(*args, **kwargs):
170-
return BackgroundGenerator(gen(*args, **kwargs))
171-
return bg_generator
172-
173-
174-
def extend_time(feats, upsampling_factor):
175-
"""FUNCTION TO EXTEND TIME RESOLUTION
176-
177-
Args:
178-
feats (ndarray): feature vector with the shape (T x D)
179-
upsampling_factor (int): upsampling_factor
180-
181-
Return:
182-
(ndarray): extend feats with the shape (upsampling_factor*T x D)
183-
"""
184-
# get number
185-
n_frames = feats.shape[0]
186-
n_dims = feats.shape[1]
187-
188-
# extend time
189-
feats_extended = np.zeros((n_frames * upsampling_factor, n_dims))
190-
for j in range(n_frames):
191-
start_idx = j * upsampling_factor
192-
end_idx = (j + 1) * upsampling_factor
193-
feats_extended[start_idx: end_idx] = repmat(feats[j, :], upsampling_factor, 1)
194-
195-
return feats_extended
196-
197-
152+
print('No checkpoint found')
153+
return None
154+
155+
def get_default_logdir_train(note, logdir_root='logdir'):
156+
STARTED_DATESTRING = datetime.now().strftime('%0m%0d-%0H%0M-%0S-%Y')
157+
logdir = os.path.join(logdir_root, '{}-{}'.format(STARTED_DATESTRING, note))
158+
print('Using default logdir: {}'.format(logdir))
159+
return logdir
160+
161+
def get_default_logdir_output(args):
162+
STARTED_DATESTRING = datetime.now().strftime('%0m%0d-%0H%0M-%0S-%Y')
163+
logdir = os.path.join(args.logdir, STARTED_DATESTRING+'-{}-{}'.format(args.src, args.trg))
164+
print('Logdir: {}'.format(logdir))
165+
return logdir
166+
167+
class ValueWindow():
168+
def __init__(self, window_size=100):
169+
self._window_size = window_size
170+
self._values = []
171+
172+
def append(self, x):
173+
self._values = self._values[-(self._window_size - 1):] + [x]
174+
175+
@property
176+
def sum(self):
177+
return sum(self._values)
178+
179+
@property
180+
def count(self):
181+
return len(self._values)
182+
183+
@property
184+
def average(self):
185+
return self.sum / max(1, self.count)
186+
187+
def reset(self):
188+
self._values = []
File renamed without changes.
File renamed without changes.
File renamed without changes.

util/wrapper.py

-78
This file was deleted.

0 commit comments

Comments
 (0)