Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
dgriff777 authored Nov 1, 2018
1 parent 5bf15eb commit 474f480
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,30 +159,31 @@ def reset(self, **kwargs):


class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env, skip=4):
def __init__(self, env=None, skip=4):
"""Return only every `skip`-th frame"""
gym.Wrapper.__init__(self, env)
super(MaxAndSkipEnv, self).__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = np.zeros(
(2, ) + env.observation_space.shape, dtype=np.uint8)
self._obs_buffer = deque(maxlen=3)
self._skip = skip

def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for i in range(self._skip):
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
if i == self._skip - 2: self._obs_buffer[0] = obs
if i == self._skip - 1: self._obs_buffer[1] = obs
self._obs_buffer.append(obs)
total_reward += reward
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0)

max_frame = np.max(np.stack(self._obs_buffer), axis=0)

return max_frame, total_reward, done, info

def reset(self, **kwargs):
return self.env.reset(**kwargs)
def reset(self):
"""Clear past frame buffer and init. to first obs. from inner env."""
self._obs_buffer.clear()
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs

0 comments on commit 474f480

Please sign in to comment.