Skip to content

Commit

Permalink
mario experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
DenSumy committed Jan 26, 2024
1 parent cb13436 commit f3b5f04
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 89 deletions.
2 changes: 1 addition & 1 deletion rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def __init__(self, params, **kwargs):
rnn_in_size += actions_num

self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers)
#self.layer_norm = torch.nn.LayerNorm(self.rnn_units)
self.layer_norm = torch.nn.LayerNorm(self.rnn_units)

mlp_args = {
'input_size' : mlp_input_size,
Expand Down
3 changes: 3 additions & 0 deletions rl_games/common/ivecenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ def get_env_state(self):

def set_env_state(self, env_state):
pass

def render(self, mode, **kwargs):
pass
11 changes: 9 additions & 2 deletions rl_games/common/vecenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def seed(self, seed):
random.seed(seed)
self.env.seed(seed)

def render(self):
self.env.render()
def render(self, **kwargs):
self.env.render(**kwargs)

def reset(self):
obs = self.env.reset()
Expand All @@ -72,6 +72,9 @@ def can_concat_infos(self):
else:
return False

def render(self, mode, **kwargs):
self.env.render(mode, **kwargs)

def get_env_info(self):
info = {}
observation_space = self.env.observation_space
Expand Down Expand Up @@ -215,6 +218,10 @@ def reset(self):
ret_obs = newobsdict
return ret_obs

def render(self, mode, **kwargs):
res = self.workers[0].render.remote(mode, **kwargs)
return self.ray.get(res)

vecenv_config = {}

def register(config_name, func):
Expand Down
74 changes: 32 additions & 42 deletions rl_games/common/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@ def step(self, action):
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
if lives < self.lives:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so it's important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
elif lives > self.lives:
# do not allow use of bonus life
self.lives = lives
return obs, reward, done, info

def reset(self, **kwargs):
Expand All @@ -115,20 +117,21 @@ def reset(self, **kwargs):
return obs

class EpisodicLifeMarioEnv(gym.Wrapper):
def __init__(self, env):
def __init__(self, env, max_lives):
"""Make end-of-life == end-of-episode, but only reset on True game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True
self.max_lives = max_lives

def step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.env._life
lives = self.env.unwrapped._life
if lives < self.lives:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so it's important to keep lives > 0, so that we only reset once
Expand All @@ -146,42 +149,7 @@ def reset(self, **kwargs):
"""
if self.was_real_done:
obs = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.env._life
return obs

class EpisodicLifeMarioEnv(gym.Wrapper):
def __init__(self, env):
"""Make end-of-life == end-of-episode, but only reset on True game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True

def step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped._life
if lives < self.lives and lives > 0:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so it's important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info

def reset(self, **kwargs):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset(**kwargs)
self.env.unwrapped.ram[0x075a] = self.max_lives
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
Expand Down Expand Up @@ -225,6 +193,28 @@ def reset(self, **kwargs):
self.lives = self.env.unwrapped.env._life
return obs

class PreventSlugEnv(gym.Wrapper):
def __init__(self, env, max_no_rewards=10000):
"""Abort if too much time without getting reward."""
MyWrapper.__init__(self, env)
self.last_reward = 0
self.steps = 0
self.max_no_rewards = max_no_rewards
self.got_reward = False

def step(self, *args, **kwargs):
obs, reward, done, info = self.env.step(*args, **kwargs)
self.steps += 1
if reward > 0:
self.last_reward = self.steps
if self.steps - self.last_reward > self.max_no_rewards:
done = True
return obs, reward, done, info

def reset(self):
self.got_reward = False
self.steps = 0
return self.env.reset()

class EpisodeStackedEnv(gym.Wrapper):
def __init__(self, env):
Expand Down Expand Up @@ -791,14 +781,14 @@ def make_atari(env_id, timelimit=True, noop_max=0, skip=4, sticky=False, directo
return env


def wrap_deepmind(env, episode_life=False, clip_rewards=True, frame_stack=True, scale=False, wrap_impala=False):
def wrap_deepmind(env, episode_life=False, clip_rewards=True, frame_stack=True, scale=False, wrap_impala=False, gray=True):
"""Configure environment for DeepMind-style Atari.
"""
if episode_life:
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
env = WarpFrame(env, grayscale=gray)
if scale:
env = ScaledFloatFrame(env)
if clip_rewards:
Expand Down
41 changes: 20 additions & 21 deletions rl_games/configs/mario/mario_resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,55 +20,54 @@ params:
activation: relu
initializer:
name: default
regularizer:
name: 'None'


mlp:
units: [512]
activation: relu
regularizer:
name: 'None'
initializer:
name: default
name: orthogonal_initializer
gain: 1.41421356237
rnn:
name: lstm
units: 256
layers: 1
#layer_norm: True
config:
reward_shaper:
min_val: -1
max_val: 1
scale_value: 1

normalize_advantage: True
gamma: 0.995
tau: 0.95
learning_rate: 3e-4
learning_rate: 5e-4
name: mario_resnet
score_to_win: 100000
grad_norm: 1.5
entropy_coef: 0.01
grad_norm: 1.0
entropy_coef: 0.005
truncate_grads: True
env_name: MarioEnv
e_clip: 0.2
clip_value: True
num_actors: 16
horizon_length: 256
minibatch_size: 2048
mini_epochs: 2
critic_coef: 1
lr_schedule: None
kl_threshold: 0.01
horizon_length: 512
minibatch_size: 4096
mini_epochs: 3
critic_coef: 2
lr_schedule: None #adaptive
kl_threshold: 0.008
normalize_input: False
normalize_value: True
use_diagnostics: True
seq_length: 32
max_epochs: 200000

weight_decay: 0.0001
save_frequency: 50
env_config:
use_dict_obs_space: True

player:
render: False
render: True
games_num: 20
n_game_life: 5
deterministic: True

deterministic: False
use_vecenv: True
46 changes: 29 additions & 17 deletions rl_games/configs/mario/mario_v1_random.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,45 @@ params:
name: discrete_a2c

network:
name: resnet_actor_critic
require_rewards: False
require_last_actions: False
name: actor_critic
separate: False
value_shape: 1
space:
discrete:

cnn:
conv_depths: [32, 64, 128, 256]
activation: relu
#permute_input: False
type: conv2d
activation: elu
initializer:
name: default
#name: glorot_normal_initializer
#gain: 1.4142
regularizer:
name: 'None'
name: None
convs:
- filters: 32
kernel_size: 8
strides: 4
padding: 0
- filters: 64
kernel_size: 4
strides: 2
padding: 0
- filters: 64
kernel_size: 3
strides: 1
padding: 0

mlp:
units: [512]
activation: relu
regularizer:
name: 'None'
activation: elu
initializer:
name: default
name: orthogonal_initializer
gain: 1.41421356237

config:
name: mario_ray
env_name: 'SuperMarioBrosRandomStages-v1'
env_name: MarioEnv
score_to_win: 100500
normalize_value: True
normalize_input: False
Expand All @@ -53,8 +65,8 @@ params:
entropy_coef: 0.01
e_clip: 0.2
clip_value: False
num_actors: 64
horizon_length: 128
num_actors: 16
horizon_length: 512
# seq_length: 8
minibatch_size: 4096
mini_epochs: 4
Expand All @@ -65,8 +77,8 @@ params:

player:
render: True
games_num: 1
n_game_life: 1
games_num: 2
n_game_life: 16
deterministic: False
use_vecenv: False
render_sleep: 0.05
14 changes: 8 additions & 6 deletions rl_games/envs/mario.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
class MarioEnv(gym.Env):
def __init__(self, **kwargs):
env_name=kwargs.pop('env_name', 'SuperMarioBros-v1')
self.has_lives = kwargs.pop('has_lives', True)
self.max_lives = kwargs.pop('max_lives', 16)
self.movement = kwargs.pop('movement', 'SIMPLE')
self.use_dict_obs_space = kwargs.pop('use_dict_obs_space', False)
Expand All @@ -26,31 +25,31 @@ def _create_super_mario_env(self, name='SuperMarioBros-v1'):
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT
import gym_super_mario_bros
movement = SIMPLE_MOVEMENT if name == 'SIMPLE' else COMPLEX_MOVEMENT
movement = SIMPLE_MOVEMENT if self.movement == 'SIMPLE' else COMPLEX_MOVEMENT
env = gym_super_mario_bros.make(name)
env = JoypadSpace(env, movement)
if 'Random' in name:
env = wrappers.EpisodicLifeRandomMarioEnv(env)
else:
env = wrappers.EpisodicLifeMarioEnv(env)
env = wrappers.EpisodicLifeMarioEnv(env, self.max_lives)
env = wrappers.MaxAndSkipEnv(env, skip=4)
env = wrappers.wrap_deepmind(
env, episode_life=False, clip_rewards=False, frame_stack=True, scale=True)
env, episode_life=False, clip_rewards=False, frame_stack=True, scale=True, gray=False)
return env

def step(self, action):
next_obs, reward, is_done, info = self.env.step(action)
if self.use_dict_obs_space:
next_obs = {
'observation': next_obs,
'reward': np.clip(np.array(reward, dtype=float), -1, 1),
'reward': np.array(reward, dtype=float),
'last_action': np.array(action, dtype=int)
}
return next_obs, reward, is_done, info

def reset(self):
obs = self.env.reset()
self.env.unwrapped.ram[0x075a] = self.max_lives

if self.use_dict_obs_space:
obs = {
'observation': obs,
Expand All @@ -59,5 +58,8 @@ def reset(self):
}
return obs

def render(self, mode, **kwargs):
self.env.render(mode, **kwargs)

def get_number_of_agents(self):
return 1

0 comments on commit f3b5f04

Please sign in to comment.