Skip to content

Commit f0b9a80

Browse files
authored
Add A New Baseline: TCN (#668)
1 parent 5ee2d94 commit f0b9a80

9 files changed

+893
-1
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ Here is a list of models built on `Qlib`.
294294
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](qlib/contrib/model/pytorch_transformer.py)
295295
- [Localformer based on pytorch (Juyong Jiang, et al.)](qlib/contrib/model/pytorch_localformer.py)
296296
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](qlib/contrib/model/pytorch_tra.py)
297+
- [TCN based on pytorch (Shaojie Bai, et al. 2018)](qlib/contrib/model/pytorch_tcn.py)
297298
298299
Your PR of new Quant models is highly welcomed.
299300

examples/benchmarks/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
3434
| MLP | Alpha158 | 0.0376±0.00 | 0.2846±0.02 | 0.0429±0.00 | 0.3220±0.01 | 0.0895±0.02 | 1.1408±0.23 | -0.1103±0.02 |
3535
| LightGBM(Guolin Ke, et al.) | Alpha158 | 0.0448±0.00 | 0.3660±0.00 | 0.0469±0.00 | 0.3877±0.00 | 0.0901±0.00 | 1.0164±0.00 | -0.1038±0.00 |
3636
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4340±0.00 | 0.0523±0.00 | 0.4284±0.01 | 0.1168±0.01 | 1.3384±0.12 | -0.1036±0.01 |
37+
| TCN | Alpha158 | 0.0275±0.00 | 0.2157±0.01 | 0.0411±0.00 | 0.3379±0.01 | 0.0190±0.02 | 0.2887±0.27 | -0.1202±0.03 |
3738

3839

3940

@@ -55,6 +56,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
5556
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 |
5657
| TCTS(Xueqing Wu, et al.) | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03 | 1.2256±0.36 | -0.0857±0.02 |
5758
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
59+
| TCN(Shaojie Bai, et al.) | Alpha360 | 0.0441±0.00 | 0.3301±0.02 | 0.0519±0.00 | 0.4130±0.01 | 0.0604±0.02 | 0.8295±0.34 | -0.1018±0.03 |
5860

