forked from jdmartin86/frogseye
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsave_data.py
23 lines (20 loc) · 880 Bytes
/
save_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
'''
Implementation of all functions that save data to disk.
'''
import numpy as np
import scripts.plot_utils as plot_utils
def save_mse_data(rewards,
preds,
filename_prefix,
bin_size=100000,
discount=0.99,
num_steps=5000000,
environ='FrogsEye'):
if environ == 'FrogsEye':
returns = plot_utils.compute_returns_FrogsEye_experiments(rewards, discount, bin_size=bin_size,
num_steps=num_steps)
sq_errors = (returns - preds[:-bin_size])**2
np.savez_compressed(filename_prefix + "_values", preds)
np.savez_compressed(filename_prefix + "_returns", returns)
np.savez_compressed(filename_prefix + "_rewards", rewards)
np.savez_compressed(filename_prefix + "_sq_errors", sq_errors)