-
Notifications
You must be signed in to change notification settings - Fork 22
/
wrappers.py
52 lines (44 loc) · 1.96 KB
/
wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gym
import numpy as np
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
# https://stable-baselines3.readthedocs.io/en/master/_modules/stable_baselines3/common/atari_wrappers.html
class ScaledFloatFrame(gym.ObservationWrapper):
def __init__(self, env):
gym.ObservationWrapper.__init__(self, env)
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
def observation(self, observation):
# This undoes the memory optimization, use with smaller replay buffers only.
return np.array(observation).astype(np.float32) / 255.0
class FireEpisodicLifeEnv(gym.Wrapper):
def __init__(self, env):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
def step(self, action):
obs, reward, done, info = self.env.step(action)
# Check current lives, make loss of life terminal, then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if self.lives > lives > 0:
# for Qbert sometimes we stay in lives == 0 condtion for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
# done = True
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
self.lives = lives
return obs, reward, done, info
def reset(self, **kwargs):
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset(**kwargs)
return obs