diff --git a/pyproject.toml b/pyproject.toml index a9213554..32fa2d09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,11 @@ envs = [ "gymnasium[accept-rom-license]", "dm-control", "procgen", - "minigrid" + "minigrid", + "gym==0.26.2", + "gym-super-mario-bros==7.4.0", + "opencv-python==4.8.1.78", + "imageio==2.33.0", ] docs = [ "mkdocs-material", diff --git a/rllte/common/prototype/on_policy_agent.py b/rllte/common/prototype/on_policy_agent.py index 2a77f740..cba48a42 100644 --- a/rllte/common/prototype/on_policy_agent.py +++ b/rllte/common/prototype/on_policy_agent.py @@ -144,8 +144,9 @@ def train( with th.no_grad(): last_values = self.policy.get_value(next_obs).detach() - # perform return and advantage estimation - self.storage.compute_returns_and_advantages(last_values) + # perform return and advantage estimation if have access to extrinsic rewards + if not self.pretraining: + self.storage.compute_returns_and_advantages(last_values) # deal with the intrinsic reward module if self.irs is not None: @@ -157,6 +158,7 @@ def train( "obs": self.storage.observations[:-1], # type: ignore "actions": self.storage.actions, "next_obs": self.storage.observations[1:], # type: ignore + "done": th.logical_or(self.storage.terminateds[:-1], self.storage.truncateds[:-1]) # type: ignore } ) # compute intrinsic rewards @@ -168,9 +170,16 @@ def train( }, step=self.global_episode * self.num_envs * self.num_steps, ) - # only add the intrinsic rewards to the advantages and returns - self.storage.advantages += intrinsic_rewards.to(self.device) - self.storage.returns += intrinsic_rewards.to(self.device) + + # if pretraining, compute intrinsic returns and advantages + if self.pretraining: + self.storage.rewards = intrinsic_rewards.to(self.device) + self.storage.compute_returns_and_advantages(last_values) + + # if combining intrinsic + extrinsic rewards, add intrinsic rewards to extrinsic returns and advantages + else: + self.storage.advantages += intrinsic_rewards.to(self.device) + self.storage.returns += intrinsic_rewards.to(self.device) # update the agent self.update() diff --git a/rllte/env/__init__.py b/rllte/env/__init__.py index 1cf0e2e8..142da333 100644 --- a/rllte/env/__init__.py +++ b/rllte/env/__init__.py @@ -57,3 +57,8 @@ from .procgen import make_procgen_env as make_procgen_env except Exception: pass + +try: + from .mario import make_mario_env as make_mario_env +except Exception: + pass diff --git a/rllte/env/mario/__init__.py b/rllte/env/mario/__init__.py new file mode 100644 index 00000000..2166f722 --- /dev/null +++ b/rllte/env/mario/__init__.py @@ -0,0 +1,48 @@ +from typing import Callable, Dict + +import gymnasium as gym +import gym as gym_old +from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv +from gymnasium.wrappers import RecordEpisodeStatistics + +from nes_py.wrappers import JoypadSpace +from gym_super_mario_bros.actions import SIMPLE_MOVEMENT + +from rllte.env.utils import Gymnasium2Torch +from rllte.env.mario.wrappers import ( + EpisodicLifeEnv, + SkipFrame, + Gym2Gymnasium, + ImageTranspose +) + +def make_mario_env( + env_id: str = "SuperMarioBros-v0", + num_envs: int = 8, + device: str = "cpu", + asynchronous: bool = True, + seed: int = 0, + ) -> Gymnasium2Torch: + + def make_env(env_id: str, seed: int) -> Callable: + def _thunk(): + env = gym_old.make(env_id, apply_api_compatibility=True, render_mode="rgb_array") + env = JoypadSpace(env, SIMPLE_MOVEMENT) + env = Gym2Gymnasium(env) + env = SkipFrame(env, skip=4) + env = gym.wrappers.ResizeObservation(env, (84, 84)) + env = ImageTranspose(env) + env = EpisodicLifeEnv(env) + env.observation_space.seed(seed) + return env + return _thunk + + envs = [make_env(env_id, seed + i) for i in range(num_envs)] + if asynchronous: + envs = AsyncVectorEnv(envs) + else: + envs = SyncVectorEnv(envs) + + envs = RecordEpisodeStatistics(envs) + return Gymnasium2Torch(envs, device=device) + \ No newline at end of file diff --git a/rllte/env/mario/wrappers.py b/rllte/env/mario/wrappers.py new file mode 100644 index 00000000..1662aefc --- /dev/null +++ b/rllte/env/mario/wrappers.py @@ -0,0 +1,107 @@ +import gymnasium as gym +import numpy as np + +class EpisodicLifeEnv(gym.Wrapper): + def __init__(self, env): + """Make end-of-life == end-of-episode, but only reset on true game + over. + """ + gym.Wrapper.__init__(self, env) + self.lives = 0 + self.was_real_done = True + self.env = env + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + self.was_real_done = np.logical_or(terminated, truncated) + try: + lives = self.env.unwrapped.env._life + if self.lives > lives > 0: + terminated, truncated = True, True + self.lives = lives + except: + pass + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + +class SkipFrame(gym.Wrapper): + def __init__(self, env, skip): + """Return only every `skip`-th frame""" + super().__init__(env) + self._skip = skip + self.env = env + + def step(self, action): + """Repeat action, and sum reward""" + total_reward = 0.0 + for i in range(self._skip): + # Accumulate reward and repeat the same action + obs, reward, terminated, truncated, info = self.env.step(action) + total_reward += reward + if np.logical_or(terminated, truncated): + break + return obs, total_reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + return self.env.reset() + + def render(self): + return self.env.render() + + +class Gym2Gymnasium(gym.Wrapper): + def __init__(self, env): + """Convert gym.Env to gymnasium.Env""" + self.env = env + + self.observation_space = gym.spaces.Box( + low=0, + high=255, + shape=env.observation_space.shape, + dtype=env.observation_space.dtype, + ) + self.action_space = gym.spaces.Discrete(env.action_space.n) + + def step(self, action): + """Repeat action, and sum reward""" + return self.env.step(action) + + def reset(self, options=None, seed=None): + return self.env.reset() + + def render(self): + return self.env.render() + + def close(self): + return self.env.close() + + def seed(self, seed=None): + return self.env.seed(seed=seed) + +class ImageTranspose(gym.ObservationWrapper): + """Transpose observation from channels last to channels first. + + Args: + env (gym.Env): Environment to wrap. + + Returns: + Minigrid2Image instance. + """ + + def __init__(self, env: gym.Env) -> None: + gym.ObservationWrapper.__init__(self, env) + shape = env.observation_space.shape + dtype = env.observation_space.dtype + self.observation_space = gym.spaces.Box( + low=0, + high=255, + shape=(shape[2], shape[0], shape[1]), + dtype=dtype, + ) + + def observation(self, observation): + """Convert observation to image.""" + observation= np.transpose(observation, axes=[2, 0, 1]) + return observation diff --git a/rllte/xplore/reward/__init__.py b/rllte/xplore/reward/__init__.py index 3a213056..9e85747d 100644 --- a/rllte/xplore/reward/__init__.py +++ b/rllte/xplore/reward/__init__.py @@ -32,3 +32,4 @@ from .ride import RIDE as RIDE from .rise import RISE as RISE from .rnd import RND as RND +from .e3b import E3B as E3B diff --git a/rllte/xplore/reward/e3b.py b/rllte/xplore/reward/e3b.py new file mode 100644 index 00000000..64de90f6 --- /dev/null +++ b/rllte/xplore/reward/e3b.py @@ -0,0 +1,266 @@ +# ============================================================================= +# MIT License + +# Copyright (c) 2023 Reinforcement Learning Evolution Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ============================================================================= + + +from typing import Dict, Tuple + +import gymnasium as gym +import numpy as np +import torch as th +import torch.nn.functional as F +from torch import nn + +from rllte.common.prototype import BaseIntrinsicRewardModule +from torch.utils.data import DataLoader, TensorDataset + +from .utils import TorchRunningMeanStd +from IPython import embed + + +class Encoder(nn.Module): + """Encoder for encoding observations. + + Args: + obs_shape (Tuple): The data shape of observations. + action_dim (int): The dimension of actions. + latent_dim (int): The dimension of encoding vectors. + + Returns: + Encoder instance. + """ + + def __init__(self, obs_shape: Tuple, action_dim: int, latent_dim: int) -> None: + super().__init__() + + # visual + if len(obs_shape) == 3: + self.trunk = nn.Sequential( + nn.Conv2d(obs_shape[0], 32, kernel_size=3, stride=2, padding=1), + nn.ELU(), + nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1), + nn.ELU(), + nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1), + nn.ELU(), + nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1), + nn.ELU(), + nn.Flatten(), + ) + with th.no_grad(): + sample = th.ones(size=tuple(obs_shape)) + n_flatten = self.trunk(sample.unsqueeze(0)).shape[1] + + self.linear = nn.Linear(n_flatten, latent_dim) + else: + self.trunk = nn.Sequential(nn.Linear(obs_shape[0], 256), nn.ReLU()) + self.linear = nn.Linear(256, latent_dim) + + def forward(self, obs: th.Tensor) -> th.Tensor: + """Encode the input tensors. + + Args: + obs (th.Tensor): Observations. + + Returns: + Encoding tensors. + """ + return self.linear(self.trunk(obs)) + +class InverseDynamicsModel(nn.Module): + """Inverse model for reconstructing transition process. + + Args: + latent_dim (int): The dimension of encoding vectors of the observations. + action_dim (int): The dimension of predicted actions. + + Returns: + Model instance. + """ + + def __init__(self, latent_dim, action_dim) -> None: + super().__init__() + + self.trunk = nn.Sequential(nn.Linear(2 * latent_dim, 256), nn.ReLU(), nn.Linear(256, action_dim)) + + def forward(self, obs: th.Tensor, next_obs: th.Tensor) -> th.Tensor: + """Forward function for outputing predicted actions. + + Args: + obs (th.Tensor): Current observations. + next_obs (th.Tensor): Next observations. + + Returns: + Predicted actions. + """ + return self.trunk(th.cat([obs, next_obs], dim=1)) + + +class E3B(BaseIntrinsicRewardModule): + """Exploration via Elliptical Episodic Bonuses (E3B). + See paper: https://proceedings.neurips.cc/paper_files/paper/2022/file/f4f79698d48bdc1a6dec20583724182b-Paper-Conference.pdf + + Args: + observation_space (Space): The observation space of environment. + action_space (Space): The action space of environment. + device (str): Device (cpu, cuda, ...) on which the code should be run. + beta (float): The initial weighting coefficient of the intrinsic rewards. + kappa (float): The decay rate. + latent_dim (int): The dimension of the ellipsoid vectors. + num_envs (int): The number of parallel environments. + ridge (float): The ridge parameter for the ellipsoid matrix. + lr (float): The learning rate for the encoder and inverse model. + batch_size (int): The batch size for the encoder and inverse model. + Returns: + Instance of E3B. + """ + + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + device: str = "cpu", + beta: float = 0.05, + kappa: float = 0.000025, + latent_dim: int = 512, + num_envs: int = 1, + ridge: float = 0.1, + lr: float = 0.001, + batch_size: int = 256, + ) -> None: + super().__init__(observation_space, action_space, device, beta, kappa) + + self.elliptical_encoder = Encoder( + obs_shape=self._obs_shape, + action_dim=self._action_dim, + latent_dim=latent_dim, + ).to(self._device) + + self.im = InverseDynamicsModel(latent_dim=latent_dim, action_dim=self._action_dim).to(self._device) + self.im_loss = nn.CrossEntropyLoss() + self.im_opt = th.optim.Adam(self.im.parameters(), lr=lr) + self.encoder_opt = th.optim.Adam(self.elliptical_encoder.parameters(), lr=lr) + + self.idx = 0 + self.ridge = ridge + self.batch_size = batch_size + self.latent_dim = latent_dim + self.num_envs = num_envs + self.running_mean_std = TorchRunningMeanStd(shape=(num_envs,), device=self._device) + + self.cov_inverse = (th.eye(latent_dim) * (1.0 / ridge)).to(self._device) + self.outer_product_buffer = th.empty(latent_dim, latent_dim).to(self._device) + + self.cov_inverse = self.cov_inverse.repeat(num_envs, 1, 1) + self.outer_product_buffer = self.outer_product_buffer.repeat(num_envs, 1, 1) + + def compute_irs(self, samples: Dict, step: int = 0) -> th.Tensor: + """Normalize and return the intrinsic rewards. They have been previously computed in the add method. + + Args: + samples (Dict): The collected samples. A python dict like + {obs (n_steps, n_envs, *obs_shape) , + actions (n_steps, n_envs, *action_shape) , + rewards (n_steps, n_envs) , + next_obs (n_steps, n_envs, *obs_shape) }. + step (int): The global training step. + + Returns: + The intrinsic rewards. + """ + # compute the weighting coefficient of timestep t + beta_t = self._beta * np.power(1.0 - self._kappa, step) + + # update the module + self.update(samples) + + # update the running mean and std + self.running_mean_std.update(self.intrinsic_rewards) + return self.intrinsic_rewards / self.running_mean_std.std * beta_t + + def update(self, samples: Dict) -> None: + """Update the intrinsic reward module if necessary. + + Args: + samples: The collected samples. A python dict like + {obs (n_steps, n_envs, *obs_shape) , + actions (n_steps, n_envs, *action_shape) , + rewards (n_steps, n_envs) , + next_obs (n_steps, n_envs, *obs_shape) }. + + Returns: + None + """ + num_steps = samples["obs"].size()[0] + num_envs = samples["obs"].size()[1] + obs_tensor = samples["obs"].view((num_envs * num_steps, *self._obs_shape)).to(self._device) + next_obs_tensor = samples["next_obs"].view((num_envs * num_steps, *self._obs_shape)).to(self._device) + actions_tensor = samples["actions"].view(num_envs * num_steps).to(self._device) + actions_tensor = F.one_hot(actions_tensor.long(), self._action_dim).float() + + dataset = TensorDataset(obs_tensor, actions_tensor, next_obs_tensor) + loader = DataLoader(dataset=dataset, batch_size=self.batch_size) + + # only perform one update step, otherwise the inverse model will overfit + obs, actions, next_obs = next(iter(loader)) + + self.encoder_opt.zero_grad() + self.im_opt.zero_grad() + + encoded_obs = self.elliptical_encoder(obs) + encoded_next_obs = self.elliptical_encoder(next_obs) + + pred_actions = self.im(encoded_obs, encoded_next_obs) + im_loss = self.im_loss(pred_actions, actions) + im_loss.backward() + + self.encoder_opt.step() + self.im_opt.step() + + def add(self, samples: Dict) -> None: + """Calculate the ellipsoid matrix and intrinsic rewards. + + Args: + samples: The collected samples. A python dict like + {obs (n_steps, n_envs, *obs_shape) , + actions (n_steps, n_envs, *action_shape) , + rewards (n_steps, n_envs) , + next_obs (n_steps, n_envs, *obs_shape) }. + done (n_steps, n_envs) }. + Returns: + None + """ + num_steps = samples["obs"].size()[0] + self.intrinsic_rewards = th.zeros(size=(num_steps, self.num_envs)).to(self._device) + with th.no_grad(): + for j in range(num_steps): + h = self.elliptical_encoder(samples["obs"][j]) + for env_idx in range(self.num_envs): + u = th.mv(self.cov_inverse[env_idx], h[env_idx]) + b = th.dot(h[env_idx], u).item() + self.intrinsic_rewards[j, env_idx] = b + + th.outer(u, u, out=self.outer_product_buffer[env_idx]) + th.add(self.cov_inverse[env_idx], self.outer_product_buffer[env_idx], alpha=-(1./(1. + b)), out=self.cov_inverse[env_idx]) + + if samples["done"][j, env_idx]: + self.cov_inverse[env_idx] = th.eye(self.latent_dim) * (1.0 / self.ridge) \ No newline at end of file