Skip to content

Commit 6b5e5e0

Browse files
committed
use loky to support nested paralleled loading
1 parent 005eece commit 6b5e5e0

File tree

3 files changed

+56
-15
lines changed

3 files changed

+56
-15
lines changed

corl/wc_data/input_fn.py

+33-12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from base import connect, getSeries, getBatch, ftQueryTpl, k_cols
44
from time import strftime
55
from joblib import Parallel, delayed
6+
from loky import get_reusable_executor
67
import tensorflow as tf
78
import sys
89
import multiprocessing
@@ -13,6 +14,7 @@
1314
feat_cols = []
1415
time_shift = None
1516
max_step = None
17+
_prefetch = None
1618

1719

1820
maxbno_query = (
@@ -27,6 +29,22 @@
2729
" flag LIKE %s) t "
2830
)
2931

32+
_executor = None
33+
34+
35+
def _getExecutor():
36+
global parallel, _executor, _prefetch
37+
if _executor is not None:
38+
return _executor
39+
_executor = get_reusable_executor(
40+
max_workers=parallel*_prefetch, timeout=20)
41+
return _executor
42+
43+
44+
def _getSeries(p):
45+
uuid, code, klid, rcode, val, max_step, time_shift, ftQueryK, ftQueryD = p
46+
return getSeries(uuid, code, klid, rcode, val, max_step, time_shift, ftQueryK, ftQueryD)
47+
3048

3149
def _loadTestSet(max_step, ntest):
3250
global parallel, time_shift
@@ -59,18 +77,18 @@ def _loadTestSet(max_step, ntest):
5977
# data = [batch, max_step, feature*time_shift]
6078
# vals = [batch]
6179
# seqlen = [batch]
62-
return np.array(uuids, 'U'), np.array(data,'f'), np.array(vals,'f'), np.array(seqlen, 'i')
80+
return np.array(uuids, 'U'), np.array(data, 'f'), np.array(vals, 'f'), np.array(seqlen, 'i')
6381
except:
6482
print(sys.exc_info()[0])
6583
raise
6684
finally:
6785
cnx.close()
6886

6987

70-
def _loadTrainingData(batch_no):
88+
def _loadTrainingData(flag):
7189
global max_step, parallel, time_shift
7290
print("{} loading training set {}...".format(
73-
strftime("%H:%M:%S"), batch_no))
91+
strftime("%H:%M:%S"), flag))
7492
cnx = connect()
7593
try:
7694
cursor = cnx.cursor(buffered=True)
@@ -82,22 +100,23 @@ def _loadTrainingData(batch_no):
82100
'WHERE '
83101
" flag = %s"
84102
)
85-
flag = 'TRAIN_{}'.format(batch_no)
86103
cursor.execute(query, (flag,))
87104
train_set = cursor.fetchall()
88105
total = cursor.rowcount
89106
cursor.close()
90107
uuids, data, vals, seqlen = [], [], [], []
91108
if total > 0:
92109
qk, qd = _getFtQuery()
93-
r = Parallel(n_jobs=parallel)(delayed(getSeries)(
94-
uuid, code, klid, rcode, val, max_step, time_shift, qk, qd
95-
) for uuid, code, klid, rcode, val in train_set)
110+
#joblib doesn't support nested threading
111+
exc = _getExecutor()
112+
params = [(uuid, code, klid, rcode, val, max_step, time_shift, qk, qd)
113+
for uuid, code, klid, rcode, val in train_set]
114+
r = list(exc.map(_getSeries, params))
96115
uuids, data, vals, seqlen = zip(*r)
97116
# data = [batch, max_step, feature*time_shift]
98117
# vals = [batch]
99118
# seqlen = [batch]
100-
return np.array(uuids,'U'), np.array(data,'f'), np.array(vals,'f'), np.array(seqlen, 'i')
119+
return np.array(uuids, 'U'), np.array(data, 'f'), np.array(vals, 'f'), np.array(seqlen, 'i')
101120
except:
102121
print(sys.exc_info()[0])
103122
raise
@@ -176,19 +195,20 @@ def _getDataSetMeta(flag, start=0):
176195
return max_bno, batch_size
177196

178197

