forked from tradytics/eiten
-
Notifications
You must be signed in to change notification settings - Fork 7
/
data_loader.py
159 lines (133 loc) · 5.24 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Basic libraries
import os
import collections
import pandas as pd
import yfinance as yf
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
class DataEngine:
def __init__(self, args):
print("\n--> Data engine has been initialized...")
self.args = args
# Stocks list
self.directory_path = str(os.path.dirname(os.path.abspath(__file__)))
str_path = f"{self.directory_path}/{self.args.stocks_file_path}"
self.stocks_file_path = str_path
self.stocks_list = []
# Load stock names in a list
self.load_stocks_from_file()
# Dictionary to store data. This will only store and save data if
# the argument is_save_dictionary is 1.
self.data_dictionary = {}
# Data length
self.stock_data_length = 0
def load_stocks_from_file(self):
"""
Load stock names from the file
"""
print("Loading all stocks from file...")
stocks_list = []
with open(self.stocks_file_path, "r") as f:
stocks_list = [str(item).strip() for item in f]
# Load symbols
stocks_list = list(sorted(set(stocks_list)))
print("Total number of stocks: %d" % len(stocks_list))
self.stocks_list = stocks_list
def get_most_frequent_count(self, input_list):
counter = collections.Counter(input_list)
return list(counter.keys())[0]
def _split_data(self, data):
if self.args.is_test:
return (data.iloc[:-self.args.future_bars]["Adj Close"].values,
data.iloc[-self.args.future_bars:]["Adj Close"].values)
return data["Adj Close"].values, None
def _format_symbol(self, s):
x = s.upper()
x = x.replace(".VN", ".V")
if len(x.split(".")) > 2:
x = x.replace(".", "-", 1)
return x
def get_data(self, symbol_raw):
"""
Get stock data from yahoo finance.
"""
symbol = self._format_symbol(symbol_raw)
future_prices = None
historical_prices = None
# Find period
if self.args.data_granularity_minutes == 1:
period = "7d"
interval = str(self.args.data_granularity_minutes) + "m"
if self.args.data_granularity_minutes == 3600:
period = "5y"
interval = "1d"
else:
period = "30d"
interval = str(self.args.data_granularity_minutes) + "m"
# Get stock price
try:
# Stock price
stock_prices = yf.download(
tickers=symbol,
period=period,
interval=interval,
auto_adjust=False,
progress=False)
# stock_prices = stock_prices.reset_index()
if self.stock_data_length == 0:
self.stock_data_length = stock_prices.shape[0]
elif stock_prices.shape[0] != self.stock_data_length:
raise Exception(f"{symbol}: Invalid Stock Length")
if self.args.history_to_use == "all":
# For some reason, yfinance gives some 0
# values in the first index
stock_prices = stock_prices.iloc[1:]
else:
stock_prices = stock_prices.iloc[-self.args.history_to_use:]
historical_prices, future_prices = self._split_data(stock_prices)
except Exception as e:
print("Exception", e)
return None, None
return historical_prices, future_prices
def collect_data_for_all_tickers(self):
"""
Iterates over all symbols and collects their data
"""
print("Loading data for all stocks...")
data_dict = {"historical": pd.DataFrame(),
"future": pd.DataFrame()
}
# Any stock with very low volatility is ignored.
# You can change this line to address that.
for i in tqdm(range(len(self.stocks_list))):
symbol = self.stocks_list[i]
try:
historical_data, future_data = self.get_data(symbol)
if historical_data is not None:
data_dict["historical"][symbol] = historical_data
if future_data is not None:
data_dict["future"][symbol] = future_data
except Exception as e:
print("Exception", e)
continue
data_dict["historical"] = data_dict["historical"].fillna(1)
data_dict["future"] = data_dict["future"].fillna(1)
try:
data_dict["historical"].to_csv("historical.csv")
data_dict["future"].to_csv("future.csv")
except Exception as e:
print("Exception: ", e)
# try:
# plt.style.use('seaborn-white')
# plt.rc('grid', linestyle="dotted", color='#a0a0a0')
# plt.rcParams['axes.edgecolor'] = "#04383F"
# plt.rcParams['figure.figsize'] = (16, 9)
# data_dict["historical"].plot()
# plt.savefig("./output/gt_historical.png")
# data_dict["future"].plot()
# plt.savefig("./output/gt_future.png")
# plt.clf()
# except Exception as e:
# print("Exception: ", e)
return data_dict