Skip to content

Commit 70d8ebe

Browse files
committed
remove smart_cond
1 parent ae1075b commit 70d8ebe

File tree

5 files changed

+233
-15
lines changed

5 files changed

+233
-15
lines changed

common/common.py

+3
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def parseArgs():
131131
return parser.parse_args()
132132

133133

134+
def next_power_of_2(x):
135+
return 1 if x == 0 else 2**(x - 1).bit_length()
136+
134137
def setupPath():
135138
p1 = os.path.dirname(os.path.abspath(__file__))
136139
p2 = os.path.dirname(p1)

corl/model/tf2/common.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed): # pylint: di
6464
x = inputs * kept_idx + alpha_p * (1 - kept_idx)
6565
# Do affine transformation
6666
return a * x + b
67-
return tf_utils.smart_cond(
67+
return tf.cond(
6868
tf.math.logical_and(
6969
tf.math.greater(self.rate, 0.),
7070
tf.math.less(self.rate, 1.)
@@ -168,15 +168,15 @@ def call(self, inputs, training=None):
168168

169169
def dropout():
170170
self.global_step.assign_add(1)
171-
rate = tf_utils.smart_cond(
171+
rate = tf.cond(
172172
tf.math.less(self.global_step, self._decay_start),
173173
lambda: self.initial_dropout_rate,
174174
lambda: self.cosine_decay_restarts(self.global_step-self._decay_start+1)
175175
)
176176
self.dropout_layer.rate = rate
177177
return self.dropout_layer(inputs, training)
178178

179-
output = tf_utils.smart_cond(
179+
output = tf.cond(
180180
training,
181181
dropout,
182182
lambda: tf.identity(inputs)

model/common.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed): # pylint: di
6464
x = inputs * kept_idx + alpha_p * (1 - kept_idx)
6565
# Do affine transformation
6666
return a * x + b
67-
return tf_utils.smart_cond(
67+
return tf.cond(
6868
tf.math.logical_and(
6969
tf.math.greater(self.rate, 0.),
7070
tf.math.less(self.rate, 1.)
@@ -168,15 +168,15 @@ def call(self, inputs, training=None):
168168

169169
def dropout():
170170
self.global_step.assign_add(1)
171-
rate = tf_utils.smart_cond(
171+
rate = tf.cond(
172172
tf.math.less(self.global_step, self._decay_start),
173173
lambda: self.initial_dropout_rate,
174174
lambda: self.cosine_decay_restarts(self.global_step-self._decay_start+1)
175175
)
176176
self.dropout_layer.rate = rate
177177
return self.dropout_layer(inputs, training)
178178

179-
output = tf_utils.smart_cond(
179+
output = tf.cond(
180180
training,
181181
dropout,
182182
lambda: tf.identity(inputs)

pstk/data/data15.py

+213
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
2+
3+
import os
4+
import sys
5+
import psutil
6+
import ray
7+
8+
import tensorflow as tf
9+
10+
from time import strftime
11+
from mysql.connector.pooling import MySQLConnectionPool
12+
13+
k_cols = ["lr"]
14+
15+
idxlst = None
16+
feat_cols = []
17+
parallel = None
18+
time_shift = None
19+
max_step = None
20+
_prefetch = None
21+
db_pool_size = None
22+
db_host = None
23+
db_port = None
24+
db_pwd = None
25+
shared_args = None
26+
check_input = False
27+
28+
cnxpool = None
29+
30+
maxbno_query = ("SELECT "
31+
" vmax "
32+
"FROM "
33+
" fs_stats "
34+
"WHERE "
35+
" method = 'standardization' "
36+
" AND tab = 'wcc_trn' "
37+
" AND fields = %s ")
38+
39+
40+
def _getIndex():
41+
'''
42+
Returns a set of index codes from idxlst table.
43+
'''
44+
global idxlst
45+
if idxlst is not None:
46+
return idxlst
47+
print("{} loading index...".format(strftime("%H:%M:%S")))
48+
cnx = cnxpool.get_connection()
49+
try:
50+
cursor = cnx.cursor(buffered=True)
51+
query = ('SELECT distinct code COLLATE utf8mb4_0900_as_cs FROM idxlst')
52+
cursor.execute(query)
53+
rows = cursor.fetchall()
54+
cursor.close()
55+
idxlst = {c[0] for c in rows}
56+
return idxlst
57+
except:
58+
print(sys.exc_info()[0])
59+
raise
60+
finally:
61+
cnx.close()
62+
63+
def _init(db_pool_size=None, db_host=None, db_port=None, db_pwd=None):
64+
global cnxpool
65+
print("{} [PID={}]: initializing mysql connection pool...".format(
66+
strftime("%H:%M:%S"), os.getpid()))
67+
cnxpool = MySQLConnectionPool(
68+
pool_name="dbpool",
69+
pool_size=db_pool_size or 5,
70+
host=db_host or '127.0.0.1',
71+
port=db_port or 3306,
72+
user='mysql',
73+
database='secu',
74+
password=db_pwd or '123456',
75+
# ssl_ca='',
76+
# use_pure=True,
77+
connect_timeout=90000)
78+
ray.init(
79+
num_cpus=psutil.cpu_count(logical=False),
80+
webui_host='127.0.0.1',
81+
memory=4 * 1024 * 1024 * 1024, # 4G
82+
object_store_memory=4 * 1024 * 1024 * 1024, # 4G
83+
driver_object_store_memory=256 * 1024 * 1024 # 256M
84+
)
85+
86+
87+
def _getDataSetMeta(flag):
88+
global cnxpool
89+
cnx = cnxpool.get_connection()
90+
max_bno, batch_size = None, None
91+
try:
92+
print('{} querying max batch no for {} set...'.format(
93+
strftime("%H:%M:%S"), flag))
94+
cursor = cnx.cursor()
95+
cursor.execute(maxbno_query, (flag + "_BNO", ))
96+
row = cursor.fetchone()
97+
max_bno = int(row[0])
98+
print('{} max batch no: {}'.format(strftime("%H:%M:%S"), max_bno))
99+
query = ("SELECT "
100+
" COUNT(*) "
101+
"FROM "
102+
" wcc_trn "
103+
"WHERE "
104+
" flag = %s "
105+
" AND bno = 1 ")
106+
cursor.execute(query, (flag, ))
107+
row = cursor.fetchone()
108+
batch_size = row[0]
109+
print('{} batch size: {}'.format(strftime("%H:%M:%S"), batch_size))
110+
if batch_size == 0:
111+
print('{} no more data for {}.'.format(strftime("%H:%M:%S"),
112+
flag.lower()))
113+
return None, None
114+
cursor.close()
115+
except:
116+
print(sys.exc_info()[0])
117+
raise
118+
finally:
119+
cnx.close()
120+
return max_bno, batch_size
121+
122+
def getInputs(start_bno=0,
123+
shift=0,
124+
cols=None,
125+
step=30,
126+
cores=psutil.cpu_count(logical=False),
127+
pfetch=2,
128+
pool=None,
129+
host=None,
130+
port=None,
131+
pwd=None,
132+
vset=None,
133+
check=False):
134+
"""Input function for the stock trend prediction dataset.
135+
136+
Returns:
137+
A dictionary of the following:
138+
'train': dataset for training
139+
'test': dataset for test/validation
140+
'train_batches': total batch of train set
141+
'test_batches': total batch of test set
142+
'train_batch_size': size of a single train set batch
143+
'test_batch_size': size of a single test set batch
144+
"""
145+
# Create dataset for training
146+
global feat_cols, max_step, time_shift
147+
global parallel, _prefetch, db_pool_size
148+
global db_host, db_port, db_pwd, shared_args, check_input
149+
time_shift = shift
150+
feat_cols = cols or k_cols
151+
max_step = step
152+
feat_size = len(feat_cols) * 2 * (time_shift + 1)
153+
parallel = cores
154+
_prefetch = pfetch
155+
db_pool_size = pool
156+
db_host = host
157+
db_port = port
158+
db_pwd = pwd
159+
check_input = check
160+
print("{} Using parallel: {}, prefetch: {} db_host: {} port: {}".format(
161+
strftime("%H:%M:%S"), parallel, _prefetch, db_host, db_port))
162+
_init(db_pool_size, db_host, db_port, db_pwd)
163+
qk, qd, qd_idx, qk2 = _getFtQuery()
164+
shared_args = ray.put({
165+
'max_step': max_step,
166+
'time_shift': time_shift,
167+
'qk': qk,
168+
'qk2': qk2,
169+
'qd': qd,
170+
'qd_idx': qd_idx,
171+
'index_list': _getIndex(),
172+
'db_host': db_host,
173+
'db_port': db_port,
174+
'db_pwd': db_pwd
175+
})
176+
# query max flag from wcc_trn and fill a slice with flags between start and max
177+
train_batches, train_batch_size = _getDataSetMeta("TR")
178+
if train_batches is None:
179+
return None
180+
bnums = [bno for bno in range(start_bno, train_batches + 1)]
181+
182+
def mapfunc(bno):
183+
ret = tf.numpy_function(func=_loadTrainingData_v2,
184+
inp=[bno],
185+
Tout=[tf.float32, tf.float32])
186+
feat, corl = ret
187+
feat.set_shape((None, max_step, feat_size))
188+
corl.set_shape((None, 1))
189+
return feat, corl
190+
191+
ds_train = tf.data.Dataset.from_tensor_slices(bnums).map(
192+
lambda bno: tuple(mapfunc(bno)),
193+
# num_parallel_calls=tf.data.experimental.AUTOTUNE
194+
num_parallel_calls=parallel
195+
).prefetch(
196+
# tf.data.experimental.AUTOTUNE
197+
_prefetch
198+
)
199+
200+
# Create dataset for testing
201+
test_batches, test_batch_size = _getDataSetMeta("TS")
202+
ds_test = tf.data.Dataset.from_tensor_slices(
203+
_loadTestSet_v2(step, test_batches + 1,
204+
vset)).batch(test_batch_size).cache().repeat()
205+
206+
return {
207+
'train': ds_train,
208+
'test': ds_test,
209+
'train_batches': train_batches,
210+
'test_batches': test_batches,
211+
'train_batch_size': train_batch_size,
212+
'test_batch_size': test_batch_size
213+
}

test/test24_mdnc.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11

2-
from corl.wc_test.test_runner import run
3-
from corl.model.tf2 import dnc_regressor
4-
from corl.wc_test.common import next_power_of_2
5-
from time import strftime
2+
from common.train_runner import run
3+
from model import dnc_regressor
4+
from common.common import next_power_of_2
65
# Path hack.
76
import sys
87
import os
@@ -38,6 +37,8 @@
3837
MEMORY_SIZE = 32
3938
NUM_READ_HEADS = 8
4039

40+
NUM_CLASSES = 5
41+
4142
VAL_SAVE_FREQ = 500
4243
STEPS_PER_EPOCH = 500
4344

@@ -48,7 +49,7 @@
4849

4950

5051
def create_regressor():
51-
regressor = dnc_regressor.DNC_Model_V8(
52+
regressor = dnc_regressor.BaseModel(
5253
num_cnn_layers=NUM_CNN_LAYERS,
5354
num_dnc_layers=NUM_DNC_LAYERS,
5455
num_fcn_layers=NUM_FCN_LAYERS,
@@ -70,6 +71,7 @@ def create_regressor():
7071
decayed_lr_start=DECAYED_LR_START,
7172
lr_decay_steps=LR_DECAY_STEPS,
7273
clipvalue=CLIP_VALUE,
74+
num_classes=NUM_CLASSES,
7375
seed=SEED,
7476
)
7577
return regressor
@@ -78,15 +80,15 @@ def create_regressor():
7880
if __name__ == '__main__':
7981

8082
np.random.seed(SEED)
83+
8184
regressor = create_regressor()
8285

83-
run(
84-
id="test27_mdnc",
86+
run(id="stock_trend_test24_mdnc",
8587
regressor=regressor,
88+
vset=None,
8689
max_step=MAX_STEP,
8790
time_shift=TIME_SHIFT,
8891
feat_cols=FEAT_COLS,
8992
val_save_freq=VAL_SAVE_FREQ,
9093
steps_per_epoch=STEPS_PER_EPOCH,
91-
include_seqlens=INCLUDE_SEQLENS,
92-
)
94+
data_pipeline=None)

0 commit comments

Comments
 (0)