From cb13436c8a9212b63975f8be0b8e16f96db06783 Mon Sep 17 00:00:00 2001 From: Denys Makoviichuk Date: Sun, 31 Dec 2023 17:10:05 -0500 Subject: [PATCH] merio petuh --- rl_games/algos_torch/network_builder.py | 1 + rl_games/common/env_configurations.py | 8 +++ rl_games/common/vecenv.py | 2 +- rl_games/common/wrappers.py | 6 +- rl_games/configs/mario/mario_resnet.yaml | 74 ++++++++++++++++++++++++ rl_games/envs/__init__.py | 9 ++- rl_games/envs/mario.py | 63 ++++++++++++++++++++ 7 files changed, 159 insertions(+), 4 deletions(-) create mode 100644 rl_games/configs/mario/mario_resnet.yaml create mode 100644 rl_games/envs/mario.py diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index ab047920..0eba37b0 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -740,6 +740,7 @@ def forward(self, obs_dict): obs_list.append(reward.unsqueeze(1)) if self.require_last_actions: obs_list.append(last_action) + out = torch.cat(obs_list, dim=1) batch_size = out.size()[0] num_seqs = batch_size // seq_length diff --git a/rl_games/common/env_configurations.py b/rl_games/common/env_configurations.py index 8e6f00c7..95e59be2 100644 --- a/rl_games/common/env_configurations.py +++ b/rl_games/common/env_configurations.py @@ -276,6 +276,10 @@ def create_env(name, **kwargs): env = wrappers.TimeLimit(env, steps_limit) return env +def create_mario_env(**kwargs): + import gym + import rl_games.envs.mario + return gym.make('MarioEnv-v0',**kwargs) configurations = { 'CartPole-v1': { @@ -450,6 +454,10 @@ def create_env(name, **kwargs): 'env_creator': lambda **kwargs: create_cule(**kwargs), 'vecenv_type': 'CULE' }, + 'MarioEnv': { + 'env_creator': lambda **kwargs: create_mario_env(**kwargs), + 'vecenv_type': 'RAY' + }, } diff --git a/rl_games/common/vecenv.py b/rl_games/common/vecenv.py index 01016723..66255ac9 100644 --- a/rl_games/common/vecenv.py +++ b/rl_games/common/vecenv.py @@ -156,7 +156,7 @@ def step(self, actions): ret_obs = dicts_to_dict_with_arrays(newobs, self.num_agents == 1) else: ret_obs = self.concat_func(newobs) - + if self.use_global_obs: newobsdict = {} newobsdict["obs"] = ret_obs diff --git a/rl_games/common/wrappers.py b/rl_games/common/wrappers.py index b1262a17..2e9355d6 100644 --- a/rl_games/common/wrappers.py +++ b/rl_games/common/wrappers.py @@ -129,12 +129,14 @@ def step(self, action): # check current lives, make loss of life terminal, # then update lives to handle bonus lives lives = self.env.unwrapped.env._life - if lives < self.lives and lives > 0: + if lives < self.lives: # for Qbert sometimes we stay in lives == 0 condition for a few frames # so it's important to keep lives > 0, so that we only reset once # the environment advertises done. done = True - self.lives = lives + elif lives > self.lives: + # do not allow use of bonus life + self.lives = lives return obs, reward, done, info def reset(self, **kwargs): diff --git a/rl_games/configs/mario/mario_resnet.yaml b/rl_games/configs/mario/mario_resnet.yaml new file mode 100644 index 00000000..6ad79a54 --- /dev/null +++ b/rl_games/configs/mario/mario_resnet.yaml @@ -0,0 +1,74 @@ +params: + algo: + name: a2c_discrete + + model: + name: discrete_a2c + + network: + name: resnet_actor_critic + require_rewards: True + require_last_actions: True + separate: False + value_shape: 1 + space: + discrete: + + cnn: + permute_input: True + conv_depths: [16, 32, 32] + activation: relu + initializer: + name: default + regularizer: + name: 'None' + + mlp: + units: [512] + activation: relu + regularizer: + name: 'None' + initializer: + name: default + rnn: + name: lstm + units: 256 + layers: 1 + config: + reward_shaper: + min_val: -1 + max_val: 1 + + normalize_advantage: True + gamma: 0.995 + tau: 0.95 + learning_rate: 3e-4 + name: mario_resnet + score_to_win: 100000 + grad_norm: 1.5 + entropy_coef: 0.01 + truncate_grads: True + env_name: MarioEnv + e_clip: 0.2 + clip_value: True + num_actors: 16 + horizon_length: 256 + minibatch_size: 2048 + mini_epochs: 2 + critic_coef: 1 + lr_schedule: None + kl_threshold: 0.01 + normalize_input: False + use_diagnostics: True + seq_length: 32 + max_epochs: 200000 + + env_config: + use_dict_obs_space: True + + player: + render: False + games_num: 20 + n_game_life: 5 + deterministic: True + diff --git a/rl_games/envs/__init__.py b/rl_games/envs/__init__.py index 6883b34a..3623831b 100644 --- a/rl_games/envs/__init__.py +++ b/rl_games/envs/__init__.py @@ -3,4 +3,11 @@ from rl_games.envs.test_network import TestNetBuilder from rl_games.algos_torch import model_builder -model_builder.register_network('testnet', TestNetBuilder) \ No newline at end of file +model_builder.register_network('testnet', TestNetBuilder) + +import gym + +gym.envs.register( + id='MarioEnv-v0', + entry_point='rl_games.envs.mario:MarioEnv' +) \ No newline at end of file diff --git a/rl_games/envs/mario.py b/rl_games/envs/mario.py new file mode 100644 index 00000000..ddf84d6e --- /dev/null +++ b/rl_games/envs/mario.py @@ -0,0 +1,63 @@ +import gym +import numpy as np +from rl_games.common import wrappers + +class MarioEnv(gym.Env): + def __init__(self, **kwargs): + env_name=kwargs.pop('env_name', 'SuperMarioBros-v1') + self.has_lives = kwargs.pop('has_lives', True) + self.max_lives = kwargs.pop('max_lives', 16) + self.movement = kwargs.pop('movement', 'SIMPLE') + self.use_dict_obs_space = kwargs.pop('use_dict_obs_space', False) + self.env = self._create_super_mario_env(env_name) + if self.use_dict_obs_space: + self.observation_space= gym.spaces.Dict({ + 'observation' : self.env.observation_space, + 'reward' : gym.spaces.Box(low=0, high=1, shape=( ), dtype=np.float32), + 'last_action': gym.spaces.Box(low=0, high=self.env.action_space.n, shape=(), dtype=int) + }) + else: + self.observation_space = self.env.observation_space + + self.action_space = self.env.action_space + + + def _create_super_mario_env(self, name='SuperMarioBros-v1'): + from nes_py.wrappers import JoypadSpace + from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT + import gym_super_mario_bros + movement = SIMPLE_MOVEMENT if name == 'SIMPLE' else COMPLEX_MOVEMENT + env = gym_super_mario_bros.make(name) + env = JoypadSpace(env, movement) + if 'Random' in name: + env = wrappers.EpisodicLifeRandomMarioEnv(env) + else: + env = wrappers.EpisodicLifeMarioEnv(env) + env = wrappers.MaxAndSkipEnv(env, skip=4) + env = wrappers.wrap_deepmind( + env, episode_life=False, clip_rewards=False, frame_stack=True, scale=True) + return env + + def step(self, action): + next_obs, reward, is_done, info = self.env.step(action) + if self.use_dict_obs_space: + next_obs = { + 'observation': next_obs, + 'reward': np.clip(np.array(reward, dtype=float), -1, 1), + 'last_action': np.array(action, dtype=int) + } + return next_obs, reward, is_done, info + + def reset(self): + obs = self.env.reset() + self.env.unwrapped.ram[0x075a] = self.max_lives + if self.use_dict_obs_space: + obs = { + 'observation': obs, + 'reward': np.array(0.0, dtype=float), + 'last_action': np.array(0, dtype=int), + } + return obs + + def get_number_of_agents(self): + return 1 \ No newline at end of file