3
3
from base import connect , getSeries , getBatch , ftQueryTpl , k_cols
4
4
from time import strftime
5
5
from joblib import Parallel , delayed
6
+ from loky import get_reusable_executor
6
7
import tensorflow as tf
7
8
import sys
8
9
import multiprocessing
13
14
feat_cols = []
14
15
time_shift = None
15
16
max_step = None
17
+ _prefetch = None
16
18
17
19
18
20
maxbno_query = (
27
29
" flag LIKE %s) t "
28
30
)
29
31
32
+ _executor = None
33
+
34
+
35
+ def _getExecutor ():
36
+ global parallel , _executor , _prefetch
37
+ if _executor is not None :
38
+ return _executor
39
+ _executor = get_reusable_executor (
40
+ max_workers = parallel * _prefetch , timeout = 20 )
41
+ return _executor
42
+
43
+
44
+ def _getSeries (p ):
45
+ uuid , code , klid , rcode , val , max_step , time_shift , ftQueryK , ftQueryD = p
46
+ return getSeries (uuid , code , klid , rcode , val , max_step , time_shift , ftQueryK , ftQueryD )
47
+
30
48
31
49
def _loadTestSet (max_step , ntest ):
32
50
global parallel , time_shift
@@ -59,18 +77,18 @@ def _loadTestSet(max_step, ntest):
59
77
# data = [batch, max_step, feature*time_shift]
60
78
# vals = [batch]
61
79
# seqlen = [batch]
62
- return np .array (uuids , 'U' ), np .array (data ,'f' ), np .array (vals ,'f' ), np .array (seqlen , 'i' )
80
+ return np .array (uuids , 'U' ), np .array (data , 'f' ), np .array (vals , 'f' ), np .array (seqlen , 'i' )
63
81
except :
64
82
print (sys .exc_info ()[0 ])
65
83
raise
66
84
finally :
67
85
cnx .close ()
68
86
69
87
70
- def _loadTrainingData (batch_no ):
88
+ def _loadTrainingData (flag ):
71
89
global max_step , parallel , time_shift
72
90
print ("{} loading training set {}..." .format (
73
- strftime ("%H:%M:%S" ), batch_no ))
91
+ strftime ("%H:%M:%S" ), flag ))
74
92
cnx = connect ()
75
93
try :
76
94
cursor = cnx .cursor (buffered = True )
@@ -82,22 +100,23 @@ def _loadTrainingData(batch_no):
82
100
'WHERE '
83
101
" flag = %s"
84
102
)
85
- flag = 'TRAIN_{}' .format (batch_no )
86
103
cursor .execute (query , (flag ,))
87
104
train_set = cursor .fetchall ()
88
105
total = cursor .rowcount
89
106
cursor .close ()
90
107
uuids , data , vals , seqlen = [], [], [], []
91
108
if total > 0 :
92
109
qk , qd = _getFtQuery ()
93
- r = Parallel (n_jobs = parallel )(delayed (getSeries )(
94
- uuid , code , klid , rcode , val , max_step , time_shift , qk , qd
95
- ) for uuid , code , klid , rcode , val in train_set )
110
+ #joblib doesn't support nested threading
111
+ exc = _getExecutor ()
112
+ params = [(uuid , code , klid , rcode , val , max_step , time_shift , qk , qd )
113
+ for uuid , code , klid , rcode , val in train_set ]
114
+ r = list (exc .map (_getSeries , params ))
96
115
uuids , data , vals , seqlen = zip (* r )
97
116
# data = [batch, max_step, feature*time_shift]
98
117
# vals = [batch]
99
118
# seqlen = [batch]
100
- return np .array (uuids ,'U' ), np .array (data ,'f' ), np .array (vals ,'f' ), np .array (seqlen , 'i' )
119
+ return np .array (uuids , 'U' ), np .array (data , 'f' ), np .array (vals , 'f' ), np .array (seqlen , 'i' )
101
120
except :
102
121
print (sys .exc_info ()[0 ])
103
122
raise
@@ -176,19 +195,20 @@ def _getDataSetMeta(flag, start=0):
176
195
return max_bno , batch_size
177
196
178
197
179
- def getInputs (start = 0 , shift = 0 , cols = None , step = 30 , cores = multiprocessing .cpu_count ()):
198
+ def getInputs (start = 0 , shift = 0 , cols = None , step = 30 , cores = multiprocessing .cpu_count (), prefetch = 2 ):
180
199
"""Input function for the wcc training dataset.
181
200
182
201
Returns:
183
202
A dictionary containing:
184
203
uuids,features,labels,seqlens,train_iter,test_iter
185
204
"""
186
205
# Create dataset for training
187
- global feat_cols , max_step , time_shift , parallel
206
+ global feat_cols , max_step , time_shift , parallel , _prefetch
188
207
time_shift = shift
189
208
feat_cols = cols
190
209
max_step = step
191
210
parallel = cores
211
+ _prefetch = prefetch
192
212
feat_size = len (cols )* 2 * (shift + 1 )
193
213
print ("{} Using parallel level:{}" .format (strftime ("%H:%M:%S" ), parallel ))
194
214
with tf .variable_scope ("build_inputs" ):
@@ -202,7 +222,7 @@ def getInputs(start=0, shift=0, cols=None, step=30, cores=multiprocessing.cpu_co
202
222
tf .py_func (_loadTrainingData , [f ], [
203
223
tf .string , tf .float32 , tf .float32 , tf .int32 ])
204
224
)
205
- ).batch (1 ).prefetch (2 )
225
+ ).batch (1 ).prefetch (prefetch )
206
226
# Create dataset for testing
207
227
max_bno , batch_size = _getDataSetMeta ("TEST" , 1 )
208
228
test_dataset = tf .data .Dataset .from_tensor_slices (
@@ -215,7 +235,8 @@ def getInputs(start=0, shift=0, cols=None, step=30, cores=multiprocessing.cpu_co
215
235
types = (tf .string , tf .float32 , tf .float32 , tf .int32 )
216
236
shapes = (tf .TensorShape ([None ]), tf .TensorShape (
217
237
[None , step , feat_size ]), tf .TensorShape ([None ]), tf .TensorShape ([None ]))
218
- iter = tf .data .Iterator .from_string_handle (handle , types , train_dataset .output_shapes )
238
+ iter = tf .data .Iterator .from_string_handle (
239
+ handle , types , train_dataset .output_shapes )
219
240
220
241
next_el = iter .get_next ()
221
242
uuids = tf .squeeze (next_el [0 ])
0 commit comments