179-
def getInputs(start=0, shift=0, cols=None, step=30, cores=multiprocessing.cpu_count()):
198+
def getInputs(start=0, shift=0, cols=None, step=30, cores=multiprocessing.cpu_count(), prefetch=2):
180199
"""Input function for the wcc training dataset.
181200
182201
Returns:
183202
A dictionary containing:
184203
uuids,features,labels,seqlens,train_iter,test_iter
185204
"""
186205
# Create dataset for training
187-
global feat_cols, max_step, time_shift, parallel
206+
global feat_cols, max_step, time_shift, parallel, _prefetch
188207
time_shift = shift
189208
feat_cols = cols
190209
max_step = step
191210
parallel = cores
211+
_prefetch = prefetch
192212
feat_size = len(cols)*2*(shift+1)
193213
print("{} Using parallel level:{}".format(strftime("%H:%M:%S"), parallel))
194214
with tf.variable_scope("build_inputs"):
@@ -202,7 +222,7 @@ def getInputs(start=0, shift=0, cols=None, step=30, cores=multiprocessing.cpu_co
202222
tf.py_func(_loadTrainingData, [f], [
203223
tf.string, tf.float32, tf.float32, tf.int32])
204224
)
205-
).batch(1).prefetch(2)
225+
).batch(1).prefetch(prefetch)
206226
# Create dataset for testing
207227
max_bno, batch_size = _getDataSetMeta("TEST", 1)
208228
test_dataset = tf.data.Dataset.from_tensor_slices(
@@ -215,7 +235,8 @@ def getInputs(start=0, shift=0, cols=None, step=30, cores=multiprocessing.cpu_co
215235
types = (tf.string, tf.float32, tf.float32, tf.int32)
216236
shapes = (tf.TensorShape([None]), tf.TensorShape(
217237
[None, step, feat_size]), tf.TensorShape([None]), tf.TensorShape([None]))
218-
iter = tf.data.Iterator.from_string_handle(handle, types, train_dataset.output_shapes)
238+
iter = tf.data.Iterator.from_string_handle(
239+
handle, types, train_dataset.output_shapes)
219240

220241
next_el = iter.get_next()
221242
uuids = tf.squeeze(next_el[0])

corl/wc_test/test4.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@
3939
]
4040

4141
parser = argparse.ArgumentParser()
42-
parser.add_argument('parallel', type=int, nargs='?', help='database operation parallel level',
42+
parser.add_argument('--parallel', type=int, help='database operation parallel level',
4343
default=multiprocessing.cpu_count())
44+
parser.add_argument('--prefetch', type=int, help='dataset prefetch batches',
45+
default=2)
4446
parser.add_argument(
4547
'--restart', help='restart training', action='store_true')
4648
args = parser.parse_args()
@@ -77,7 +79,7 @@ def run():
7779
bno = int(os.path.basename(
7880
ckpt.model_checkpoint_path).split('-')[1])
7981
d = input_fn.getInputs(
80-
bno+1, TIME_SHIFT, k_cols, MAX_STEP, args.parallel)
82+
bno+1, TIME_SHIFT, k_cols, MAX_STEP, args.parallel, args.prefetch)
8183
model.setNodes(d['uuids'], d['features'],
8284
d['labels'], d['seqlens'])
8385
saver = tf.train.Saver()
@@ -91,7 +93,7 @@ def run():
9193

9294
if not restored:
9395
d = input_fn.getInputs(
94-
bno+1, TIME_SHIFT, k_cols, MAX_STEP, args.parallel)
96+
bno+1, TIME_SHIFT, k_cols, MAX_STEP, args.parallel, args.prefetch)
9597
model.setNodes(d['uuids'], d['features'],
9698
d['labels'], d['seqlens'])
9799
saver = tf.train.Saver()

executor/test.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from loky import get_reusable_executor
2+
3+
_executor = get_reusable_executor(
4+
max_workers=2, timeout=20)
5+
6+
7+
def fn(p):
8+
a, b, c = p
9+
print("received:{} {} {}".format(a, b, c))
10+
return p[2]+1, [[p[0]+1, p[0]+2], [p[0]+3, p[0]+4]], p[1]+1
11+
12+
13+
params = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
14+
r = list(_executor.map(fn, params))
15+
ra, rb, rc = zip(*r)
16+
print("a:{}".format(ra))
17+
print("b:{}".format(rb))
18+
print("c:{}".format(rc))

0 commit comments

Comments
 (0)