-
Notifications
You must be signed in to change notification settings - Fork 437
/
Copy pathstella_environment.cpp
300 lines (247 loc) · 10.5 KB
/
stella_environment.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
/* *****************************************************************************
* A.L.E (Arcade Learning Environment)
* Copyright (c) 2009-2013 by Yavar Naddaf, Joel Veness, Marc G. Bellemare and
* the Reinforcement Learning and Artificial Intelligence Laboratory
* Released under the GNU General Public License; see License.txt for details.
*
* Based on: Stella -- "An Atari 2600 VCS Emulator"
* Copyright (c) 1995-2007 by Bradford W. Mott and the Stella team
*
* *****************************************************************************
* stella_environment.cpp
*
* A class that wraps around the Stella core to provide users with a typical
* reinforcement learning environment interface.
*
**************************************************************************** */
#include "environment/stella_environment.hpp"
#include <sstream>
#include <cstring>
#include <optional>
#include "emucore/System.hxx"
namespace ale {
StellaEnvironment::StellaEnvironment(OSystem* osystem, RomSettings* settings)
: m_osystem(osystem),
m_settings(settings),
m_phosphor_blend(osystem),
m_screen(m_osystem->console().mediaSource().height(),
m_osystem->console().mediaSource().width()),
m_player_a_action(PLAYER_A_NOOP),
m_player_b_action(PLAYER_B_NOOP) {
// Determine whether this is a paddle-based game
if (m_osystem->console().properties().get(Controller_Left) == "PADDLES" ||
m_osystem->console().properties().get(Controller_Right) == "PADDLES") {
m_use_paddles = true;
int paddle_min_val = m_osystem->settings().getInt("paddle_min");
int paddle_max_val = m_osystem->settings().getInt("paddle_max");
m_state.setPaddleLimits(paddle_min_val != -1 ? paddle_min_val : PADDLE_MIN,
paddle_max_val != -1 ? paddle_max_val : PADDLE_MAX);
m_state.resetPaddles(m_osystem->event());
} else {
m_use_paddles = false;
}
m_num_reset_steps = 4;
m_cartridge_md5 = m_osystem->console().properties().get(Cartridge_MD5);
// Initialize RNG
int32_t seed;
if (m_osystem->settings().getInt("random_seed") == -1) {
seed = time(NULL);
m_random.seed((uint32_t)seed);
} else {
seed = m_osystem->settings().getInt("random_seed");
assert(seed >= 0);
m_random.seed((uint32_t)seed);
}
Logger::Info << "Random seed is " << seed << std::endl;
// Set current mode to the ROM's default mode
m_state.setCurrentMode(settings->getDefaultMode());
m_max_num_frames_per_episode =
m_osystem->settings().getInt("max_num_frames_per_episode");
m_colour_averaging = m_osystem->settings().getBool("color_averaging");
m_repeat_action_probability =
m_osystem->settings().getFloat("repeat_action_probability");
m_frame_skip = m_osystem->settings().getInt("frame_skip");
if (m_frame_skip < 1) {
Logger::Warning << "Warning: frame skip set to < 1. Setting to 1.\n";
m_frame_skip = 1;
}
// If so desired, we record all emulated frames to a given directory
std::string recordDir = m_osystem->settings().getString("record_screen_dir");
if (!recordDir.empty()) {
Logger::Info << "Recording screens to directory: " << recordDir << "\n";
// Create the screen exporter
m_screen_exporter.reset(
new ScreenExporter(m_osystem->colourPalette(), recordDir));
}
}
/** Resets the system to its start state. */
void StellaEnvironment::reset() {
m_state.resetEpisodeFrameNumber();
// Reset the paddles
m_state.resetPaddles(m_osystem->event());
// Reset the emulator
m_osystem->console().system().reset();
// NOOP for 60 steps in the deterministic environment setting, or some random amount otherwise
int noopSteps;
noopSteps = 60;
emulate(PLAYER_A_NOOP, PLAYER_B_NOOP, noopSteps);
// Reset the emulator
softReset();
// reset the rom (after emulating, in case the NOOPs led to reward)
m_settings->reset();
// Apply mode that was previously defined, then soft reset with this mode
m_settings->setMode(m_state.getCurrentMode(), m_osystem->console().system(),
getWrapper());
softReset();
// Apply necessary actions specified by the rom itself
ActionVect startingActions = m_settings->getStartingActions();
for (size_t i = 0; i < startingActions.size(); i++) {
emulate(startingActions[i], PLAYER_B_NOOP);
}
}
ALEState StellaEnvironment::cloneState(bool include_rng) {
std::optional<Random*> rng = include_rng ? std::make_optional(&m_random) : std::nullopt;
return m_state.save(m_osystem, m_settings, rng, m_cartridge_md5);
}
void StellaEnvironment::restoreState(const ALEState& target_state) {
m_state.load(m_osystem, m_settings, &m_random, m_cartridge_md5, target_state);
}
void StellaEnvironment::noopIllegalActions(Action& player_a_action,
Action& player_b_action) {
if (player_a_action < (Action)PLAYER_B_NOOP &&
!m_settings->isLegal(player_a_action)) {
player_a_action = (Action)PLAYER_A_NOOP;
}
// Also drop RESET, which doesn't play nice with our clean notions of RL environments
else if (player_a_action == RESET)
player_a_action = (Action)PLAYER_A_NOOP;
if (player_b_action < (Action)RESET &&
!m_settings->isLegal((Action)((int)player_b_action - PLAYER_B_NOOP))) {
player_b_action = (Action)PLAYER_B_NOOP;
} else if (player_b_action == RESET)
player_b_action = (Action)PLAYER_B_NOOP;
}
reward_t StellaEnvironment::act(Action player_a_action,
Action player_b_action) {
// Total reward received as we repeat the action
reward_t sum_rewards = 0;
Random& rng = getEnvironmentRNG();
// Apply the same action for a given number of times... note that act() will refuse to emulate
// past the terminal state
for (size_t i = 0; i < m_frame_skip; i++) {
// Stochastically drop actions, according to m_repeat_action_probability
if (rng.nextDouble() >= m_repeat_action_probability)
m_player_a_action = player_a_action;
// @todo Possibly optimize by avoiding call to rand() when player B is "off" ?
if (rng.nextDouble() >= m_repeat_action_probability)
m_player_b_action = player_b_action;
// If so desired, request one frame's worth of sound (this does nothing if recording
// is not enabled)
m_osystem->sound().recordNextFrame();
// Render screen if we're displaying it
m_osystem->screen().render();
// Similarly record screen as needed
if (m_screen_exporter.get() != NULL)
m_screen_exporter->saveNext(m_screen);
// Use the stored actions, which may or may not have changed this frame
sum_rewards += oneStepAct(m_player_a_action, m_player_b_action);
}
return sum_rewards;
}
/** This functions emulates a push on the reset button of the console */
void StellaEnvironment::softReset() {
emulate(RESET, PLAYER_B_NOOP, m_num_reset_steps);
// Reset previous actions to NOOP for correct action repeating
m_player_a_action = PLAYER_A_NOOP;
m_player_b_action = PLAYER_B_NOOP;
}
/** Applies the given actions (e.g. updating paddle positions when the paddle is used)
* and performs one simulation step in Stella. */
reward_t StellaEnvironment::oneStepAct(Action player_a_action,
Action player_b_action) {
// Once in a terminal state, refuse to go any further (special actions must be handled
// outside of this environment; in particular reset() should be called rather than passing
// RESET or SYSTEM_RESET.
if (isTerminal())
return 0;
// Convert illegal actions into NOOPs; actions such as reset are always legal
noopIllegalActions(player_a_action, player_b_action);
// Emulate in the emulator
emulate(player_a_action, player_b_action);
// Increment the number of frames seen so far
m_state.incrementFrame();
return m_settings->getReward();
}
bool StellaEnvironment::isTerminal() const {
return (m_settings->isTerminal() ||
(m_max_num_frames_per_episode > 0 &&
m_state.getEpisodeFrameNumber() >= m_max_num_frames_per_episode));
}
void StellaEnvironment::pressSelect(size_t num_steps) {
m_state.pressSelect(m_osystem->event());
for (size_t t = 0; t < num_steps; t++) {
m_osystem->console().mediaSource().update();
}
processScreen();
processRAM();
emulate(PLAYER_A_NOOP, PLAYER_B_NOOP);
m_state.incrementFrame();
}
void StellaEnvironment::setDifficulty(difficulty_t value) {
m_state.setDifficulty(value);
}
void StellaEnvironment::setMode(game_mode_t value) {
m_state.setCurrentMode(value);
}
void StellaEnvironment::emulate(Action player_a_action, Action player_b_action,
size_t num_steps) {
Event* event = m_osystem->event();
// Handle paddles separately: we have to manually update the paddle positions at each step
if (m_use_paddles) {
// Run emulator forward for 'num_steps'
for (size_t t = 0; t < num_steps; t++) {
// Update paddle position at every step
m_state.applyActionPaddles(event, player_a_action, player_b_action);
m_osystem->console().mediaSource().update();
m_settings->step(m_osystem->console().system());
}
} else {
// In joystick mode we only need to set the action events once
m_state.setActionJoysticks(event, player_a_action, player_b_action);
for (size_t t = 0; t < num_steps; t++) {
m_osystem->console().mediaSource().update();
m_settings->step(m_osystem->console().system());
}
}
// Parse screen and RAM into their respective data structures
processScreen();
processRAM();
}
/** Accessor methods for the environment state. */
void StellaEnvironment::setState(const ALEState& state) { m_state = state; }
const ALEState& StellaEnvironment::getState() const { return m_state; }
std::unique_ptr<StellaEnvironmentWrapper> StellaEnvironment::getWrapper() {
return std::unique_ptr<StellaEnvironmentWrapper>(
new StellaEnvironmentWrapper(*this));
}
void StellaEnvironment::processScreen() {
if (m_colour_averaging) {
// Perform phosphor averaging; the blender stores its result in the given screen
m_phosphor_blend.process(m_screen);
} else {
// Copy screen over and we're done!
std::memcpy(m_screen.getArray(),
m_osystem->console().mediaSource().currentFrameBuffer(),
m_screen.arraySize());
}
}
void StellaEnvironment::processRAM() {
// Copy RAM over
for (size_t i = 0; i < m_ram.size(); i++)
*m_ram.byte(i) = m_osystem->console().system().peek(i + 0x80);
}
void StellaEnvironment::setRAM(size_t memory_index, byte_t value) {
m_osystem->console().system().poke(memory_index + 0x80, value);
*m_ram.byte(memory_index) = value;
}
} // namespace ale