-
Notifications
You must be signed in to change notification settings - Fork 475
/
stocks_env.py
89 lines (67 loc) · 3.18 KB
/
stocks_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
from .trading_env import TradingEnv, Actions, Positions
class StocksEnv(TradingEnv):
def __init__(self, df, window_size, frame_bound, render_mode=None):
assert len(frame_bound) == 2
self.frame_bound = frame_bound
super().__init__(df, window_size, render_mode)
self.trade_fee_bid_percent = 0.01 # unit
self.trade_fee_ask_percent = 0.005 # unit
def _process_data(self):
prices = self.df.loc[:, 'Close'].to_numpy()
prices[self.frame_bound[0] - self.window_size] # validate index (TODO: Improve validation)
prices = prices[self.frame_bound[0]-self.window_size:self.frame_bound[1]]
diff = np.insert(np.diff(prices), 0, 0)
signal_features = np.column_stack((prices, diff))
return prices.astype(np.float32), signal_features.astype(np.float32)
def _calculate_reward(self, action):
step_reward = 0
trade = False
if (
(action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)
):
trade = True
if trade:
current_price = self.prices[self._current_tick]
last_trade_price = self.prices[self._last_trade_tick]
price_diff = current_price - last_trade_price
if self._position == Positions.Long:
step_reward += price_diff
return step_reward
def _update_profit(self, action):
trade = False
if (
(action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)
):
trade = True
if trade or self._truncated:
current_price = self.prices[self._current_tick]
last_trade_price = self.prices[self._last_trade_tick]
if self._position == Positions.Long:
shares = (self._total_profit * (1 - self.trade_fee_ask_percent)) / last_trade_price
self._total_profit = (shares * (1 - self.trade_fee_bid_percent)) * current_price
def max_possible_profit(self):
current_tick = self._start_tick
last_trade_tick = current_tick - 1
profit = 1.
while current_tick <= self._end_tick:
position = None
if self.prices[current_tick] < self.prices[current_tick - 1]:
while (current_tick <= self._end_tick and
self.prices[current_tick] < self.prices[current_tick - 1]):
current_tick += 1
position = Positions.Short
else:
while (current_tick <= self._end_tick and
self.prices[current_tick] >= self.prices[current_tick - 1]):
current_tick += 1
position = Positions.Long
if position == Positions.Long:
current_price = self.prices[current_tick - 1]
last_trade_price = self.prices[last_trade_tick]
shares = profit / last_trade_price
profit = shares * current_price
last_trade_tick = current_tick - 1
return profit