import numpy as np class ReplayMemory(): def __init__(self, max_mem_len, input_shape, n_actions): self.mem_len = max_mem_len self.states = np.zeros((self.mem_len, *input_shape), dtype=np.float32) self.actions = np.zeros(self.mem_len, dtype=np.int64) self.rewards = np.zeros(self.mem_len, dtype=np.float32) self._states = np.zeros((self.mem_len, *input_shape), dtype=np.float32) self.dones = np.zeros(self.mem_len, dtype=np.bool) self.next_available_memory = 0 def store_transition(self, state, action, reward, _state, done): memory = self.next_available_memory % self.mem_len self.states[memory] = state self.actions[memory] = action self.rewards[memory] = reward self._states[memory] = _state self.dones[memory] = done self.next_available_memory += 1 def minibatch(self, minibatch_size=32): if self.next_available_memory >= self.mem_len: most_mems = self.mem_len else: most_mems = self.next_available_memory memories = np.random.choice(most_mems, minibatch_size, replace=False) return self.states[memories], self.actions[memories], self.rewards[memories], self._states[memories], self.dones[memories]