-
Notifications
You must be signed in to change notification settings - Fork 151
/
Copy pathatari.py
155 lines (123 loc) · 5.09 KB
/
atari.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from copy import deepcopy
from collections import deque
import gym
from mushroom_rl.core import Environment, MDPInfo
from mushroom_rl.utils.spaces import *
from mushroom_rl.utils.frames import LazyFrames, preprocess_frame
class MaxAndSkip(gym.Wrapper):
def __init__(self, env, skip, max_pooling=True):
gym.Wrapper.__init__(self, env)
self._obs_buffer = np.zeros((2,) + env.observation_space.shape,
dtype=np.uint8)
self._skip = skip
self._max_pooling = max_pooling
def reset(self):
return self.env.reset()
def step(self, action):
total_reward = 0.
for i in range(self._skip):
obs, reward, absorbing, info = self.env.step(action)
if i == self._skip - 2:
self._obs_buffer[0] = obs
if i == self._skip - 1:
self._obs_buffer[1] = obs
total_reward += reward
if absorbing:
break
if self._max_pooling:
frame = self._obs_buffer.max(axis=0)
else:
frame = self._obs_buffer.mean(axis=0)
return frame, total_reward, absorbing, info
def reset(self, **kwargs):
return self.env.reset(**kwargs)
class Atari(Environment):
"""
The Atari environment as presented in:
"Human-level control through deep reinforcement learning". Mnih et. al..
2015.
"""
def __init__(self, name, width=84, height=84, ends_at_life=False,
max_pooling=True, history_length=4, max_no_op_actions=30):
"""
Constructor.
Args:
name (str): id name of the Atari game in Gym;
width (int, 84): width of the screen;
height (int, 84): height of the screen;
ends_at_life (bool, False): whether the episode ends when a life is
lost or not;
max_pooling (bool, True): whether to do max-pooling or
average-pooling of the last two frames when using NoFrameskip;
history_length (int, 4): number of frames to form a state;
max_no_op_actions (int, 30): maximum number of no-op action to
execute at the beginning of an episode.
"""
# MPD creation
if 'NoFrameskip' in name:
self.env = MaxAndSkip(gym.make(name), history_length, max_pooling)
else:
self.env = gym.make(name)
# MDP parameters
self._img_size = (width, height)
self._episode_ends_at_life = ends_at_life
self._max_lives = self.env.unwrapped.ale.lives()
self._lives = self._max_lives
self._force_fire = None
self._real_reset = True
self._max_no_op_actions = max_no_op_actions
self._history_length = history_length
self._current_no_op = None
assert self.env.unwrapped.get_action_meanings()[0] == 'NOOP'
# MDP properties
action_space = Discrete(self.env.action_space.n)
observation_space = Box(
low=0., high=255., shape=(history_length, self._img_size[1], self._img_size[0]))
horizon = np.inf # the gym time limit is used.
gamma = .99
mdp_info = MDPInfo(observation_space, action_space, gamma, horizon)
super().__init__(mdp_info)
def reset(self, state=None):
if self._real_reset:
self._state = preprocess_frame(self.env.reset(), self._img_size)
self._state = deque([deepcopy(
self._state) for _ in range(self._history_length)],
maxlen=self._history_length
)
self._lives = self._max_lives
self._force_fire = self.env.unwrapped.get_action_meanings()[1] == 'FIRE'
self._current_no_op = np.random.randint(self._max_no_op_actions + 1)
return LazyFrames(list(self._state), self._history_length)
def step(self, action):
action = action[0]
# Force FIRE action to start episodes in games with lives
if self._force_fire:
obs, _, _, _ = self.env.env.step(1)
self._force_fire = False
while self._current_no_op > 0:
obs, _, _, _ = self.env.env.step(0)
self._current_no_op -= 1
obs, reward, absorbing, info = self.env.step(action)
self._real_reset = absorbing
if info['lives'] != self._lives:
if self._episode_ends_at_life:
absorbing = True
self._lives = info['lives']
self._force_fire = self.env.unwrapped.get_action_meanings()[
1] == 'FIRE'
self._state.append(preprocess_frame(obs, self._img_size))
return LazyFrames(list(self._state),
self._history_length), reward, absorbing, info
def render(self, mode='human'):
self.env.render(mode=mode)
def stop(self):
self.env.close()
self._real_reset = True
def set_episode_end(self, ends_at_life):
"""
Setter.
Args:
ends_at_life (bool): whether the episode ends when a life is
lost or not.
"""
self._episode_ends_at_life = ends_at_life