Skip to content

Commit c007319

Browse files
committed
cleanup todo list and query workload from tag table
1 parent a5e82c4 commit c007319

File tree

11 files changed

+143
-70
lines changed

11 files changed

+143
-70
lines changed

corl/model/tf2/dnc_regressor.py

-5
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,6 @@ def getModel(self):
301301
# name='features',
302302
dtype=tf.float32)
303303

304-
#TODO add CNN before RNN?
305-
306304
# create sequence of DNC layers
307305
layer = inputs
308306
for i in range(self._num_dnc_layers):
@@ -319,11 +317,8 @@ def getModel(self):
319317
return_sequences=True if i+1 < self._num_dnc_layers else False,
320318
name='rnn_{}'.format(i),
321319
)
322-
# TODO use separate dnc cell for forward & backward pass?
323320
layer = keras.layers.Bidirectional(layer=rnn, name='bidir_{}'.format(i))(layer)
324321

325-
# TODO add batch normalization layer before FCN?
326-
327322
if self._dropout_rate > 0:
328323
layer = keras.layers.AlphaDropout(self._dropout_rate)(layer)
329324

corl/model/tf2/lstm.py

-1
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,6 @@ def getModel(self):
314314
return self.model
315315

316316
def compile(self):
317-
# TODO study how to use ReduceLROnPlateau and CosineDecayRestarts on adam optimizer
318317
# decay = tf.keras.experimental.CosineDecayRestarts(self._lr,
319318
# self._lr_decay_steps,
320319
# t_mul=1.02,

corl/wc_data/input_fn.py

+70-2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060

6161
def _init_db(db_pool_size=None, db_host=None, db_port=None, db_pwd=None):
6262
global cnxpool
63+
if cnxpool is not None:
64+
return
6365
print("{} [PID={}]: initializing mysql connection pool...".format(
6466
strftime("%H:%M:%S"), os.getpid()))
6567
cnxpool = MySQLConnectionPool(
@@ -624,13 +626,79 @@ def mapfunc(bno):
624626
}
625627

626628