5961
- The selected 20 features are based on the feature importance of a lightgbm-based model.
6062
- The base model of DoubleEnsemble is LGBM.
+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
numpy==1.17.4
2+
pandas==1.1.2
3+
scikit_learn==0.23.2
4+
torch==1.7.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
qlib_init:
2+
provider_uri: "~/.qlib/qlib_data/cn_data"
3+
region: cn
4+
market: &market csi300
5+
benchmark: &benchmark SH000300
6+
data_handler_config: &data_handler_config
7+
start_time: 2008-01-01
8+
end_time: 2020-08-01
9+
fit_start_time: 2008-01-01
10+
fit_end_time: 2014-12-31
11+
instruments: *market
12+
infer_processors:
13+
- class: FilterCol
14+
kwargs:
15+
fields_group: feature
16+
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
17+
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
18+
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
19+
]
20+
- class: RobustZScoreNorm
21+
kwargs:
22+
fields_group: feature
23+
clip_outlier: true
24+
- class: Fillna
25+
kwargs:
26+
fields_group: feature
27+
learn_processors:
28+
- class: DropnaLabel
29+
- class: CSRankNorm
30+
kwargs:
31+
fields_group: label
32+
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
33+
34+
port_analysis_config: &port_analysis_config
35+
strategy:
36+
class: TopkDropoutStrategy
37+
module_path: qlib.contrib.strategy
38+
kwargs:
39+
model: <MODEL>
40+
dataset: <DATASET>
41+
topk: 50
42+
n_drop: 5
43+
backtest:
44+
start_time: 2017-01-01
45+
end_time: 2020-08-01
46+
account: 100000000
47+
benchmark: *benchmark
48+
exchange_kwargs:
49+
limit_threshold: 0.095
50+
deal_price: close
51+
open_cost: 0.0005
52+
close_cost: 0.0015
53+
min_cost: 5
54+
task:
55+
model:
56+
class: TCN
57+
module_path: qlib.contrib.model.pytorch_tcn_ts
58+
kwargs:
59+
d_feat: 20
60+
num_layers: 5
61+
n_chans: 32
62+
kernel_size: 7
63+
dropout: 0.5
64+
n_epochs: 200
65+
lr: 1e-4
66+
early_stop: 20
67+
batch_size: 2000
68+
metric: loss
69+
loss: mse
70+
optimizer: adam
71+
n_jobs: 20
72+
GPU: 0
73+
dataset:
74+
class: TSDatasetH
75+
module_path: qlib.data.dataset
76+
kwargs:
77+
handler:
78+
class: Alpha158
79+
module_path: qlib.contrib.data.handler
80+
kwargs: *data_handler_config
81+
segments:
82+
train: [2008-01-01, 2014-12-31]
83+
valid: [2015-01-01, 2016-12-31]
84+
test: [2017-01-01, 2020-08-01]
85+
step_len: 20
86+
record:
87+
- class: SignalRecord
88+
module_path: qlib.workflow.record_temp
89+
kwargs:
90+
model: <MODEL>
91+
dataset: <DATASET>
92+
- class: SigAnaRecord
93+
module_path: qlib.workflow.record_temp
94+
kwargs:
95+
ana_long_short: False
96+
ann_scaler: 252
97+
- class: PortAnaRecord
98+
module_path: qlib.workflow.record_temp
99+
kwargs:
100+
config: *port_analysis_config
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
qlib_init:
2+
provider_uri: "~/.qlib/qlib_data/cn_data"
3+
region: cn
4+
market: &market csi300
5+
benchmark: &benchmark SH000300
6+
data_handler_config: &data_handler_config
7+
start_time: 2008-01-01
8+
end_time: 2020-08-01
9+
fit_start_time: 2008-01-01
10+
fit_end_time: 2014-12-31
11+
instruments: *market
12+
infer_processors:
13+
- class: RobustZScoreNorm
14+
kwargs:
15+
fields_group: feature
16+
clip_outlier: true
17+
- class: Fillna
18+
kwargs:
19+
fields_group: feature
20+
learn_processors:
21+
- class: DropnaLabel
22+
- class: CSRankNorm
23+
kwargs:
24+
fields_group: label
25+
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
26+
port_analysis_config: &port_analysis_config
27+
strategy:
28+
class: TopkDropoutStrategy
29+
module_path: qlib.contrib.strategy
30+
kwargs:
31+
model: <MODEL>
32+
dataset: <DATASET>
33+
topk: 50
34+
n_drop: 5
35+
backtest:
36+
start_time: 2017-01-01
37+
end_time: 2020-08-01
38+
account: 100000000
39+
benchmark: *benchmark
40+
exchange_kwargs:
41+
limit_threshold: 0.095
42+
deal_price: close
43+
open_cost: 0.0005
44+
close_cost: 0.0015
45+
min_cost: 5
46+
task:
47+
model:
48+
class: TCN
49+
module_path: qlib.contrib.model.pytorch_tcn
50+
kwargs:
51+
d_feat: 6
52+
num_layers: 5
53+
n_chans: 128
54+
kernel_size: 3
55+
dropout: 0.5
56+
n_epochs: 200
57+
lr: 1e-3
58+
early_stop: 20
59+
batch_size: 2000
60+
metric: loss
61+
loss: mse
62+
optimizer: adam
63+
GPU: 0
64+
dataset:
65+
class: DatasetH
66+
module_path: qlib.data.dataset
67+
kwargs:
68+
handler:
69+
class: Alpha360
70+
module_path: qlib.contrib.data.handler
71+
kwargs: *data_handler_config
72+
segments:
73+
train: [2008-01-01, 2014-12-31]
74+
valid: [2015-01-01, 2016-12-31]
75+
test: [2017-01-01, 2020-08-01]
76+
record:
77+
- class: SignalRecord
78+
module_path: qlib.workflow.record_temp
79+
kwargs:
80+
model: <MODEL>
81+
dataset: <DATASET>
82+
- class: SigAnaRecord
83+
module_path: qlib.workflow.record_temp
84+
kwargs:
85+
ana_long_short: False
86+
ann_scaler: 252
87+
- class: PortAnaRecord
88+
module_path: qlib.workflow.record_temp
89+
kwargs:
90+
config: *port_analysis_config

qlib/contrib/model/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@
3030
from .pytorch_nn import DNNModelPytorch
3131
from .pytorch_tabnet import TabnetModel
3232
from .pytorch_sfm import SFM_Model
33+
from .pytorch_tcn import TCN
3334

34-
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model)
35+
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN)
3536
except ModuleNotFoundError:
3637
pytorch_classes = ()
3738
print("Please install necessary libs for PyTorch models.")

0 commit comments

Comments
 (0)