12
12
from mysql .connector .pooling import MySQLConnectionPool
13
13
from corl .wc_data .series import DataLoader , getSeries_v2
14
14
from corl .wc_test .test27_mdnc import create_regressor
15
- from corl .wc_data .input_fn import getWorkloadForPrediction
15
+ from corl .wc_data .input_fn import getWorkloadForPrediction , getWorkloadForPredictionFromTags
16
16
17
17
REGRESSOR = create_regressor ()
18
18
cnxpool = None
36
36
%s,%s,%s)
37
37
"""
38
38
39
+ KLINE_TAG_UPSERT = """
40
+ INSERT INTO `secu`.`kline_d_b_lr_tags`
41
+ (`code`,`date`,`klid`,`tags`,`udate`,`utime`)
42
+ VALUES
43
+ (%s,%s,%s,%s,%s,%s)
44
+ ON DUPLICATE KEY
45
+ UPDATE
46
+ `tags`=concat(`tags`,';',VALUES(`tags`)),
47
+ `udate`=VALUES(`udate`),
48
+ `utime`=VALUES(`utime`)
49
+ """
50
+
39
51
40
52
def _init (db_pool_size = None , db_host = None , db_port = None , db_pwd = None ):
41
53
# FIXME too many db initialization message in the log and 'aborted clients' in mysql dashboard
@@ -54,13 +66,6 @@ def _init(db_pool_size=None, db_host=None, db_port=None, db_pwd=None):
54
66
# ssl_ca='',
55
67
# use_pure=True,
56
68
connect_timeout = 90000 )
57
- # ray.init(
58
- # num_cpus=psutil.cpu_count(logical=False),
59
- # webui_host='127.0.0.1', # TODO need a different port?
60
- # memory=2 * 1024 * 1024 * 1024, # 2G
61
- # object_store_memory=512 * 1024 * 1024, # 512M
62
- # driver_object_store_memory=256 * 1024 * 1024 # 256M
63
- # )
64
69
65
70
66
71
def _get_rcodes_for (code , table , dates ):
@@ -206,6 +211,8 @@ def _save_prediction(code=None, klid=None, date=None, rcodes=None, top_k=None, p
206
211
try :
207
212
cursor = cnx .cursor ()
208
213
cursor .executemany (WCC_INSERT , bucket )
214
+ cursor .executemany (KLINE_TAG_UPSERT ,
215
+ [(t [0 ], t [1 ], t [2 ], 'wcc_predict' , t [- 2 ], t [- 1 ]) for t in bucket ])
209
216
cnx .commit ()
210
217
except :
211
218
print (sys .exc_info ()[0 ])
@@ -352,15 +359,11 @@ def _predict(model_path, max_batch_size, data_queue, infer_queue, args):
352
359
def predict ():
353
360
try :
354
361
next_work = data_queue .get ()
362
+
355
363
if isinstance (next_work , str ) and next_work == 'done' :
356
- if data_queue .empty ():
357
- infer_queue .put ('done' )
358
- return True
359
- else :
360
- print ('{} warning, data_queue is still not empty when ' 'done' ' signal is received. qsize: {}' .format (
361
- strftime ("%H:%M:%S" ), data_queue .size ()))
362
- data_queue .put_nowait ('done' )
363
- return False
364
+ infer_queue .put ('done' )
365
+ return True
366
+
364
367
batch = next_work ['batch' ]
365
368
p = model .predict (batch , batch_size = max_batch_size )
366
369
p = np .squeeze (p )
@@ -399,7 +402,7 @@ def predict():
399
402
if c == 2000 :
400
403
print ('{} predict average: {}' .format (
401
404
strftime ("%H:%M:%S" ), elapsed / 1000 ), file = sys .stderr )
402
- predict ()
405
+ done = predict ()
403
406
c += 1
404
407
405
408
return done
@@ -411,44 +414,42 @@ def _save_infer_result(top_k, shared_args, infer_queue):
411
414
db_host = shared_args ['db_host' ]
412
415
db_port = shared_args ['db_port' ]
413
416
db_pwd = shared_args ['db_pwd' ]
417
+ parallel = shared_args ['args' ].parallel
418
+
414
419
if cnxpool is None :
415
420
_init (1 , db_host , db_port , db_pwd )
416
421
417
422
def _inner_work ():
418
423
# poll work request from 'infer_queue' for saving inference result and handle persistence
419
424
if infer_queue .empty ():
420
425
sleep (5 )
421
- return False
426
+ return 0
422
427
try :
423
428
next_result = infer_queue .get ()
429
+
424
430
if isinstance (next_result , str ) and next_result == 'done' :
425
- if infer_queue .empty ():
426
- # flush bucket
427
- _save_prediction ()
428
- return True
429
- else :
430
- print ('{} warning, infer_queue is still not empty when ' 'done' ' signal is received. qsize: {}' .format (
431
- strftime ("%H:%M:%S" ), infer_queue .size ()))
432
- infer_queue .put_nowait ('done' )
433
- else :
434
- result = next_result ['result' ]
435
- rcodes = next_result ['rcodes' ]
436
- code = next_result ['code' ]
437
- date = next_result ['date' ]
438
- klid = next_result ['klid' ]
439
- udate = next_result ['udate' ]
440
- utime = next_result ['utime' ]
441
- _save_prediction (code , klid , date , rcodes ,
442
- top_k , result , udate , utime )
431
+ # flush bucket
432
+ _save_prediction ()
433
+ return 1
434
+
435
+ result = next_result ['result' ]
436
+ rcodes = next_result ['rcodes' ]
437
+ code = next_result ['code' ]
438
+ date = next_result ['date' ]
439
+ klid = next_result ['klid' ]
440
+ udate = next_result ['udate' ]
441
+ utime = next_result ['utime' ]
442
+ _save_prediction (code , klid , date , rcodes ,
443
+ top_k , result , udate , utime )
443
444
except Exception :
444
445
sleep (2 )
445
446
pass
446
-
447
- return False
448
447
449
- done = False
450
- while not done :
451
- done = _inner_work ()
448
+ return 0
449
+
450
+ done = 0
451
+ while done < parallel :
452
+ done += _inner_work ()
452
453
453
454
cnxpool ._remove_connections ()
454
455
@@ -477,7 +478,7 @@ def predict_wcc(num_actors, min_rcode, max_batch_size, model_path, top_k, shared
477
478
shared_args ) for i in range (num_actors )]
478
479
)
479
480
480
- work = getWorkloadForPrediction (actor_pool ,
481
+ work = getWorkloadForPredictionFromTags (actor_pool ,
481
482
start_anchor ,
482
483
stop_anchor ,
483
484
corl_prior ,
0 commit comments