-
-
Notifications
You must be signed in to change notification settings - Fork 121
/
Copy pathstock_prediction_plotter.py
80 lines (72 loc) · 3.35 KB
/
stock_prediction_plotter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Copyright 2020-2024 Jordi Corbilla. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import matplotlib.pyplot as plt
class Plotter:
def __init__(self, blocking, project_folder, short_name, currency, stock_ticker):
self.blocking = blocking
self.project_folder = project_folder
self.short_name = short_name
self.currency = currency
self.stock_ticker = stock_ticker
def plot_histogram_data_split(self, training_data, test_data, validation_date):
print("plotting Data and Histogram")
plt.figure(figsize=(12, 5))
plt.plot(training_data.Close, color='green')
plt.plot(test_data.Close, color='red')
plt.ylabel('Price [' + self.currency + ']')
plt.xlabel("Date")
plt.legend(["Training Data", "Validation Data >= " + validation_date.strftime("%Y-%m-%d")])
plt.title(self.short_name)
plt.savefig(os.path.join(self.project_folder, self.short_name.strip().replace('.', '') + '_price.png'))
fig, ax = plt.subplots()
training_data.hist(ax=ax)
fig.savefig(os.path.join(self.project_folder, self.short_name.strip().replace('.', '') + '_hist.png'))
plt.pause(0.001)
plt.show(block=self.blocking)
def plot_loss(self, history):
print("plotting loss")
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss/Validation Loss')
plt.legend(loc='upper right')
plt.savefig(os.path.join(self.project_folder, 'loss.png'))
plt.pause(0.001)
plt.show(block=self.blocking)
def plot_mse(self, history):
print("plotting MSE")
plt.plot(history.history['MSE'], label='MSE')
plt.plot(history.history['val_MSE'], label='val_MSE')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.title('MSE/Validation MSE')
plt.legend(loc='upper right')
plt.savefig(os.path.join(self.project_folder, 'MSE.png'))
plt.pause(0.001)
plt.show(block=self.blocking)
def project_plot_predictions(self, price_predicted, test_data):
print("plotting predictions")
plt.figure(figsize=(14, 5))
plt.plot(price_predicted[self.stock_ticker + '_predicted'], color='red', label='Predicted [' + self.short_name + '] price')
plt.plot(test_data.Close, color='green', label='Actual [' + self.short_name + '] price')
plt.xlabel('Time')
plt.ylabel('Price [' + self.currency + ']')
plt.legend()
plt.title('Prediction')
plt.savefig(os.path.join(self.project_folder, self.short_name.strip().replace('.', '') + '_prediction.png'))
plt.pause(0.001)
plt.show(block=self.blocking)