Skip to content

Commit

Permalink
merio petuh
Browse files Browse the repository at this point in the history
  • Loading branch information
DenSumy committed Dec 31, 2023
1 parent c153721 commit cb13436
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 4 deletions.
1 change: 1 addition & 0 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ def forward(self, obs_dict):
obs_list.append(reward.unsqueeze(1))
if self.require_last_actions:
obs_list.append(last_action)

out = torch.cat(obs_list, dim=1)
batch_size = out.size()[0]
num_seqs = batch_size // seq_length
Expand Down
8 changes: 8 additions & 0 deletions rl_games/common/env_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ def create_env(name, **kwargs):
env = wrappers.TimeLimit(env, steps_limit)
return env

def create_mario_env(**kwargs):
import gym
import rl_games.envs.mario
return gym.make('MarioEnv-v0',**kwargs)

configurations = {
'CartPole-v1': {
Expand Down Expand Up @@ -450,6 +454,10 @@ def create_env(name, **kwargs):
'env_creator': lambda **kwargs: create_cule(**kwargs),
'vecenv_type': 'CULE'
},
'MarioEnv': {
'env_creator': lambda **kwargs: create_mario_env(**kwargs),
'vecenv_type': 'RAY'
},
}


Expand Down
2 changes: 1 addition & 1 deletion rl_games/common/vecenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def step(self, actions):
ret_obs = dicts_to_dict_with_arrays(newobs, self.num_agents == 1)
else:
ret_obs = self.concat_func(newobs)

if self.use_global_obs:
newobsdict = {}
newobsdict["obs"] = ret_obs
Expand Down
6 changes: 4 additions & 2 deletions rl_games/common/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,14 @@ def step(self, action):
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.env._life
if lives < self.lives and lives > 0:
if lives < self.lives:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so it's important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
elif lives > self.lives:
# do not allow use of bonus life
self.lives = lives
return obs, reward, done, info

def reset(self, **kwargs):
Expand Down
74 changes: 74 additions & 0 deletions rl_games/configs/mario/mario_resnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
params:
algo:
name: a2c_discrete

model:
name: discrete_a2c

network:
name: resnet_actor_critic
require_rewards: True
require_last_actions: True
separate: False
value_shape: 1
space:
discrete:

cnn:
permute_input: True
conv_depths: [16, 32, 32]
activation: relu
initializer:
name: default
regularizer:
name: 'None'

mlp:
units: [512]
activation: relu
regularizer:
name: 'None'
initializer:
name: default
rnn:
name: lstm
units: 256
layers: 1
config:
reward_shaper:
min_val: -1
max_val: 1

normalize_advantage: True
gamma: 0.995
tau: 0.95
learning_rate: 3e-4
name: mario_resnet
score_to_win: 100000
grad_norm: 1.5
entropy_coef: 0.01
truncate_grads: True
env_name: MarioEnv
e_clip: 0.2
clip_value: True
num_actors: 16
horizon_length: 256
minibatch_size: 2048
mini_epochs: 2
critic_coef: 1
lr_schedule: None
kl_threshold: 0.01
normalize_input: False
use_diagnostics: True
seq_length: 32
max_epochs: 200000

env_config:
use_dict_obs_space: True

player:
render: False
games_num: 20
n_game_life: 5
deterministic: True

9 changes: 8 additions & 1 deletion rl_games/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,11 @@
from rl_games.envs.test_network import TestNetBuilder
from rl_games.algos_torch import model_builder

model_builder.register_network('testnet', TestNetBuilder)
model_builder.register_network('testnet', TestNetBuilder)

import gym

gym.envs.register(
id='MarioEnv-v0',
entry_point='rl_games.envs.mario:MarioEnv'
)
63 changes: 63 additions & 0 deletions rl_games/envs/mario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import gym
import numpy as np
from rl_games.common import wrappers

class MarioEnv(gym.Env):
def __init__(self, **kwargs):
env_name=kwargs.pop('env_name', 'SuperMarioBros-v1')
self.has_lives = kwargs.pop('has_lives', True)
self.max_lives = kwargs.pop('max_lives', 16)
self.movement = kwargs.pop('movement', 'SIMPLE')
self.use_dict_obs_space = kwargs.pop('use_dict_obs_space', False)
self.env = self._create_super_mario_env(env_name)
if self.use_dict_obs_space:
self.observation_space= gym.spaces.Dict({
'observation' : self.env.observation_space,
'reward' : gym.spaces.Box(low=0, high=1, shape=( ), dtype=np.float32),
'last_action': gym.spaces.Box(low=0, high=self.env.action_space.n, shape=(), dtype=int)
})
else:
self.observation_space = self.env.observation_space

self.action_space = self.env.action_space


def _create_super_mario_env(self, name='SuperMarioBros-v1'):
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT
import gym_super_mario_bros
movement = SIMPLE_MOVEMENT if name == 'SIMPLE' else COMPLEX_MOVEMENT
env = gym_super_mario_bros.make(name)
env = JoypadSpace(env, movement)
if 'Random' in name:
env = wrappers.EpisodicLifeRandomMarioEnv(env)
else:
env = wrappers.EpisodicLifeMarioEnv(env)
env = wrappers.MaxAndSkipEnv(env, skip=4)
env = wrappers.wrap_deepmind(
env, episode_life=False, clip_rewards=False, frame_stack=True, scale=True)
return env

def step(self, action):
next_obs, reward, is_done, info = self.env.step(action)
if self.use_dict_obs_space:
next_obs = {
'observation': next_obs,
'reward': np.clip(np.array(reward, dtype=float), -1, 1),
'last_action': np.array(action, dtype=int)
}
return next_obs, reward, is_done, info

def reset(self):
obs = self.env.reset()
self.env.unwrapped.ram[0x075a] = self.max_lives
if self.use_dict_obs_space:
obs = {
'observation': obs,
'reward': np.array(0.0, dtype=float),
'last_action': np.array(0, dtype=int),
}
return obs

def get_number_of_agents(self):
return 1

0 comments on commit cb13436

Please sign in to comment.