forked from openai/baselines
-
Notifications
You must be signed in to change notification settings - Fork 725
/
record_expert.py
183 lines (148 loc) · 6.7 KB
/
record_expert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import warnings
from typing import Dict
import cv2 # pytype:disable=import-error
import numpy as np
from gym import spaces
from stable_baselines.common.base_class import BaseRLModel
from stable_baselines.common.vec_env import VecEnv, VecFrameStack
from stable_baselines.common.base_class import _UnvecWrapper
def generate_expert_traj(model, save_path=None, env=None, n_timesteps=0,
n_episodes=100, image_folder='recorded_images'):
"""
Train expert controller (if needed) and record expert trajectories.
.. note::
only Box and Discrete spaces are supported for now.
:param model: (RL model or callable) The expert model, if it needs to be trained,
then you need to pass ``n_timesteps > 0``.
:param save_path: (str) Path without the extension where the expert dataset will be saved
(ex: 'expert_cartpole' -> creates 'expert_cartpole.npz').
If not specified, it will not save, and just return the generated expert trajectories.
This parameter must be specified for image-based environments.
:param env: (gym.Env) The environment, if not defined then it tries to use the model
environment.
:param n_timesteps: (int) Number of training timesteps
:param n_episodes: (int) Number of trajectories (episodes) to record
:param image_folder: (str) When using images, folder that will be used to record images.
:return: (dict) the generated expert trajectories.
"""
# Retrieve the environment using the RL model
if env is None and isinstance(model, BaseRLModel):
env = model.get_env()
assert env is not None, "You must set the env in the model or pass it to the function."
is_vec_env = False
if isinstance(env, VecEnv) and not isinstance(env, _UnvecWrapper):
is_vec_env = True
if env.num_envs > 1:
warnings.warn("You are using multiple envs, only the data from the first one will be recorded.")
# Sanity check
assert (isinstance(env.observation_space, spaces.Box) or
isinstance(env.observation_space, spaces.Discrete)), "Observation space type not supported"
assert (isinstance(env.action_space, spaces.Box) or
isinstance(env.action_space, spaces.Discrete)), "Action space type not supported"
# Check if we need to record images
obs_space = env.observation_space
record_images = len(obs_space.shape) == 3 and obs_space.shape[-1] in [1, 3, 4] \
and obs_space.dtype == np.uint8
if record_images and save_path is None:
warnings.warn("Observations are images but no save path was specified, so will save in numpy archive; "
"this can lead to higher memory usage.")
record_images = False
if not record_images and len(obs_space.shape) == 3 and obs_space.dtype == np.uint8:
warnings.warn("The observations looks like images (shape = {}) "
"but the number of channel > 4, so it will be saved in the numpy archive "
"which can lead to high memory usage".format(obs_space.shape))
image_ext = 'jpg'
if record_images:
# We save images as jpg or png, that have only 3/4 color channels
if isinstance(env, VecFrameStack) and env.n_stack == 4:
# assert env.n_stack < 5, "The current data recorder does no support"\
# "VecFrameStack with n_stack > 4"
image_ext = 'png'
folder_path = os.path.dirname(save_path)
image_folder = os.path.join(folder_path, image_folder)
os.makedirs(image_folder, exist_ok=True)
print("=" * 10)
print("Images will be recorded to {}/".format(image_folder))
print("Image shape: {}".format(obs_space.shape))
print("=" * 10)
if n_timesteps > 0 and isinstance(model, BaseRLModel):
model.learn(n_timesteps)
actions = []
observations = []
rewards = []
episode_returns = np.zeros((n_episodes,))
episode_starts = []
ep_idx = 0
obs = env.reset()
episode_starts.append(True)
reward_sum = 0.0
idx = 0
# state and mask for recurrent policies
state, mask = None, None
if is_vec_env:
mask = [True for _ in range(env.num_envs)]
while ep_idx < n_episodes:
obs_ = obs[0] if is_vec_env else obs
if record_images:
image_path = os.path.join(image_folder, "{}.{}".format(idx, image_ext))
# Convert from RGB to BGR
# which is the format OpenCV expect
if obs_.shape[-1] == 3:
obs_ = cv2.cvtColor(obs_, cv2.COLOR_RGB2BGR)
cv2.imwrite(image_path, obs_)
observations.append(image_path)
else:
observations.append(obs_)
if isinstance(model, BaseRLModel):
action, state = model.predict(obs, state=state, mask=mask)
else:
action = model(obs)
obs, reward, done, _ = env.step(action)
# Use only first env
if is_vec_env:
mask = [done[0] for _ in range(env.num_envs)]
action = np.array([action[0]])
reward = np.array([reward[0]])
done = np.array([done[0]])
actions.append(action)
rewards.append(reward)
episode_starts.append(done)
reward_sum += reward
idx += 1
if done:
if not is_vec_env:
obs = env.reset()
# Reset the state in case of a recurrent policy
state = None
episode_returns[ep_idx] = reward_sum
reward_sum = 0.0
ep_idx += 1
if isinstance(env.observation_space, spaces.Box) and not record_images:
observations = np.concatenate(observations).reshape((-1,) + env.observation_space.shape)
elif isinstance(env.observation_space, spaces.Discrete):
observations = np.array(observations).reshape((-1, 1))
elif record_images:
observations = np.array(observations)
if isinstance(env.action_space, spaces.Box):
actions = np.concatenate(actions).reshape((-1,) + env.action_space.shape)
elif isinstance(env.action_space, spaces.Discrete):
actions = np.array(actions).reshape((-1, 1))
rewards = np.array(rewards)
episode_starts = np.array(episode_starts[:-1])
assert len(observations) == len(actions)
# pytype: disable=attribute-error
numpy_dict = {
'actions': actions,
'obs': observations,
'rewards': rewards,
'episode_returns': episode_returns,
'episode_starts': episode_starts
}
for key, val in numpy_dict.items():
print(key, val.shape)
# pytype: enable=attribute-error
if save_path is not None:
np.savez(save_path, **numpy_dict)
env.close()
return numpy_dict