-
Notifications
You must be signed in to change notification settings - Fork 1
/
training_demo.py
84 lines (73 loc) · 2.77 KB
/
training_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import time
import uuid
import os
import torch
from training.DQN.actor import ActorConfig, Actor
from training.DQN.learner import LearnerConfig, DQNLearner
from training.model_io.output_wrapper import Action11OutputWrapper, Action3OutputWrapper
from training.model_io.featureEngine import FeatureEngineVersion3_Simple, FeatureEngine_single600T_Mod
from training.env.trainingEnv import TrainingStockEnv
from training.model.DNN import DNN, FullPosDNN
from training.replay.ReplayBuffer import ReplayBuffer
from training.reward.rewards import scaled_net_return, single_trade_return600TWAP, single_trade_return600TWAP_Mod
from training.util.logger import logger
from training.util.explicit_control import ExplicitControlConf
if __name__ == '__main__':
TRAINING_EPISODE_NUM = 1e6
LEARNING_PERIOD = 16
SAVING_PATH = '/mnt/data3/rl-data/training_res'
exp_name = f'{time.strftime("%Y%m%d:%H%M%S", time.localtime())}-{str(uuid.uuid4())[:8]}'
os.makedirs(os.path.join(SAVING_PATH, exp_name))
feature_engine = FeatureEngine_single600T_Mod(max_position=10)
model = DNN(input_dim=feature_engine.get_input_shape(), hidden_dim=[64], output_dim=Action3OutputWrapper.get_output_shape())
model_output_wrapper = Action3OutputWrapper(model)
replay_buffer = ReplayBuffer(10000)
explicit_config = ExplicitControlConf(
signal_risk_thresh = -float('inf')
)
env = TrainingStockEnv(
mode='ordered',
reward_fn=single_trade_return600TWAP_Mod,
save_metric_path=os.path.join(SAVING_PATH, exp_name),
save_code_metric=True,
max_postion=feature_engine.max_position)
actor_config = ActorConfig(
eps_start=0.9,
eps_end=0.05,
eps_decay=1e6,
minimal_buffer_size=1000,
)
actor = Actor(
env,
feature_engine,
model_output_wrapper,
replay_buffer,
actor_config,
explicit_config,
)
learner_config = LearnerConfig(
batch_size=128,
gamma=0.99,
tau=0.005,
lr=1e-5,
optimizer_type='SGD',
device=torch.device('cpu'),
#model_save_prefix=SAVING_PATH,
model_save_step=20000,
minimal_buffer_size=1000,
)
learner = DQNLearner(
learner_config,
model,
replay_buffer,
os.path.join(SAVING_PATH, exp_name),
)
while env.episode_cnt < TRAINING_EPISODE_NUM:
actor.step()
if env.step_cnt % LEARNING_PERIOD == 0:
loss = learner.step()
if env.step_cnt % (1000*LEARNING_PERIOD) == 0:
logger.info(f"learner stepping, "
f"current step count: {env.step_cnt}, "
f"current episode count: {env.episode_cnt}, "
f"learning loss: {loss}")