|
| 1 | + |
| 2 | + |
| 3 | +import os |
| 4 | +import sys |
| 5 | +import psutil |
| 6 | +import ray |
| 7 | + |
| 8 | +import tensorflow as tf |
| 9 | + |
| 10 | +from time import strftime |
| 11 | +from mysql.connector.pooling import MySQLConnectionPool |
| 12 | + |
| 13 | +k_cols = ["lr"] |
| 14 | + |
| 15 | +idxlst = None |
| 16 | +feat_cols = [] |
| 17 | +parallel = None |
| 18 | +time_shift = None |
| 19 | +max_step = None |
| 20 | +_prefetch = None |
| 21 | +db_pool_size = None |
| 22 | +db_host = None |
| 23 | +db_port = None |
| 24 | +db_pwd = None |
| 25 | +shared_args = None |
| 26 | +check_input = False |
| 27 | + |
| 28 | +cnxpool = None |
| 29 | + |
| 30 | +maxbno_query = ("SELECT " |
| 31 | + " vmax " |
| 32 | + "FROM " |
| 33 | + " fs_stats " |
| 34 | + "WHERE " |
| 35 | + " method = 'standardization' " |
| 36 | + " AND tab = 'wcc_trn' " |
| 37 | + " AND fields = %s ") |
| 38 | + |
| 39 | + |
| 40 | +def _getIndex(): |
| 41 | + ''' |
| 42 | + Returns a set of index codes from idxlst table. |
| 43 | + ''' |
| 44 | + global idxlst |
| 45 | + if idxlst is not None: |
| 46 | + return idxlst |
| 47 | + print("{} loading index...".format(strftime("%H:%M:%S"))) |
| 48 | + cnx = cnxpool.get_connection() |
| 49 | + try: |
| 50 | + cursor = cnx.cursor(buffered=True) |
| 51 | + query = ('SELECT distinct code COLLATE utf8mb4_0900_as_cs FROM idxlst') |
| 52 | + cursor.execute(query) |
| 53 | + rows = cursor.fetchall() |
| 54 | + cursor.close() |
| 55 | + idxlst = {c[0] for c in rows} |
| 56 | + return idxlst |
| 57 | + except: |
| 58 | + print(sys.exc_info()[0]) |
| 59 | + raise |
| 60 | + finally: |
| 61 | + cnx.close() |
| 62 | + |
| 63 | +def _init(db_pool_size=None, db_host=None, db_port=None, db_pwd=None): |
| 64 | + global cnxpool |
| 65 | + print("{} [PID={}]: initializing mysql connection pool...".format( |
| 66 | + strftime("%H:%M:%S"), os.getpid())) |
| 67 | + cnxpool = MySQLConnectionPool( |
| 68 | + pool_name="dbpool", |
| 69 | + pool_size=db_pool_size or 5, |
| 70 | + host=db_host or '127.0.0.1', |
| 71 | + port=db_port or 3306, |
| 72 | + user='mysql', |
| 73 | + database='secu', |
| 74 | + password=db_pwd or '123456', |
| 75 | + # ssl_ca='', |
| 76 | + # use_pure=True, |
| 77 | + connect_timeout=90000) |
| 78 | + ray.init( |
| 79 | + num_cpus=psutil.cpu_count(logical=False), |
| 80 | + webui_host='127.0.0.1', |
| 81 | + memory=4 * 1024 * 1024 * 1024, # 4G |
| 82 | + object_store_memory=4 * 1024 * 1024 * 1024, # 4G |
| 83 | + driver_object_store_memory=256 * 1024 * 1024 # 256M |
| 84 | + ) |
| 85 | + |
| 86 | + |
| 87 | +def _getDataSetMeta(flag): |
| 88 | + global cnxpool |
| 89 | + cnx = cnxpool.get_connection() |
| 90 | + max_bno, batch_size = None, None |
| 91 | + try: |
| 92 | + print('{} querying max batch no for {} set...'.format( |
| 93 | + strftime("%H:%M:%S"), flag)) |
| 94 | + cursor = cnx.cursor() |
| 95 | + cursor.execute(maxbno_query, (flag + "_BNO", )) |
| 96 | + row = cursor.fetchone() |
| 97 | + max_bno = int(row[0]) |
| 98 | + print('{} max batch no: {}'.format(strftime("%H:%M:%S"), max_bno)) |
| 99 | + query = ("SELECT " |
| 100 | + " COUNT(*) " |
| 101 | + "FROM " |
| 102 | + " wcc_trn " |
| 103 | + "WHERE " |
| 104 | + " flag = %s " |
| 105 | + " AND bno = 1 ") |
| 106 | + cursor.execute(query, (flag, )) |
| 107 | + row = cursor.fetchone() |
| 108 | + batch_size = row[0] |
| 109 | + print('{} batch size: {}'.format(strftime("%H:%M:%S"), batch_size)) |
| 110 | + if batch_size == 0: |
| 111 | + print('{} no more data for {}.'.format(strftime("%H:%M:%S"), |
| 112 | + flag.lower())) |
| 113 | + return None, None |
| 114 | + cursor.close() |
| 115 | + except: |
| 116 | + print(sys.exc_info()[0]) |
| 117 | + raise |
| 118 | + finally: |
| 119 | + cnx.close() |
| 120 | + return max_bno, batch_size |
| 121 | + |
| 122 | +def getInputs(start_bno=0, |
| 123 | + shift=0, |
| 124 | + cols=None, |
| 125 | + step=30, |
| 126 | + cores=psutil.cpu_count(logical=False), |
| 127 | + pfetch=2, |
| 128 | + pool=None, |
| 129 | + host=None, |
| 130 | + port=None, |
| 131 | + pwd=None, |
| 132 | + vset=None, |
| 133 | + check=False): |
| 134 | + """Input function for the stock trend prediction dataset. |
| 135 | +
|
| 136 | + Returns: |
| 137 | + A dictionary of the following: |
| 138 | + 'train': dataset for training |
| 139 | + 'test': dataset for test/validation |
| 140 | + 'train_batches': total batch of train set |
| 141 | + 'test_batches': total batch of test set |
| 142 | + 'train_batch_size': size of a single train set batch |
| 143 | + 'test_batch_size': size of a single test set batch |
| 144 | + """ |
| 145 | + # Create dataset for training |
| 146 | + global feat_cols, max_step, time_shift |
| 147 | + global parallel, _prefetch, db_pool_size |
| 148 | + global db_host, db_port, db_pwd, shared_args, check_input |
| 149 | + time_shift = shift |
| 150 | + feat_cols = cols or k_cols |
| 151 | + max_step = step |
| 152 | + feat_size = len(feat_cols) * 2 * (time_shift + 1) |
| 153 | + parallel = cores |
| 154 | + _prefetch = pfetch |
| 155 | + db_pool_size = pool |
| 156 | + db_host = host |
| 157 | + db_port = port |
| 158 | + db_pwd = pwd |
| 159 | + check_input = check |
| 160 | + print("{} Using parallel: {}, prefetch: {} db_host: {} port: {}".format( |
| 161 | + strftime("%H:%M:%S"), parallel, _prefetch, db_host, db_port)) |
| 162 | + _init(db_pool_size, db_host, db_port, db_pwd) |
| 163 | + qk, qd, qd_idx, qk2 = _getFtQuery() |
| 164 | + shared_args = ray.put({ |
| 165 | + 'max_step': max_step, |
| 166 | + 'time_shift': time_shift, |
| 167 | + 'qk': qk, |
| 168 | + 'qk2': qk2, |
| 169 | + 'qd': qd, |
| 170 | + 'qd_idx': qd_idx, |
| 171 | + 'index_list': _getIndex(), |
| 172 | + 'db_host': db_host, |
| 173 | + 'db_port': db_port, |
| 174 | + 'db_pwd': db_pwd |
| 175 | + }) |
| 176 | + # query max flag from wcc_trn and fill a slice with flags between start and max |
| 177 | + train_batches, train_batch_size = _getDataSetMeta("TR") |
| 178 | + if train_batches is None: |
| 179 | + return None |
| 180 | + bnums = [bno for bno in range(start_bno, train_batches + 1)] |
| 181 | + |
| 182 | + def mapfunc(bno): |
| 183 | + ret = tf.numpy_function(func=_loadTrainingData_v2, |
| 184 | + inp=[bno], |
| 185 | + Tout=[tf.float32, tf.float32]) |
| 186 | + feat, corl = ret |
| 187 | + feat.set_shape((None, max_step, feat_size)) |
| 188 | + corl.set_shape((None, 1)) |
| 189 | + return feat, corl |
| 190 | + |
| 191 | + ds_train = tf.data.Dataset.from_tensor_slices(bnums).map( |
| 192 | + lambda bno: tuple(mapfunc(bno)), |
| 193 | + # num_parallel_calls=tf.data.experimental.AUTOTUNE |
| 194 | + num_parallel_calls=parallel |
| 195 | + ).prefetch( |
| 196 | + # tf.data.experimental.AUTOTUNE |
| 197 | + _prefetch |
| 198 | + ) |
| 199 | + |
| 200 | + # Create dataset for testing |
| 201 | + test_batches, test_batch_size = _getDataSetMeta("TS") |
| 202 | + ds_test = tf.data.Dataset.from_tensor_slices( |
| 203 | + _loadTestSet_v2(step, test_batches + 1, |
| 204 | + vset)).batch(test_batch_size).cache().repeat() |
| 205 | + |
| 206 | + return { |
| 207 | + 'train': ds_train, |
| 208 | + 'test': ds_test, |
| 209 | + 'train_batches': train_batches, |
| 210 | + 'test_batches': test_batches, |
| 211 | + 'train_batch_size': train_batch_size, |
| 212 | + 'test_batch_size': test_batch_size |
| 213 | + } |
0 commit comments