629+
def getWorkloadForPredictionFromTags(actor_pool, start_anchor, stop_anchor, corl_prior, max_step, time_shift, host, port, pwd):
630+
##TODO realize me. query workload from tag table
631+
'''
632+
Returns list of tuples (code, date, klid)
633+
'''
634+
global cnxpool
635+
_init_db(1, host, port, pwd)
636+
qry = (
637+
"SELECT "
638+
" partition_name "
639+
"FROM "
640+
" information_schema.partitions "
641+
"WHERE "
642+
" table_schema = 'secu' "
643+
" AND table_name = 'kline_d_b_lr' "
644+
)
645+
cond = ''
646+
if start_anchor is not None:
647+
c1, k1 = start_anchor
648+
cond += '''
649+
and (
650+
t.code > '{}'
651+
or (t.code = '{}' and t.klid >= {})
652+
)
653+
'''.format(c1, c1, k1)
654+
if stop_anchor is not None:
655+
c2, k2 = stop_anchor
656+
cond += '''
657+
and (
658+
t.code < '{}'
659+
or (t.code = '{}' and t.klid < {})
660+
)
661+
'''.format(c2, c2, k2)
662+
cnx = cnxpool.get_connection()
663+
cursor = None
664+
try:
665+
print('{} querying partitions for kline_d_b_lr'.format(strftime("%H:%M:%S")))
666+
cursor = cnx.cursor()
667+
cursor.execute(qry)
668+
rows = cursor.fetchall()
669+
total = cursor.rowcount
670+
print('{} #partitions: {}'.format(strftime("%H:%M:%S"), total))
671+
except:
672+
print(sys.exc_info()[0])
673+
raise
674+
finally:
675+
if cursor is not None:
676+
cursor.close()
677+
cnx.close()
678+
679+
tasks = actor_pool.map(
680+
lambda a, part: a.get_wcc_infer_work_request.remote(part, cond),
681+
rows
682+
)
683+
684+
# remove empty sublists
685+
workloads = [t for t in list(tasks) if t]
686+
# flatten the list and remove empty tuples
687+
workloads = [val for sublist in workloads for val in sublist if val]
688+
# sort by code and klid in ascending order
689+
workloads.sort(key=lambda tup: (tup[0], tup[3]))
690+
691+
print('{} total workloads: {}'.format(
692+
strftime("%H:%M:%S"), len(workloads)))
693+
694+
return workloads
695+
627696
def getWorkloadForPrediction(actor_pool, start_anchor, stop_anchor, corl_prior, max_step, time_shift, host, port, pwd):
628697
'''
629698
Returns list of tuples (code, date, klid)
630699
'''
631700
global cnxpool
632-
if cnxpool is None:
633-
_init_db(1, host, port, pwd)
701+
_init_db(1, host, port, pwd)
634702
qry = (
635703
"SELECT "
636704
" partition_name "

corl/wcc/Dockerfile

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
FROM tensorflow/tensorflow:2.4.1-gpu
22

3-
COPY requirements.txt .
3+
# set timezone to Asia/Shanghai
4+
#ENV TZ Asia/Shanghai
45

6+
#RUN echo $TZ > /etc/timezone && \
7+
# apt-get update && apt-get install -y tzdata && \
8+
# rm /etc/localtime && \
9+
# ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && \
10+
# dpkg-reconfigure -f noninteractive tzdata && \
11+
# apt-get clean
12+
13+
# install timedatectl & dependencies, and enable time synchronization
14+
#RUN apt-get install -y systemd dbus
15+
#RUN timedatectl set-ntp on
16+
17+
# upgrade pip and install requirements
518
RUN python3 -m pip install --upgrade pip
619
RUN pip install -r requirements.txt;
720
RUN rm -rf requirements.txt

corl/wcc/worker.py

+43-42
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mysql.connector.pooling import MySQLConnectionPool
1313
from corl.wc_data.series import DataLoader, getSeries_v2
1414
from corl.wc_test.test27_mdnc import create_regressor
15-
from corl.wc_data.input_fn import getWorkloadForPrediction
15+
from corl.wc_data.input_fn import getWorkloadForPrediction, getWorkloadForPredictionFromTags
1616

1717
REGRESSOR = create_regressor()
1818
cnxpool = None
@@ -36,6 +36,18 @@
3636
%s,%s,%s)
3737
"""
3838

39+
KLINE_TAG_UPSERT = """
40+
INSERT INTO `secu`.`kline_d_b_lr_tags`
41+
(`code`,`date`,`klid`,`tags`,`udate`,`utime`)
42+
VALUES
43+
(%s,%s,%s,%s,%s,%s)
44+
ON DUPLICATE KEY
45+
UPDATE
46+
`tags`=concat(`tags`,';',VALUES(`tags`)),
47+
`udate`=VALUES(`udate`),
48+
`utime`=VALUES(`utime`)
49+
"""
50+
3951

4052
def _init(db_pool_size=None, db_host=None, db_port=None, db_pwd=None):
4153
# FIXME too many db initialization message in the log and 'aborted clients' in mysql dashboard
@@ -54,13 +66,6 @@ def _init(db_pool_size=None, db_host=None, db_port=None, db_pwd=None):
5466
# ssl_ca='',
5567
# use_pure=True,
5668
connect_timeout=90000)
57-
# ray.init(
58-
# num_cpus=psutil.cpu_count(logical=False),
59-
# webui_host='127.0.0.1', # TODO need a different port?
60-
# memory=2 * 1024 * 1024 * 1024, # 2G
61-
# object_store_memory=512 * 1024 * 1024, # 512M
62-
# driver_object_store_memory=256 * 1024 * 1024 # 256M
63-
# )
6469

6570

6671
def _get_rcodes_for(code, table, dates):
@@ -206,6 +211,8 @@ def _save_prediction(code=None, klid=None, date=None, rcodes=None, top_k=None, p
206211
try:
207212
cursor = cnx.cursor()
208213
cursor.executemany(WCC_INSERT, bucket)
214+
cursor.executemany(KLINE_TAG_UPSERT,
215+
[(t[0], t[1], t[2], 'wcc_predict', t[-2], t[-1]) for t in bucket])
209216
cnx.commit()
210217
except:
211218
print(sys.exc_info()[0])
@@ -352,15 +359,11 @@ def _predict(model_path, max_batch_size, data_queue, infer_queue, args):
352359
def predict():
353360
try:
354361
next_work = data_queue.get()
362+
355363
if isinstance(next_work, str) and next_work == 'done':
356-
if data_queue.empty():
357-
infer_queue.put('done')
358-
return True
359-
else:
360-
print('{} warning, data_queue is still not empty when ''done'' signal is received. qsize: {}'.format(
361-
strftime("%H:%M:%S"), data_queue.size()))
362-
data_queue.put_nowait('done')
363-
return False
364+
infer_queue.put('done')
365+
return True
366+
364367
batch = next_work['batch']
365368
p = model.predict(batch, batch_size=max_batch_size)
366369
p = np.squeeze(p)
@@ -399,7 +402,7 @@ def predict():
399402
if c == 2000:
400403
print('{} predict average: {}'.format(
401404
strftime("%H:%M:%S"), elapsed/1000), file=sys.stderr)
402-
predict()
405+
done = predict()
403406
c += 1
404407

405408
return done
@@ -411,44 +414,42 @@ def _save_infer_result(top_k, shared_args, infer_queue):
411414
db_host = shared_args['db_host']
412415
db_port = shared_args['db_port']
413416
db_pwd = shared_args['db_pwd']
417+
parallel = shared_args['args'].parallel
418+
414419
if cnxpool is None:
415420
_init(1, db_host, db_port, db_pwd)
416421

417422
def _inner_work():
418423
# poll work request from 'infer_queue' for saving inference result and handle persistence
419424
if infer_queue.empty():
420425
sleep(5)
421-
return False
426+
return 0
422427
try:
423428
next_result = infer_queue.get()
429+
424430
if isinstance(next_result, str) and next_result == 'done':
425-
if infer_queue.empty():
426-
# flush bucket
427-
_save_prediction()
428-
return True
429-
else:
430-
print('{} warning, infer_queue is still not empty when ''done'' signal is received. qsize: {}'.format(
431-
strftime("%H:%M:%S"), infer_queue.size()))
432-
infer_queue.put_nowait('done')
433-
else:
434-
result = next_result['result']
435-
rcodes = next_result['rcodes']
436-
code = next_result['code']
437-
date = next_result['date']
438-
klid = next_result['klid']
439-
udate = next_result['udate']
440-
utime = next_result['utime']
441-
_save_prediction(code, klid, date, rcodes,
442-
top_k, result, udate, utime)
431+
# flush bucket
432+
_save_prediction()
433+
return 1
434+
435+
result = next_result['result']
436+
rcodes = next_result['rcodes']
437+
code = next_result['code']
438+
date = next_result['date']
439+
klid = next_result['klid']
440+
udate = next_result['udate']
441+
utime = next_result['utime']
442+
_save_prediction(code, klid, date, rcodes,
443+
top_k, result, udate, utime)
443444
except Exception:
444445
sleep(2)
445446
pass
446-
447-
return False
448447

449-
done = False
450-
while not done:
451-
done = _inner_work()
448+
return 0
449+
450+
done = 0
451+
while done < parallel:
452+
done += _inner_work()
452453

453454
cnxpool._remove_connections()
454455

@@ -477,7 +478,7 @@ def predict_wcc(num_actors, min_rcode, max_batch_size, model_path, top_k, shared
477478
shared_args) for i in range(num_actors)]
478479
)
479480

480-
work = getWorkloadForPrediction(actor_pool,
481+
work = getWorkloadForPredictionFromTags(actor_pool,
481482
start_anchor,
482483
stop_anchor,
483484
corl_prior,

pstk/data/data15.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,20 @@ def _getIndex():
6262

6363
def _init(db_pool_size=None, db_host=None, db_port=None, db_pwd=None):
6464
global cnxpool
65-
print("{} [PID={}]: initializing mysql connection pool...".format(
66-
strftime("%H:%M:%S"), os.getpid()))
67-
cnxpool = MySQLConnectionPool(
68-
pool_name="dbpool",
69-
pool_size=db_pool_size or 5,
70-
host=db_host or '127.0.0.1',
71-
port=db_port or 3306,
72-
user='mysql',
73-
database='secu',
74-
password=db_pwd or '123456',
75-
# ssl_ca='',
76-
# use_pure=True,
77-
connect_timeout=90000)
65+
if cnxpool is None:
66+
print("{} [PID={}]: initializing mysql connection pool...".format(
67+
strftime("%H:%M:%S"), os.getpid()))
68+
cnxpool = MySQLConnectionPool(
69+
pool_name="dbpool",
70+
pool_size=db_pool_size or 5,
71+
host=db_host or '127.0.0.1',
72+
port=db_port or 3306,
73+
user='mysql',
74+
database='secu',
75+
password=db_pwd or '123456',
76+
# ssl_ca='',
77+
# use_pure=True,
78+
connect_timeout=90000)
7879
ray.init(
7980
num_cpus=psutil.cpu_count(logical=False),
8081
webui_host='127.0.0.1',

pstk/model/model10.py

-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def logits(self):
5959
@staticmethod
6060
def tcn(self, inputs):
6161
# Temporal Convolutional Network
62-
#TODO: implement me
6362
return None
6463

6564

pstk/model/model11.py

-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def fcn(self, inputs):
8888
@staticmethod
8989
def rnn(self, inputs):
9090
# Deep Residual RNN
91-
# TODO: try MultiRNNCell of MultiRNNCell, wrapped in a residual wrapper
9291
cells = []
9392
feat_size = int(inputs.get_shape()[-1])
9493
# p = int(round(self._rnn_layers ** 0.5))

pstk/model/model3.py

-2
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,9 @@ def prediction(self):
174174
@staticmethod
175175
def rnn(self, input):
176176
# Recurrent network.
177-
# TODO add tf.contrib.rnn.ConvLSTMCell?
178177
step = int(input.get_shape()[1])
179178
feat = int(input.get_shape()[2])
180179
c = feat // self._input_width # channel
181-
# TODO step & width must equal?
182180
input = tf.reshape(input, [-1, step, self._input_width, c])
183181
clc = tf.contrib.rnn.ConvLSTMCell(
184182
conv_ndims=1,

pstk/model/wavenet/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def _create_dilation_layer(self, input_batch, layer_index, dilation,
337337

338338
def _generator_conv(self, input_batch, state_batch, weights):
339339
'''Perform convolution for a single convolutional processing step.'''
340-
# TODO generalize to filter_width > 2
340+
# TD: generalize to filter_width > 2
341341
past_weights = weights[0, :, :]
342342
curr_weights = weights[1, :, :]
343343
output = tf.matmul(state_batch, past_weights) + tf.matmul(

test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def testTasklist():
363363
# print('#talst: {}'.format(len(talst)))
364364
# for i in range(50):
365365
# print_talst_element(i, talst)
366-
# TODO test efficient status update
366+
# test efficient status update
367367
# for i in range(50):
368368
# delayed_write_talst(i, talst)
369369
# print("job done")

0 commit comments

Comments
 (0)