forked from openai/baselines
-
Notifications
You must be signed in to change notification settings - Fork 725
/
dataset.py
371 lines (307 loc) · 14.1 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
import queue
import time
from multiprocessing import Queue, Process
import cv2 # pytype:disable=import-error
import numpy as np
from joblib import Parallel, delayed
from stable_baselines import logger
class ExpertDataset(object):
"""
Dataset for using behavior cloning or GAIL.
The structure of the expert dataset is a dict, saved as an ".npz" archive.
The dictionary contains the keys 'actions', 'episode_returns', 'rewards', 'obs' and 'episode_starts'.
The corresponding values have data concatenated across episode: the first axis is the timestep,
the remaining axes index into the data. In case of images, 'obs' contains the relative path to
the images, to enable space saving from image compression.
:param expert_path: (str) The path to trajectory data (.npz file). Mutually exclusive with traj_data.
:param traj_data: (dict) Trajectory data, in format described above. Mutually exclusive with expert_path.
:param train_fraction: (float) the train validation split (0 to 1)
for pre-training using behavior cloning (BC)
:param batch_size: (int) the minibatch size for behavior cloning
:param traj_limitation: (int) the number of trajectory to use (if -1, load all)
:param randomize: (bool) if the dataset should be shuffled
:param verbose: (int) Verbosity
:param sequential_preprocessing: (bool) Do not use subprocess to preprocess
the data (slower but use less memory for the CI)
"""
# Excluded attribute when pickling the object
EXCLUDED_KEYS = {'dataloader', 'train_loader', 'val_loader'}
def __init__(self, expert_path=None, traj_data=None, train_fraction=0.7, batch_size=64,
traj_limitation=-1, randomize=True, verbose=1, sequential_preprocessing=False):
if traj_data is not None and expert_path is not None:
raise ValueError("Cannot specify both 'traj_data' and 'expert_path'")
if traj_data is None and expert_path is None:
raise ValueError("Must specify one of 'traj_data' or 'expert_path'")
if traj_data is None:
traj_data = np.load(expert_path, allow_pickle=True)
if verbose > 0:
for key, val in traj_data.items():
print(key, val.shape)
# Array of bool where episode_starts[i] = True for each new episode
episode_starts = traj_data['episode_starts']
traj_limit_idx = len(traj_data['obs'])
if traj_limitation > 0:
n_episodes = 0
# Retrieve the index corresponding
# to the traj_limitation trajectory
for idx, episode_start in enumerate(episode_starts):
n_episodes += int(episode_start)
if n_episodes == (traj_limitation + 1):
traj_limit_idx = idx - 1
observations = traj_data['obs'][:traj_limit_idx]
actions = traj_data['actions'][:traj_limit_idx]
# obs, actions: shape (N * L, ) + S
# where N = # episodes, L = episode length
# and S is the environment observation/action space.
# S = (1, ) for discrete space
# Flatten to (N * L, prod(S))
if len(observations.shape) > 2:
observations = np.reshape(observations, [-1, np.prod(observations.shape[1:])])
if len(actions.shape) > 2:
actions = np.reshape(actions, [-1, np.prod(actions.shape[1:])])
indices = np.random.permutation(len(observations)).astype(np.int64)
# Train/Validation split when using behavior cloning
train_indices = indices[:int(train_fraction * len(indices))]
val_indices = indices[int(train_fraction * len(indices)):]
assert len(train_indices) > 0, "No sample for the training set"
assert len(val_indices) > 0, "No sample for the validation set"
self.observations = observations
self.actions = actions
self.returns = traj_data['episode_returns'][:traj_limit_idx]
self.avg_ret = sum(self.returns) / len(self.returns)
self.std_ret = np.std(np.array(self.returns))
self.verbose = verbose
assert len(self.observations) == len(self.actions), "The number of actions and observations differ " \
"please check your expert dataset"
self.num_traj = min(traj_limitation, np.sum(episode_starts))
self.num_transition = len(self.observations)
self.randomize = randomize
self.sequential_preprocessing = sequential_preprocessing
self.dataloader = None
self.train_loader = DataLoader(train_indices, self.observations, self.actions, batch_size,
shuffle=self.randomize, start_process=False,
sequential=sequential_preprocessing)
self.val_loader = DataLoader(val_indices, self.observations, self.actions, batch_size,
shuffle=self.randomize, start_process=False,
sequential=sequential_preprocessing)
if self.verbose >= 1:
self.log_info()
def init_dataloader(self, batch_size):
"""
Initialize the dataloader used by GAIL.
:param batch_size: (int)
"""
indices = np.random.permutation(len(self.observations)).astype(np.int64)
self.dataloader = DataLoader(indices, self.observations, self.actions, batch_size,
shuffle=self.randomize, start_process=False,
sequential=self.sequential_preprocessing)
def __del__(self):
# Exit processes if needed
for key in self.EXCLUDED_KEYS:
if self.__dict__.get(key) is not None:
del self.__dict__[key]
def __getstate__(self):
"""
Gets state for pickling.
Excludes processes that are not pickleable
"""
# Remove processes in order to pickle the dataset.
return {key: val for key, val in self.__dict__.items() if key not in self.EXCLUDED_KEYS}
def __setstate__(self, state):
"""
Restores pickled state.
init_dataloader() must be called
after unpickling before using it with GAIL.
:param state: (dict)
"""
self.__dict__.update(state)
for excluded_key in self.EXCLUDED_KEYS:
assert excluded_key not in state
self.dataloader = None
self.train_loader = None
self.val_loader = None
def log_info(self):
"""
Log the information of the dataset.
"""
logger.log("Total trajectories: {}".format(self.num_traj))
logger.log("Total transitions: {}".format(self.num_transition))
logger.log("Average returns: {}".format(self.avg_ret))
logger.log("Std for returns: {}".format(self.std_ret))
def get_next_batch(self, split=None):
"""
Get the batch from the dataset.
:param split: (str) the type of data split (can be None, 'train', 'val')
:return: (np.ndarray, np.ndarray) inputs and labels
"""
dataloader = {
None: self.dataloader,
'train': self.train_loader,
'val': self.val_loader
}[split]
if dataloader.process is None:
dataloader.start_process()
try:
return next(dataloader)
except StopIteration:
dataloader = iter(dataloader)
return next(dataloader)
def plot(self):
"""
Show histogram plotting of the episode returns
"""
# Isolate dependency since it is only used for plotting and also since
# different matplotlib backends have further dependencies themselves.
import matplotlib.pyplot as plt
plt.hist(self.returns)
plt.show()
class DataLoader(object):
"""
A custom dataloader to preprocessing observations (including images)
and feed them to the network.
Original code for the dataloader from https://github.com/araffin/robotics-rl-srl
(MIT licence)
Authors: Antonin Raffin, René Traoré, Ashley Hill
:param indices: ([int]) list of observations indices
:param observations: (np.ndarray) observations or images path
:param actions: (np.ndarray) actions
:param batch_size: (int) Number of samples per minibatch
:param n_workers: (int) number of preprocessing worker (for loading the images)
:param infinite_loop: (bool) whether to have an iterator that can be reset
:param max_queue_len: (int) Max number of minibatches that can be preprocessed at the same time
:param shuffle: (bool) Shuffle the minibatch after each epoch
:param start_process: (bool) Start the preprocessing process (default: True)
:param backend: (str) joblib backend (one of 'multiprocessing', 'sequential', 'threading'
or 'loky' in newest versions)
:param sequential: (bool) Do not use subprocess to preprocess the data
(slower but use less memory for the CI)
:param partial_minibatch: (bool) Allow partial minibatches (minibatches with a number of element
lesser than the batch_size)
"""
def __init__(self, indices, observations, actions, batch_size, n_workers=1,
infinite_loop=True, max_queue_len=1, shuffle=False,
start_process=True, backend='threading', sequential=False, partial_minibatch=True):
super(DataLoader, self).__init__()
self.n_workers = n_workers
self.infinite_loop = infinite_loop
self.indices = indices
self.original_indices = indices.copy()
self.n_minibatches = len(indices) // batch_size
# Add a partial minibatch, for instance
# when there is not enough samples
if partial_minibatch and len(indices) % batch_size > 0:
self.n_minibatches += 1
self.batch_size = batch_size
self.observations = observations
self.actions = actions
self.shuffle = shuffle
self.queue = Queue(max_queue_len)
self.process = None
self.load_images = isinstance(observations[0], str)
self.backend = backend
self.sequential = sequential
self.start_idx = 0
if start_process:
self.start_process()
def start_process(self):
"""Start preprocessing process"""
# Skip if in sequential mode
if self.sequential:
return
self.process = Process(target=self._run)
# Make it a deamon, so it will be deleted at the same time
# of the main process
self.process.daemon = True
self.process.start()
@property
def _minibatch_indices(self):
"""
Current minibatch indices given the current pointer
(start_idx) and the minibatch size
:return: (np.ndarray) 1D array of indices
"""
return self.indices[self.start_idx:self.start_idx + self.batch_size]
def sequential_next(self):
"""
Sequential version of the pre-processing.
"""
if self.start_idx > len(self.indices):
raise StopIteration
if self.start_idx == 0:
if self.shuffle:
# Shuffle indices
np.random.shuffle(self.indices)
obs = self.observations[self._minibatch_indices]
if self.load_images:
obs = np.concatenate([self._make_batch_element(image_path) for image_path in obs],
axis=0)
actions = self.actions[self._minibatch_indices]
self.start_idx += self.batch_size
return obs, actions
def _run(self):
start = True
with Parallel(n_jobs=self.n_workers, batch_size="auto", backend=self.backend) as parallel:
while start or self.infinite_loop:
start = False
if self.shuffle:
np.random.shuffle(self.indices)
for minibatch_idx in range(self.n_minibatches):
self.start_idx = minibatch_idx * self.batch_size
obs = self.observations[self._minibatch_indices]
if self.load_images:
if self.n_workers <= 1:
obs = [self._make_batch_element(image_path)
for image_path in obs]
else:
obs = parallel(delayed(self._make_batch_element)(image_path)
for image_path in obs)
obs = np.concatenate(obs, axis=0)
actions = self.actions[self._minibatch_indices]
self.queue.put((obs, actions))
# Free memory
del obs
self.queue.put(None)
@classmethod
def _make_batch_element(cls, image_path):
"""
Process one element.
:param image_path: (str) path to an image
:return: (np.ndarray)
"""
# cv2.IMREAD_UNCHANGED is needed to load
# grey and RGBa images
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
# Grey image
if len(image.shape) == 2:
image = image[:, :, np.newaxis]
if image is None:
raise ValueError("Tried to load {}, but it was not found".format(image_path))
# Convert from BGR to RGB
if image.shape[-1] == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.reshape((1,) + image.shape)
return image
def __len__(self):
return self.n_minibatches
def __iter__(self):
self.start_idx = 0
self.indices = self.original_indices.copy()
return self
def __next__(self):
if self.sequential:
return self.sequential_next()
if self.process is None:
raise ValueError("You must call .start_process() before using the dataloader")
while True:
try:
val = self.queue.get_nowait()
break
except queue.Empty:
time.sleep(0.001)
continue
if val is None:
raise StopIteration
return val
def __del__(self):
if self.process is not None:
self.process.terminate()