import numpy as np import random class ER(object): def __init__(self, memory_size, state_dim, action_dim, reward_dim, qpos_dim, qvel_dim, batch_size, history_length=1): self.memory_size = memory_size self.actions = np.random.normal(scale=0.35, size=(self.memory_size, action_dim)) self.rewards = np.random.normal(scale=0.35, size=(self.memory_size, )) self.states = np.random.normal(scale=0.35, size=(self.memory_size,) + state_dim) self.qpos = np.random.normal(scale=0.35, size=(self.memory_size, qpos_dim)) self.qvel = np.random.normal(scale=0.35, size=(self.memory_size, qvel_dim)) self.terminals = np.zeros(self.memory_size, dtype=np.float32) self.batch_size = batch_size self.history_length = history_length self.count = 0 self.current = 0 self.state_dim = state_dim self.action_dim = action_dim # pre-allocate prestates and poststates for minibatch self.prestates = np.empty((self.batch_size, self.history_length) + state_dim, dtype=np.float32) self.poststates = np.empty((self.batch_size, self.history_length) + state_dim, dtype=np.float32) self.traj_length = 2 self.traj_states = np.empty((self.batch_size, self.traj_length) + state_dim, dtype=np.float32) self.traj_actions = np.empty((self.batch_size, self.traj_length-1, action_dim), dtype=np.float32) def add(self, actions, rewards, next_states, terminals, qposs=[], qvels = []): # state is post-state, after action and reward for idx in range(len(actions)): self.actions[self.current, ...] = actions[idx] self.rewards[self.current] = rewards[idx] self.states[self.current, ...] = next_states[idx] self.terminals[self.current] = terminals[idx] if len(qposs) == len(actions): self.qpos[self.current, ...] = qposs[idx] self.qvel[self.current, ...] = qvels[idx] self.count = max(self.count, self.current + 1) self.current = (self.current + 1) % self.memory_size def get_state(self, index): assert self.count > 0, "replay memory is empty" # normalize index to expected range, allows negative indexes index = index % self.count # if is not in the beginning of matrix if index >= self.history_length - 1: # use faster slicing return self.states[(index - (self.history_length - 1)):(index + 1), ...] else: # otherwise normalize indexes and use slower list based access indexes = [(index - i) % self.count for i in reversed(range(self.history_length))] return self.states[indexes, ...] def sample(self, indexes=None): # memory must include poststate, prestate and history assert self.count > self.history_length if indexes is None: # sample random indexes indexes = [] while len(indexes) < self.batch_size: # find random index while True: # sample one index (ignore states wraping over index = random.randint(self.history_length, self.count - 1) # if wraps over current pointer, then get new one if index >= self.current > index - self.history_length: continue # if wraps over episode end, then get new one # poststate (last screen) can be terminal state! if self.terminals[(index - self.history_length):index].any(): continue # otherwise use this index break # having index first is fastest in C-order matrices self.prestates[len(indexes), ...] = self.get_state(index - 1) self.poststates[len(indexes), ...] = self.get_state(index) indexes.append(index) actions = self.actions[indexes, ...] rewards = self.rewards[indexes, ...] if hasattr(self, 'qpos'): qpos = self.qpos[indexes, ...] qvels = self.qvel[indexes, ...] else: qpos = [] qvels = [] terminals = self.terminals[indexes] return np.squeeze(self.prestates, axis=1), actions, rewards, \ np.squeeze(self.poststates, axis=1), terminals, qpos, qvels