diff --git a/docs/09-environment-integrations/nethack.md b/docs/09-environment-integrations/nethack.md new file mode 100644 index 000000000..5dd1f5526 --- /dev/null +++ b/docs/09-environment-integrations/nethack.md @@ -0,0 +1,115 @@ +# NetHack + +## Installation +Works in `Python 3.10`. Higher versions have problems with building NLE. + +To install NetHack, you need nle and its dependencies. + +```bash +# nle dependencies +apt-get install build-essential python3-dev python3-pip python3-numpy autoconf libtool pkg-config libbz2-dev +conda install cmake flex bison lit + +# install nle locally and modify it to enable seeding and handle rendering with gymnasium +git clone https://github.com/facebookresearch/nle.git nle && cd nle \ +&& git checkout v0.9.0 && git submodule init && git submodule update --recursive \ +&& sed '/#define NLE_ALLOW_SEEDING 1/i#define NLE_ALLOW_SEEDING 1' include/nleobs.h -i \ +&& sed '/self\.nethack\.set_initial_seeds = f/d' nle/env/tasks.py -i \ +&& sed '/self\.nethack\.set_current_seeds = f/d' nle/env/tasks.py -i \ +&& sed '/self\.nethack\.get_current_seeds = f/d' nle/env/tasks.py -i \ +&& sed '/def seed(self, core=None, disp=None, reseed=True):/d' nle/env/tasks.py -i \ +&& sed '/raise RuntimeError("NetHackChallenge doesn.t allow seed changes")/d' nle/env/tasks.py -i \ +&& sed -i '/def render(self, mode="human"):/a\ if not self.last_observation:\n return' nle/env/base.py \ +&& python setup.py install && cd .. + +# install sample factory with nethack extras +pip install -e .[nethack] +conda install -c conda-forge pybind11 +pip install -e sf_examples/nethack/nethack_render_utils +``` + +## Running Experiments + +Run NetHack experiments with the scripts in `sf_examples.nethack`. +The default parameters have been chosen to match [dungeons & data](https://github.com/dungeonsdatasubmission/dungeonsdata-neurips2022) which is based on [nle sample factory baseline](https://github.com/Miffyli/nle-sample-factory-baseline). By moving from D&D to sample factory we've managed to increase the APPO score from 2k to 2.8k. + +To train a model in the `nethack_challenge` environment: + +``` +python -m sf_examples.nethack.train_nethack \ + --env=nethack_challenge \ + --batch_size=4096 \ + --num_workers=16 \ + --num_envs_per_worker=32 \ + --worker_num_splits=2 \ + --rollout=32 \ + --character=mon-hum-neu-mal \ + --model=ChaoticDwarvenGPT5 \ + --rnn_size=512 \ + --experiment=nethack_monk +``` + +To visualize the training results, use the `enjoy_nethack` script: + +``` +python -m sf_examples.nethack.enjoy_nethack --env=nethack_challenge --character=mon-hum-neu-mal --experiment=nethack_monk +``` + +Additionally it's possible to use an alternative `fast_eval_nethack` script which is much faster + +``` +python -m sf_examples.nethack.fast_eval_nethack --env=nethack_challenge --sample_env_episodes=128 --num_workers=16 --num_envs_per_worker=2 --character=mon-hum-neu-mal --experiment=nethack_monk +``` + +### List of Supported Environments + +- nethack_staircase +- nethack_score +- nethack_pet +- nethack_oracle +- nethack_gold +- nethack_eat +- nethack_scout +- nethack_challenge + +## Results + +### Reports +1. Sample Factory was benchmarked on `nethack_challenge` against Dungeons and Data. Sample-Factory was able to achieve similar sample efficiency as D&D using the same parameters and get better running returns (2.8k vs 2k). Training was done on `nethack_challenge` with human-monk character for 2B env steps. + - https://api.wandb.ai/links/bartekcupial/w69fid1w + +### Models +Sample Factory APPO model trained on `nethack_challenge` environment is uploaded to the HuggingFace Hub. The model have been trained for 2B steps. + +The model below is the best model from the experiment against Dungeons and Data above. The evaluation metrics here are obtained by running the model 1024 times. + +Model card: https://huggingface.co/LLParallax/sample_factory_human_monk +Evaluation results: +``` +{ + "reward/reward": 3245.3828125, + "reward/reward_min": 20.0, + "reward/reward_max": 18384.0, + "len/len": 2370.4560546875, + "len/len_min": 27.0, + "len/len_max": 21374.0, + "policy_stats/avg_score": 3245.4716796875, + "policy_stats/avg_turns": 14693.970703125, + "policy_stats/avg_dlvl": 1.13671875, + "policy_stats/avg_max_hitpoints": 46.42578125, + "policy_stats/avg_max_energy": 34.00390625, + "policy_stats/avg_armor_class": 4.68359375, + "policy_stats/avg_experience_level": 6.13671875, + "policy_stats/avg_experience_points": 663.375, + "policy_stats/avg_eating_score": 14063.2587890625, + "policy_stats/avg_gold_score": 76.033203125, + "policy_stats/avg_scout_score": 499.0478515625, + "policy_stats/avg_sokobanfillpit_score": 0.0, + "policy_stats/avg_staircase_pet_score": 0.005859375, + "policy_stats/avg_staircase_score": 4.9970703125, + "policy_stats/avg_episode_number": 1.5, + "policy_stats/avg_true_objective": 3245.3828125, + "policy_stats/avg_true_objective_min": 20.0, + "policy_stats/avg_true_objective_max": 18384.0 +} +``` \ No newline at end of file diff --git a/setup.py b/setup.py index 46254b904..a85ae51c2 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,16 @@ _atari_deps = ["gymnasium[atari, accept-rom-license]"] _mujoco_deps = ["gymnasium[mujoco]", "mujoco<2.5"] +_nethack_deps = [ + "numba ~= 0.58", + "pandas ~= 2.1", + "matplotlib ~= 3.8", + "seaborn ~= 0.12", + "scipy ~= 1.11", + "shimmy", + "tqdm ~= 4.66", + "debugpy ~= 1.6", +] _envpool_deps = ["envpool"] _docs_deps = [ @@ -67,6 +77,7 @@ "atari": _atari_deps, "envpool": _envpool_deps, "mujoco": _mujoco_deps, + "nethack": _nethack_deps, "vizdoom": ["vizdoom<2.0", "gymnasium[classic_control]"], # "dmlab": ["dm_env"], <-- these are just auxiliary packages, the main package has to be built from sources }, diff --git a/sf_examples/nethack/enjoy_nethack.py b/sf_examples/nethack/enjoy_nethack.py new file mode 100644 index 000000000..443618e50 --- /dev/null +++ b/sf_examples/nethack/enjoy_nethack.py @@ -0,0 +1,30 @@ +import sys + +from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args +from sample_factory.enjoy import enjoy +from sf_examples.nethack.nethack_params import ( + add_extra_params_general, + add_extra_params_model, + add_extra_params_nethack_env, + nethack_override_defaults, +) +from sf_examples.nethack.train_nethack import register_nethack_components + + +def main(): # pragma: no cover + """Script entry point.""" + register_nethack_components() + + parser, cfg = parse_sf_args(evaluation=True) + add_extra_params_nethack_env(parser) + add_extra_params_model(parser) + add_extra_params_general(parser) + nethack_override_defaults(cfg.env, parser) + cfg = parse_full_cfg(parser) + + status = enjoy(cfg) + return status + + +if __name__ == "__main__": # pragma: no cover + sys.exit(main()) diff --git a/sf_examples/nethack/fast_eval_nethack.py b/sf_examples/nethack/fast_eval_nethack.py new file mode 100644 index 000000000..d05572698 --- /dev/null +++ b/sf_examples/nethack/fast_eval_nethack.py @@ -0,0 +1,35 @@ +import sys + +from sample_factory.cfg.arguments import checkpoint_override_defaults, parse_full_cfg, parse_sf_args +from sample_factory.eval import do_eval +from sf_examples.nethack.nethack_params import ( + add_extra_params_general, + add_extra_params_model, + add_extra_params_nethack_env, + nethack_override_defaults, +) +from sf_examples.nethack.train_nethack import register_nethack_components + + +def main(): # pragma: no cover + """Script entry point.""" + register_nethack_components() + + parser, cfg = parse_sf_args(evaluation=True) + add_extra_params_nethack_env(parser) + add_extra_params_model(parser) + add_extra_params_general(parser) + nethack_override_defaults(cfg.env, parser) + + # important, instead of `load_from_checkpoint` as in enjoy we want + # to override it here to be able to use argv arguments + checkpoint_override_defaults(cfg, parser) + + cfg = parse_full_cfg(parser) + + status = do_eval(cfg) + return status + + +if __name__ == "__main__": # pragma: no cover + sys.exit(main()) diff --git a/sf_examples/nethack/models/__init__.py b/sf_examples/nethack/models/__init__.py new file mode 100644 index 000000000..70c1a1138 --- /dev/null +++ b/sf_examples/nethack/models/__init__.py @@ -0,0 +1,6 @@ +from sf_examples.nethack.models.chaotic_dwarf import ChaoticDwarvenGPT5 + +MODELS = [ + ChaoticDwarvenGPT5, +] +MODELS_LOOKUP = {c.__name__: c for c in MODELS} diff --git a/sf_examples/nethack/models/chaotic_dwarf.py b/sf_examples/nethack/models/chaotic_dwarf.py new file mode 100644 index 000000000..b2d0df458 --- /dev/null +++ b/sf_examples/nethack/models/chaotic_dwarf.py @@ -0,0 +1,304 @@ +"""Adapted from Chaos Dwarf in Nethack Challenge Starter Kit: +https://github.com/Miffyli/nle-sample-factory-baseline + +MIT License + +Copyright (c) 2021 Anssi + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +import torch +from nle import nethack +from torch import nn +from torch.nn import functional as F + +from sample_factory.algo.utils.torch_utils import calc_num_elements +from sample_factory.model.encoder import Encoder +from sample_factory.utils.typing import Config, ObsSpace + + +class MessageEncoder(nn.Module): + def __init__(self): + super(MessageEncoder, self).__init__() + self.hidden_dim = 128 + self.msg_fwd = nn.Sequential( + nn.Linear(nethack.MESSAGE_SHAPE[0], 128), + nn.ELU(inplace=True), + nn.Linear(128, self.hidden_dim), + nn.ELU(inplace=True), + ) + + def forward(self, message): + return self.msg_fwd(message / 255.0) + + +class BLStatsEncoder(nn.Module): + def __init__(self): + super(BLStatsEncoder, self).__init__() + self.hidden_dim = 128 + nethack.BLSTATS_SHAPE[0] + self.blstats_fwd = nn.Sequential( + nn.Linear(nethack.BLSTATS_SHAPE[0], 128), + nn.ELU(inplace=True), + nn.Linear(128, 128), + nn.ELU(inplace=True), + ) + + normalization_stats = torch.tensor( + [ + 1.0 / 79.0, # hero col + 1.0 / 21, # hero row + 0.0, # strength pct + 1.0 / 10, # strength + 1.0 / 10, # dexterity + 1.0 / 10, # constitution + 1.0 / 10, # intelligence + 1.0 / 10, # wisdom + 1.0 / 10, # charisma + 0.0, # score + 1.0 / 10, # hitpoints + 1.0 / 10, # max hitpoints + 0.0, # depth + 1.0 / 1000, # gold + 1.0 / 10, # energy + 1.0 / 10, # max energy + 1.0 / 10, # armor class + 0.0, # monster level + 1.0 / 10, # experience level + 1.0 / 100, # experience points + 1.0 / 1000, # time + 1.0, # hunger_state + 1.0 / 10, # carrying capacity + 0.0, # carrying capacity + 0.0, # level number + 0.0, # condition bits + 0.0, # alignment bits + ], + requires_grad=False, + ) + self.register_buffer("normalization_stats", normalization_stats) + + self.blstat_range = (-5, 5) + + def forward(self, blstats): + norm_bls = torch.clip( + blstats * self.normalization_stats, + self.blstat_range[0], + self.blstat_range[1], + ) + + return torch.cat([self.blstats_fwd(norm_bls), norm_bls], dim=-1) + + +class TopLineEncoder(nn.Module): + def __init__(self): + super(TopLineEncoder, self).__init__() + self.hidden_dim = 128 + self.i_dim = nethack.NLE_TERM_CO * 256 + + self.msg_fwd = nn.Sequential( + nn.Linear(self.i_dim, 128), + nn.ELU(inplace=True), + nn.Linear(128, self.hidden_dim), + nn.ELU(inplace=True), + ) + + def forward(self, message): + # Characters start at 33 in ASCII and go to 128. 96 = 128 - 32 + message_normed = F.one_hot((message).long(), 256).reshape(-1, self.i_dim).float() + return self.msg_fwd(message_normed) + + +class BottomLinesEncoder(nn.Module): + def __init__(self): + super(BottomLinesEncoder, self).__init__() + self.conv_layers = [] + w = nethack.NLE_TERM_CO * 2 + for in_ch, out_ch, filter, stride in [[2, 32, 8, 4], [32, 64, 4, 1]]: + self.conv_layers.append(nn.Conv1d(in_ch, out_ch, filter, stride=stride)) + self.conv_layers.append(nn.ELU(inplace=True)) + w = conv_outdim(w, filter, padding=0, stride=stride) + + self.conv_net = nn.Sequential(*self.conv_layers) + self.fwd_net = nn.Sequential( + nn.Linear(w * out_ch, 128), + nn.ELU(), + nn.Linear(128, 128), + nn.ELU(), + ) + self.hidden_dim = 128 + + def forward(self, bottom_lines): + B, D = bottom_lines.shape + # ASCII 32: ' ', ASCII [33-128]: visible characters + chars_normalised = (bottom_lines - 32) / 96 + + # ASCII [45-57]: -./01234556789 + numbers_mask = (bottom_lines > 44) * (bottom_lines < 58) + digits_normalised = numbers_mask * (bottom_lines - 47) / 10 + + # Put in different channels & conv (B, 2, D) + x = torch.stack([chars_normalised, digits_normalised], dim=1) + return self.fwd_net(self.conv_net(x).view(B, -1)) + + +def conv_outdim(i_dim, k, padding=0, stride=1, dilation=1): + """Return the dimension after applying a convolution along one axis""" + return int(1 + (i_dim + 2 * padding - dilation * (k - 1) - 1) / stride) + + +class InverseModel(nn.Module): + def __init__(self, h_dim, action_space): + super(InverseModel, self).__init__() + self.h_dim = h_dim * 2 + self.action_space = action_space + + self.fwd_model = nn.Sequential( + nn.Linear(self.h_dim, 128), + nn.ELU(inplace=True), + nn.Linear(128, 128), + nn.ELU(inplace=True), + nn.Linear(128, action_space), + ) + + def forward(self, obs): + T, B, *_ = obs.shape + x = torch.cat([obs[:-1], obs[1:]], dim=-1) + pred_a = self.fwd_model(x) + off_by_one = torch.ones((1, B, self.action_space), device=x.device) * -1 + return torch.cat([pred_a, off_by_one], dim=0) + + +class ScreenEncoder(nn.Module): + def __init__(self, screen_shape): + super(ScreenEncoder, self).__init__() + conv_layers = [] + + self.h, self.w = screen_shape + self.hidden_dim = 512 + + self.conv_filters = [ + [3, 32, 8, 6, 1], + [32, 64, 4, 2, 1], + [64, 128, 3, 2, 1], + [128, 128, 3, 1, 1], + ] + + for ( + in_channels, + out_channels, + filter_size, + stride, + dilation, + ) in self.conv_filters: + conv_layers.append( + nn.Conv2d( + in_channels, + out_channels, + filter_size, + stride=stride, + dilation=dilation, + ) + ) + conv_layers.append(nn.ELU(inplace=True)) + + self.h = conv_outdim(self.h, filter_size, padding=0, stride=stride, dilation=dilation) + self.w = conv_outdim(self.w, filter_size, padding=0, stride=stride, dilation=dilation) + + self.conv_head = nn.Sequential(*conv_layers) + self.out_size = self.h * self.w * out_channels + + self.fc_head = nn.Sequential(nn.Linear(self.out_size, self.hidden_dim), nn.ELU(inplace=True)) + + def forward(self, screen_image): + x = self.conv_head(screen_image / 255.0) + x = x.view(-1, self.out_size) + x = self.fc_head(x) + return x + + +class ChaoticDwarvenGPT5(Encoder): + def __init__(self, cfg: Config, obs_space: ObsSpace): + super().__init__(cfg) + self.obs_keys = list(sorted(obs_space.keys())) # always the same order + self.encoders = nn.ModuleDict() + + self.use_tty_only = cfg.use_tty_only + self.use_prev_action = cfg.use_prev_action + + # screen encoder (TODO: could also use only tty_chars) + pixel_size = cfg.pixel_size + if cfg.crop_dim == 0: + screen_shape = (24 * pixel_size, 80 * pixel_size) + else: + screen_shape = (cfg.crop_dim * pixel_size, cfg.crop_dim * pixel_size) + self.screen_encoder = torch.jit.script(ScreenEncoder(screen_shape)) + screen_shape = obs_space["screen_image"].shape + + # top and bottom encoders + if self.use_tty_only: + self.topline_encoder = TopLineEncoder() + self.bottomline_encoder = torch.jit.script(BottomLinesEncoder()) + topline_shape = (obs_space["tty_chars"].shape[1],) + bottomline_shape = (2 * obs_space["tty_chars"].shape[1],) + else: + self.topline_encoder = torch.jit.script(MessageEncoder()) + self.bottomline_encoder = torch.jit.script(BLStatsEncoder()) + topline_shape = obs_space["message"].shape + bottomline_shape = obs_space["blstats"].shape + + if self.use_prev_action: + self.num_actions = obs_space["prev_actions"].n + self.prev_actions_dim = self.num_actions + else: + self.num_actions = None + self.prev_actions_dim = 0 + + self.encoder_out_size = sum( + [ + calc_num_elements(self.screen_encoder, screen_shape), + calc_num_elements(self.topline_encoder, topline_shape), + calc_num_elements(self.bottomline_encoder, bottomline_shape), + self.prev_actions_dim, + ] + ) + + def forward(self, obs_dict): + B, C, H, W = obs_dict["screen_image"].shape + + if self.use_tty_only: + topline = obs_dict["tty_chars"][..., 0, :] + bottom_line = obs_dict["tty_chars"][..., -2:, :] + else: + topline = obs_dict["message"] + bottom_line = obs_dict["blstats"] + + encodings = [ + self.topline_encoder(topline.float(memory_format=torch.contiguous_format).view(B, -1)), + self.bottomline_encoder(bottom_line.float(memory_format=torch.contiguous_format).view(B, -1)), + self.screen_encoder(obs_dict["screen_image"].float(memory_format=torch.contiguous_format).view(B, C, H, W)), + ] + + if self.use_prev_action: + prev_actions = obs_dict["prev_actions"].long().view(B) + encodings.append(torch.nn.functional.one_hot(prev_actions, self.num_actions)) + + return torch.cat(encodings, dim=1) + + def get_out_size(self) -> int: + return self.encoder_out_size diff --git a/sf_examples/nethack/nethack_env.py b/sf_examples/nethack/nethack_env.py new file mode 100644 index 000000000..b870aad6d --- /dev/null +++ b/sf_examples/nethack/nethack_env.py @@ -0,0 +1,107 @@ +from typing import Optional + +from nle.env.tasks import ( + NetHackChallenge, + NetHackEat, + NetHackGold, + NetHackOracle, + NetHackScore, + NetHackScout, + NetHackStaircase, + NetHackStaircasePet, +) + +from sample_factory.algo.utils.gymnasium_utils import patch_non_gymnasium_env +from sf_examples.nethack.utils.wrappers import ( + BlstatsInfoWrapper, + PrevActionsWrapper, + RenderCharImagesWithNumpyWrapperV2, + SeedActionSpaceWrapper, + TaskRewardsInfoWrapper, +) + +NETHACK_ENVS = dict( + nethack_staircase=NetHackStaircase, + nethack_score=NetHackScore, + nethack_pet=NetHackStaircasePet, + nethack_oracle=NetHackOracle, + nethack_gold=NetHackGold, + nethack_eat=NetHackEat, + nethack_scout=NetHackScout, + nethack_challenge=NetHackChallenge, +) + + +def nethack_env_by_name(name): + if name in NETHACK_ENVS.keys(): + return NETHACK_ENVS[name] + else: + raise Exception("Unknown NetHack env") + + +def make_nethack_env(env_name, cfg, env_config, render_mode: Optional[str] = None): + assert render_mode in (None, "human", "full", "ansi", "string", "rgb_array") + + env_class = nethack_env_by_name(env_name) + + observation_keys = ( + "message", + "blstats", + "tty_chars", + "tty_colors", + "tty_cursor", + # ALSO AVAILABLE (OFF for speed) + # "specials", + # "colors", + # "chars", + # "glyphs", + # "inv_glyphs", + # "inv_strs", + # "inv_letters", + # "inv_oclasses", + ) + + kwargs = dict( + character=cfg.character, + max_episode_steps=cfg.max_episode_steps, + observation_keys=observation_keys, + penalty_step=cfg.penalty_step, + penalty_time=cfg.penalty_time, + penalty_mode=cfg.fn_penalty_step, + savedir=cfg.savedir, + save_ttyrec_every=cfg.save_ttyrec_every, + ) + if env_name == "challenge": + kwargs["no_progress_timeout"] = 150 + + if env_name in ("staircase", "pet", "oracle"): + kwargs.update(reward_win=cfg.reward_win, reward_lose=cfg.reward_lose) + # else: # print warning once + # warnings.warn("Ignoring cfg.reward_win and cfg.reward_lose") + + env = env_class(**kwargs) + + if cfg.add_image_observation: + env = RenderCharImagesWithNumpyWrapperV2( + env, + crop_size=cfg.crop_dim, + rescale_font_size=(cfg.pixel_size, cfg.pixel_size), + ) + + if cfg.use_prev_action: + env = PrevActionsWrapper(env) + + if cfg.add_stats_to_info: + env = BlstatsInfoWrapper(env) + env = TaskRewardsInfoWrapper(env) + + env = patch_non_gymnasium_env(env) + + if render_mode: + env.render_mode = render_mode + + if cfg.serial_mode and cfg.num_workers == 1: + # full reproducability can only be achieved in serial mode and when there is only 1 worker + env = SeedActionSpaceWrapper(env) + + return env diff --git a/sf_examples/nethack/nethack_params.py b/sf_examples/nethack/nethack_params.py new file mode 100644 index 000000000..926a8b2a7 --- /dev/null +++ b/sf_examples/nethack/nethack_params.py @@ -0,0 +1,134 @@ +from sample_factory.utils.utils import str2bool + + +def add_extra_params_nethack_env(parser): + """ + Specify any additional command line arguments for NetHack environments. + """ + p = parser + p.add_argument( + "--character", type=str, default="mon-hum-neu-mal", help="name of character. Defaults to 'mon-hum-neu-mal'." + ) + p.add_argument( + "--max_episode_steps", + type=int, + default=100000, + help="maximum amount of steps allowed before the game is forcefully quit. In such cases, `info 'end_status']` will be equal to `StepStatus.ABORTED`", + ) + p.add_argument( + "--penalty_step", type=float, default=0.0, help="constant applied to amount of frozen steps. Defaults to 0.0." + ) + p.add_argument( + "--penalty_time", type=float, default=0.0, help="constant applied to amount of frozen steps. Defaults to 0.0." + ) + p.add_argument( + "--fn_penalty_step", + type=str, + default="constant", + help="name of the mode for calculating the time step penalty. Can be `constant`, `exp`, `square`, `linear`, or `always`. Defaults to `constant`.", + ) + p.add_argument( + "--savedir", + type=str, + default=None, + help="Path to save ttyrecs (game recordings) into, if save_ttyrec_every is nonzero. If nonempty string, interpreted as a path to a new or existing directory. If " + " (empty string) or None, NLE choses a unique directory name.Defaults to `None`.", + ) + p.add_argument( + "--save_ttyrec_every", + type=int, + default=0, + help="Integer, if 0, no ttyrecs (game recordings) will be saved. Otherwise, save a ttyrec every Nth episode.", + ) + p.add_argument( + "--add_image_observation", + type=str2bool, + default=True, + help="If True, additional wrapper will render screen image. Defaults to `True`.", + ) + p.add_argument("--crop_dim", type=int, default=18, help="Crop image around the player. Defaults to `18`.") + p.add_argument( + "--pixel_size", + type=int, + default=6, + help="Rescales each character to size of `(pixel_size, pixel_size). Defaults to `6`.", + ) + + +def add_extra_params_model(parser): + """ + Specify any additional command line arguments for NetHack models. + """ + p = parser + p.add_argument( + "--use_prev_action", + type=str2bool, + default=True, + help="If True, the model will use previous action. Defaults to `True`", + ) + p.add_argument( + "--use_tty_only", + type=str2bool, + default=True, + help="If True, the model will use tty_chars for the topline and bottomline. Defaults to `True`", + ) + + +def add_extra_params_general(parser): + """ + Specify any additional command line arguments for NetHack. + """ + p = parser + p.add_argument( + "--model", type=str, default="ChaoticDwarvenGPT5", help="Name of the model. Defaults to `ChaoticDwarvenGPT5`." + ) + p.add_argument( + "--add_stats_to_info", + type=str2bool, + default=True, + help="If True, adds wrapper which loggs additional statisics. Defaults to `True`.", + ) + + +def nethack_override_defaults(_env, parser): + """RL params specific to NetHack envs.""" + # set hyperparameter values to the same as in d&d + parser.set_defaults( + use_record_episode_statistics=False, + gamma=0.999, + num_workers=12, + num_envs_per_worker=2, + worker_num_splits=2, + train_for_env_steps=2_000_000_000, + nonlinearity="relu", + use_rnn=True, + rnn_type="lstm", + actor_critic_share_weights=True, + policy_initialization="orthogonal", + policy_init_gain=1.0, + adaptive_stddev=False, # True only for continous action distributions + reward_scale=1.0, + reward_clip=10.0, + batch_size=1024, + rollout=32, + max_grad_norm=4, + num_epochs=1, + num_batches_per_epoch=1, # can be used for increasing the batch_size for SGD + ppo_clip_ratio=0.1, + ppo_clip_value=1.0, + value_loss_coeff=1.0, + exploration_loss="entropy", + exploration_loss_coeff=0.001, + learning_rate=0.0001, + gae_lambda=1.0, + with_vtrace=False, # in d&d they've used vtrace + normalize_input=False, # turn off for now and use normalization from d&d + normalize_returns=True, + async_rl=True, + experiment_summaries_interval=50, + adam_beta1=0.9, + adam_beta2=0.999, + adam_eps=1e-7, + seed=22, + save_every_sec=120, + ) diff --git a/sf_examples/nethack/nethack_render_utils/CMakeLists.txt b/sf_examples/nethack/nethack_render_utils/CMakeLists.txt new file mode 100644 index 000000000..727be7ad8 --- /dev/null +++ b/sf_examples/nethack/nethack_render_utils/CMakeLists.txt @@ -0,0 +1,12 @@ +cmake_minimum_required(VERSION 3.4...3.18) +project(nethack_render_utils VERSION 0.0.1) + +find_package(pybind11 REQUIRED) +include_directories(${pybind11_INCLUDE_DIR}) + +pybind11_add_module(nethack_render_utils src/main.cpp) + +# EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a +# define (VERSION_INFO) here. +target_compile_definitions(nethack_render_utils + PRIVATE VERSION_INFO=${EXAMPLE_VERSION_INFO}) \ No newline at end of file diff --git a/sf_examples/nethack/nethack_render_utils/Hack-Regular.ttf b/sf_examples/nethack/nethack_render_utils/Hack-Regular.ttf new file mode 100644 index 000000000..097db1814 Binary files /dev/null and b/sf_examples/nethack/nethack_render_utils/Hack-Regular.ttf differ diff --git a/sf_examples/nethack/nethack_render_utils/setup.py b/sf_examples/nethack/nethack_render_utils/setup.py new file mode 100644 index 000000000..aaae05690 --- /dev/null +++ b/sf_examples/nethack/nethack_render_utils/setup.py @@ -0,0 +1,127 @@ +import os +import re +import subprocess +import sys + +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext + +# Convert distutils Windows platform specifiers to CMake -A arguments +PLAT_TO_CMAKE = { + "win32": "Win32", + "win-amd64": "x64", + "win-arm32": "ARM", + "win-arm64": "ARM64", +} + + +# A CMakeExtension needs a sourcedir instead of a file list. +# The name must be the _single_ output extension from the CMake build. +# If you need multiple extensions, see scikit-build. +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class CMakeBuild(build_ext): + def build_extension(self, ext): + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + + # required for auto-detection & inclusion of auxiliary "native" libs + if not extdir.endswith(os.path.sep): + extdir += os.path.sep + + debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug + cfg = "Debug" if debug else "Release" + + # CMake lets you override the generator - we need to check this. + # Can be set with Conda-Build, for example. + cmake_generator = os.environ.get("CMAKE_GENERATOR", "") + + # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON + # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code + # from Python. + cmake_args = [ + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", + f"-DPYTHON_EXECUTABLE={sys.executable}", + f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm + ] + build_args = [] + # Adding CMake arguments set as environment variable + # (needed e.g. to build for ARM OSx on conda-forge) + if "CMAKE_ARGS" in os.environ: + cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] + + # In this example, we pass in the version to C++. You might not need to. + cmake_args += [f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"] + + if self.compiler.compiler_type != "msvc": + # Using Ninja-build since it a) is available as a wheel and b) + # multithreads automatically. MSVC would require all variables be + # exported for Ninja to pick it up, which is a little tricky to do. + # Users can override the generator with CMAKE_GENERATOR in CMake + # 3.15+. + if not cmake_generator: + try: + import ninja # noqa: F401 + + cmake_args += ["-GNinja"] + except ImportError: + pass + + else: + # Single config generators are handled "normally" + single_config = any(x in cmake_generator for x in {"NMake", "Ninja"}) + + # CMake allows an arch-in-generator style for backward compatibility + contains_arch = any(x in cmake_generator for x in {"ARM", "Win64"}) + + # Specify the arch if using MSVC generator, but only if it doesn't + # contain a backward-compatibility arch spec already in the + # generator name. + if not single_config and not contains_arch: + cmake_args += ["-A", PLAT_TO_CMAKE[self.plat_name]] + + # Multi-config generators have a different way to specify configs + if not single_config: + cmake_args += [f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] + build_args += ["--config", cfg] + + if sys.platform.startswith("darwin"): + # Cross-compile support for macOS - respect ARCHFLAGS if set + archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) + if archs: + cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] + + # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level + # across all generators. + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + # self.parallel is a Python 3 only way to set parallel jobs by hand + # using -j in the build_ext call, not supported by pip or PyPA-build. + if hasattr(self, "parallel") and self.parallel: + # CMake 3.12+ only. + build_args += [f"-j{self.parallel}"] + + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + + subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp) + subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp) + + +# The information here can also be placed in setup.cfg - better separation of +# logic and declaration, and simpler if you include description/version in a file. +setup( + name="nethack_render_utils", + version="0.0.1", + author="Eric Hambro", + author_email="ehambro@fb.com", + description="Render NetHack glyphs as a screen", + long_description="", + ext_modules=[CMakeExtension("nethack_render_utils")], + cmdclass={"build_ext": CMakeBuild}, + zip_safe=False, + extras_require={"test": ["pytest>=6.0"]}, + python_requires=">=3.6", +) diff --git a/sf_examples/nethack/nethack_render_utils/src/main.cpp b/sf_examples/nethack/nethack_render_utils/src/main.cpp new file mode 100644 index 000000000..36ff5019d --- /dev/null +++ b/sf_examples/nethack/nethack_render_utils/src/main.cpp @@ -0,0 +1,213 @@ +#include + +#include +#include +#include + +#define STRINGIFY(x) #x +#define MACRO_STRINGIFY(x) STRINGIFY(x) + +namespace py = pybind11; + +template void check_array_types(py::handle h, int min_dim) { + if (h.is_none()) + throw std::invalid_argument("Array is None!"); + ; + if (!py::isinstance(h)) + throw std::invalid_argument("Numpy array required."); + + py::array array = py::array::ensure(h); + if (!array.dtype().is(py::dtype::of())) + throw std::invalid_argument("Buffer dtype mismatch."); + + if (!(array.flags() & py::array::c_style)) + throw std::invalid_argument("Array isn't C contiguous."); + + py::buffer_info buf = array.request(); + if (buf.ndim < min_dim) + throw std::invalid_argument("Wrong ndim in array"); +} + +void check_array_shapes(py::handle tty_chars, py::handle tty_colors, + py::handle tty_cursor, py::handle glyph_images, + py::handle out, int crop_size) { + + const auto &chars_shape = py::array::ensure(tty_chars).request().shape; + const auto &colors_shape = py::array::ensure(tty_colors).request().shape; + if (!std::equal(chars_shape.begin(), chars_shape.end(), colors_shape.begin())) + throw std::invalid_argument("Shape mismatch (tty_chars, tty_colors)."); + + const auto &img_shape = py::array::ensure(glyph_images).request().shape; + if (!(img_shape[0] == 256 && img_shape[1] == 16)) + throw std::invalid_argument("Shape of glyph_images must start (256, 16)"); + + int crop_height = + (crop_size > 0) ? crop_size : chars_shape[chars_shape.size() - 2]; + int crop_width = + (crop_size > 0) ? crop_size : chars_shape[chars_shape.size() - 1]; + const auto &out_shape = py::array::ensure(out).request().shape; + size_t dim = out_shape.size(); + if (!(out_shape[dim - 3] == img_shape[2] && + out_shape[dim - 2] == (crop_height * img_shape[3]) && + out_shape[dim - 1] == (crop_width * img_shape[4]))) + throw std::invalid_argument("Shape mismatch (glyph_images, out)."); + + const auto &cursor_shape = py::array::ensure(tty_cursor).request().shape; + if (!(cursor_shape[cursor_shape.size() - 1] == 2)) + throw std::invalid_argument("Shape of glyph_images must be (2)"); + + if (!(chars_shape.size() - 2 == cursor_shape.size() - 1 && + chars_shape.size() - 2 == out_shape.size() - 3)) { + throw std::invalid_argument("Different dims for batch conversion"); + } + + for (int i = 0; i < chars_shape.size() - 2; ++i) { + if (!(chars_shape[i] == cursor_shape[i] && + chars_shape[i] == out_shape[i])) { + throw std::invalid_argument("Different batch sizes for batch conversion"); + } + } +} + +void tile_crop(py::array_t tty_chars, py::array_t tty_colors, + py::array_t tty_cursor, py::array_t images, + py::array_t out_array, int crop_size) { + + py::buffer_info chars_buff = tty_chars.request(); + py::buffer_info colors_buff = tty_colors.request(); + py::buffer_info cursor_buff = tty_cursor.request(); + py::buffer_info images_buff = images.request(); + py::buffer_info out_buff = out_array.request(); + + const auto &chars_shape = chars_buff.shape; + const auto &img_shape = images_buff.shape; + const auto &out_shape = out_buff.shape; + + int lead_dims = chars_shape.size() - 2; + int lead_elems = 1; + for (int i = 0; i < lead_dims; ++i) { + lead_elems *= chars_shape[i]; + } + + int rows = chars_shape[lead_dims + 0]; + int cols = chars_shape[lead_dims + 1]; + + int img_colors = img_shape[1]; + int img_channels = img_shape[2]; + int img_rows = img_shape[3]; + int img_cols = img_shape[4]; + + int out_chan = out_shape[lead_dims + 0]; + int out_rows = out_shape[lead_dims + 1]; + int out_cols = out_shape[lead_dims + 2]; + + uint8_t *char_ptr = static_cast(chars_buff.ptr); + int8_t *color_ptr = static_cast(colors_buff.ptr); + uint8_t *out_ptr = static_cast(out_buff.ptr); + uint8_t *cur_ptr = static_cast(cursor_buff.ptr); + uint8_t *img_ptr = static_cast(images_buff.ptr); + + int half_crop_size = crop_size / 2; + + // Strides + int s_char_frame = rows * cols; + int s_char_row = cols; + + int s_color_frame = rows * cols; + int s_color_row = cols; + + int s_cursor_frame = 2; + + int s_img_col = img_cols; + int s_img_row = img_rows * img_cols; + int s_img_color = img_channels * img_rows * img_cols; + int s_img_glyph = img_colors * img_channels * img_rows * img_cols; + + int s_out_frame = out_chan * out_rows * out_cols; + int s_out_chan = out_rows * out_cols; + int s_out_row = out_cols; + { + py::gil_scoped_release release; + + for (size_t i = 0; i < lead_elems; ++i) { + auto chars_at = [char_ptr, s_char_row](int h, int w) { + return *(char_ptr + h * s_char_row + w); + }; + auto colors_at = [color_ptr, s_color_row](int h, int w) { + return *(color_ptr + h * s_color_row + w); + }; + auto img_at = [img_ptr, s_img_glyph, s_img_color, s_img_row, + s_img_col](int glyph, int color, int chan, int h, int w) { + return *(img_ptr + glyph * s_img_glyph + color * s_img_color + + chan * s_img_row + h * s_img_col + w); + }; + auto out_ptr_ = [out_ptr, s_out_chan, s_out_row](int chan, int h, int w) { + return (out_ptr + chan * s_out_chan + h * s_out_row + w); + }; + + int start_h = (crop_size > 0) ? cur_ptr[0] - half_crop_size : 0; + int start_w = (crop_size > 0) ? cur_ptr[1] - half_crop_size : 0; + + int max_r = (crop_size > 0) ? crop_size : rows; + int max_c = (crop_size > 0) ? crop_size : cols; + for (size_t r = 0; r < max_r; ++r) { + int h = r + start_h; + for (size_t c = 0; c < max_c; ++c) { + int w = c + start_w; + for (size_t i_chan = 0; i_chan < img_channels; ++i_chan) { + for (size_t i_r = 0; i_r < img_rows; ++i_r) { + for (size_t i_c = 0; i_c < img_cols; ++i_c) { + + if ((h < 0 || h >= rows || w < 0 || w >= cols)) { + *out_ptr_(i_chan, r * img_rows + i_r, c * img_cols + i_c) = 0; + } else { + int this_glyph = chars_at(h, w); + int this_color = colors_at(h, w); + *out_ptr_(i_chan, r * img_rows + i_r, c * img_cols + i_c) = + img_at(this_glyph, this_color, i_chan, i_r, i_c); + } + } + } + } + } + } + char_ptr += s_char_frame; + color_ptr += s_color_frame; + cur_ptr += s_cursor_frame; + out_ptr += s_out_frame; + } + } +} +void render_crop(py::object tty_chars, py::object tty_colors, + py::object tty_cursor, py::object images, py::object out_array, + int crop_size) { + + check_array_types(tty_chars, 2); + check_array_types(tty_colors, 2); + check_array_types(tty_cursor, 1); + check_array_types(images, 5); + check_array_types(out_array, 3); + check_array_shapes(tty_chars, tty_colors, tty_cursor, images, out_array, + crop_size); + + tile_crop(tty_chars, tty_colors, tty_cursor, images, out_array, crop_size); +} + +namespace py = pybind11; + +PYBIND11_MODULE(nethack_render_utils, m) { + m.doc() = R"pbdoc( + A module to turn glyphs into the screen in pixels + ----------------------- + )pbdoc"; + + m.def("render_crop", &render_crop, py::arg("tty_chars"), + py::arg("tty_colors"), py::arg("tty_cursor"), py::arg("images"), + py::arg("out_array"), py::arg("crop_size") = 12); + +#ifdef VERSION_INFO + m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); +#else + m.attr("__version__") = "dev"; +#endif +} diff --git a/sf_examples/nethack/nethack_render_utils/test.py b/sf_examples/nethack/nethack_render_utils/test.py new file mode 100644 index 000000000..e15b429af --- /dev/null +++ b/sf_examples/nethack/nethack_render_utils/test.py @@ -0,0 +1,137 @@ +import sys +import time +from concurrent.futures import ThreadPoolExecutor + +import gym +import nethack_render_utils as m +import nle # NOQA: F401 +import numpy as np +import tqdm +from PIL import Image as im + +sys.path.append("/private/home/ehambro/fair/workspaces/wrapper-hackrl/hackrl") +import wrappers # NOQA: E402 + + +def create_env(): + return wrappers.RenderCharImagesWithNumpyWrapper(gym.make("NetHackChallenge-v0"), blstats_cursor=False) + + +def load_obs(): + e = create_env() + e.reset() + e.step(0) + obs = e.step(1)[0] + obs = e.step(5)[0] + + images = e.char_array.copy() + + return ( + obs["tty_chars"].copy(), + obs["tty_colors"].copy(), + obs["tty_cursor"].copy(), + images, + obs["screen_image"].copy(), + ) + + +def test_main(): + assert m.__version__ == "0.0.1" + assert m.add(1, 2) == 3 + assert m.subtract(1, 2) == -1 + + +np.set_printoptions(threshold=sys.maxsize) + + +def test(_): + obs = [np.ascontiguousarray(x) for x in load_obs()] + chars, colors, cursor, images, screen_image = obs + out = np.zeros_like(screen_image, order="C") + out = np.zeros((3, 72, 72), order="C", dtype=np.uint8) + + m.render_crop(chars, colors, cursor, images, out, screen_image) + + if not np.all(out == screen_image): + scr_im = im.fromarray(np.transpose(screen_image, (1, 2, 0))) + out_im = im.fromarray(np.transpose(out, (1, 2, 0))) + + # saving the final output + # as a PNG file + out_im.save("out_im.png") + scr_im.save("scr_im.png") + print(cursor[1] - 6, cursor[1] + 6) + print( + chars[ + max(cursor[0] - 6, 0) : cursor[0] + 6, + max(cursor[1] - 6, 0) : cursor[1] + 6, + ] + ) + + np.testing.assert_array_equal(out, screen_image) + + +if __name__ == "__main__": + with ThreadPoolExecutor(max_workers=10) as tp: + + def fn(_): + obs = [np.ascontiguousarray(x) for x in load_obs()] + chars, colors, cursor, images, screen_image = obs + + out = np.zeros_like(screen_image, order="C") + m.render_crop(chars, colors, cursor, images, out) + np.testing.assert_array_equal(screen_image, out) + + def fn_batched(_): + obs = [np.ascontiguousarray(x) for x in load_obs()] + chars, colors, cursor, images, screen_image = obs + obs = [np.ascontiguousarray(np.stack([x] * 10)) for x in (chars, colors, cursor, screen_image)] + obs = [np.ascontiguousarray(np.stack([x] * 20)) for x in obs] + (chars, colors, cursor, screen_image) = obs + + out = np.zeros_like(screen_image, order="C") + m.render_crop(chars, colors, cursor, images, out) + np.testing.assert_array_equal(screen_image, out) + + retries = 100 + batch_size = (100, 100) + obs = [] + for _ in range(retries): + this_obs = [np.ascontiguousarray(x) for x in load_obs()] + chars, colors, cursor, images, screen_image = this_obs + z = [np.ascontiguousarray(np.stack([x] * batch_size[0])) for x in (chars, colors, cursor, screen_image)] + z = [np.ascontiguousarray(np.stack([x] * batch_size[1])) for x in z] + (chars, colors, cursor, screen_image) = z + this_obs = chars, colors, cursor, images, screen_image + out = np.zeros_like(screen_image, order="C") + + obs.append((chars, colors, cursor, images, out)) + + print("Testing") + list(map(fn, tqdm.tqdm(range(200)))) + + print("Testing Batched") + list(map(fn_batched, tqdm.tqdm(range(200)))) + + print("Profile Single Thread") + start = time.time() + for o in obs: + chars, colors, cursor, images, out = o + m.render_crop(chars, colors, cursor, images, out) + t = time.time() - start + print("Time:", t) + print("SPS:", retries * batch_size[0] * batch_size[1] / t) + + print("Profile Batch") + start = time.time() + + def _parallel(o, i): + chars, colors, cursor, images, out = o + m.render_crop(chars[i], colors[i], cursor[i], images, out[i]) + + for o in obs: + list(tp.map(_parallel, [(o, i) for i in range(batch_size[0])])) + + t = time.time() - start + print("Time:", t) + print("SPS:", retries * batch_size[0] * batch_size[1] / t) diff --git a/sf_examples/nethack/scripts/sample_env.py b/sf_examples/nethack/scripts/sample_env.py new file mode 100644 index 000000000..5b1b260f0 --- /dev/null +++ b/sf_examples/nethack/scripts/sample_env.py @@ -0,0 +1,34 @@ +import sys + +from sample_factory.algo.utils.rl_utils import make_dones +from sample_factory.envs.create_env import create_env +from sample_factory.utils.utils import log +from sf_examples.nethack.train_nethack import parse_nethack_args, register_nethack_components + + +def main(): + register_nethack_components() + cfg = parse_nethack_args(evaluation=True) + + render_mode = "human" + if cfg.save_video: + render_mode = "rgb_array" + elif cfg.no_render: + render_mode = None + + env = create_env(cfg.env, cfg=cfg, render_mode=render_mode) + + env.seed(0) + env.action_space.seed(0) + + for i in range(10): + env.reset() + done = False + while not done: + obs, rew, terminated, truncated, info = env.step(env.action_space.sample()) + done = make_dones(terminated, truncated) + log.info("Done!") + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/sf_examples/nethack/train_nethack.py b/sf_examples/nethack/train_nethack.py new file mode 100644 index 000000000..133869bce --- /dev/null +++ b/sf_examples/nethack/train_nethack.py @@ -0,0 +1,58 @@ +import sys + +from sample_factory.algo.utils.context import global_model_factory +from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args +from sample_factory.envs.env_utils import register_env +from sample_factory.model.encoder import Encoder +from sample_factory.train import run_rl +from sample_factory.utils.typing import Config, ObsSpace +from sf_examples.nethack.models import MODELS_LOOKUP +from sf_examples.nethack.nethack_env import NETHACK_ENVS, make_nethack_env +from sf_examples.nethack.nethack_params import ( + add_extra_params_general, + add_extra_params_model, + add_extra_params_nethack_env, + nethack_override_defaults, +) + + +def register_nethack_envs(): + for env_name in NETHACK_ENVS.keys(): + register_env(env_name, make_nethack_env) + + +def make_nethack_encoder(cfg: Config, obs_space: ObsSpace) -> Encoder: + """Factory function as required by the API.""" + try: + model_cls = MODELS_LOOKUP[cfg.model] + except KeyError: + raise NotImplementedError("model=%s" % cfg.model) from None + + return model_cls(cfg, obs_space) + + +def register_nethack_components(): + register_nethack_envs() + global_model_factory().register_encoder_factory(make_nethack_encoder) + + +def parse_nethack_args(argv=None, evaluation=False): + parser, partial_cfg = parse_sf_args(argv=argv, evaluation=evaluation) + add_extra_params_nethack_env(parser) + add_extra_params_model(parser) + add_extra_params_general(parser) + nethack_override_defaults(partial_cfg.env, parser) + final_cfg = parse_full_cfg(parser, argv) + return final_cfg + + +def main(): # pragma: no cover + """Script entry point.""" + register_nethack_components() + cfg = parse_nethack_args() + status = run_rl(cfg) + return status + + +if __name__ == "__main__": # pragma: no cover + sys.exit(main()) diff --git a/sf_examples/nethack/utils/task_rewards.py b/sf_examples/nethack/utils/task_rewards.py new file mode 100644 index 000000000..99daa8fc4 --- /dev/null +++ b/sf_examples/nethack/utils/task_rewards.py @@ -0,0 +1,168 @@ +import re + +import numpy as np +from nle import nethack + + +class Score: + def __init__(self): + self.score = 0 + # convert name to snake_case + # https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case + self.name = re.sub("(?!^)([A-Z]+)", r"_\1", self.__class__.__name__).lower() + + def reset_score(self): + self.score = 0 + + +class GoldScore(Score): + def reward(self, env, last_observation, observation, end_status): + old_blstats = last_observation[env._blstats_index] + blstats = observation[env._blstats_index] + + old_gold = old_blstats[nethack.NLE_BL_GOLD] + gold = blstats[nethack.NLE_BL_GOLD] + + reward = np.abs(gold - old_gold) + self.score += reward + + return reward + + +class EatingScore(Score): + def reward(self, env, last_observation, observation, end_status): + old_internal = last_observation[env._internal_index] + internal = observation[env._internal_index] + + reward = max(0, internal[7] - old_internal[7]) + self.score += reward + + return reward + + +class ScoutScore(Score): + def __init__(self): + super().__init__() + self.dungeon_explored = {} + + def reward(self, env, last_observation, observation, end_status): + glyphs = observation[env._glyph_index] + blstats = observation[env._blstats_index] + + dungeon_num = blstats[nethack.NLE_BL_DNUM] + dungeon_level = blstats[nethack.NLE_BL_DLEVEL] + + key = (dungeon_num, dungeon_level) + explored = np.sum(glyphs != nethack.GLYPH_CMAP_OFF) + explored_old = 0 + if key in self.dungeon_explored: + explored_old = self.dungeon_explored[key] + reward = explored - explored_old + self.dungeon_explored[key] = explored + self.score += reward + + return reward + + def reset_score(self): + super().reset_score() + self.dungeon_explored = {} + + +class StaircaseScore(Score): + """ + This task requires the agent to get on top of a staircase down (>). + The reward function is :math:`I`, where :math:`I` is 1 if the + task is successful, and 0 otherwise. + """ + + def reward(self, env, last_observation, observation, end_status): + internal = observation[env._internal_index] + stairs_down = internal[4] + + reward = 1 if stairs_down else 0 + self.score += reward + + return reward + + +class StaircasePetScore(Score): + """ + This task requires the agent to get on top of a staircase down (>), while + having their pet next to it. See `NetHackStaircase` for the reward function. + """ + + def reward(self, env, last_observation, observation, end_status): + internal = observation[env._internal_index] + stairs_down = internal[4] + + reward = 0 + if stairs_down: + glyphs = observation[env._glyph_index] + blstats = observation[env._blstats_index] + x, y = blstats[:2] + + neighbors = glyphs[y - 1 : y + 2, x - 1 : x + 2] + if np.any(nethack.glyph_is_pet(neighbors)): + reward = 1 + + self.score += reward + + return reward + + +class SokobanfillpitScore(Score): + """ + This task requires the agent to put the boulders inside wholes for sokoban. + We count each successful boulder moved into a whole as a total reward. + """ + + def reward(self, env, last_observation, observation, end_status): + # the score counts how many pits we fill + char_array = [chr(i) for i in observation[env._message_index]] + message = "".join(char_array) + + if message.startswith("The boulder fills a pit.") or message.startswith( + "The boulder falls into and plugs a whole in the floor!" + ): + reward = 1 + else: + reward = 0 + self.score += reward + + return reward + + +class SokobansolvedlevelsScore(Score): + def __init__(self): + super().__init__() + self.sokoban_levels = {} + + def reward(self, env, last_observation, observation, end_status): + glyphs = observation[env._glyph_index] + blstats = observation[env._blstats_index] + + dungeon_num = blstats[nethack.NLE_BL_DNUM] + dungeon_level = blstats[nethack.NLE_BL_DLEVEL] + + # when we know that this is sokoban + if dungeon_num == 4: + # TODO: maybe we should count "solving" sokoban level when we reach the next level of the sokoban? + # checking if all pits are solved can be buggy if the glyphs have different values on other levels + + # count the number of pits, glyphs SS.S_pit + pits = np.isin(glyphs, [2411]).sum() + key = (dungeon_num, dungeon_level) + self.sokoban_levels[key] = pits + + def reset_score(self): + super().reset_score() + self.sokoban_levels = {} + + @property + def score(self): + score = 0 + for pits in self.sokoban_levels.values(): + # when all pits are filled we assume that sokoban level is solved + if pits == 0: + score += 1 + return score diff --git a/sf_examples/nethack/utils/wrappers/__init__.py b/sf_examples/nethack/utils/wrappers/__init__.py new file mode 100644 index 000000000..6867f54dd --- /dev/null +++ b/sf_examples/nethack/utils/wrappers/__init__.py @@ -0,0 +1,13 @@ +from sf_examples.nethack.utils.wrappers.blstats_info import BlstatsInfoWrapper +from sf_examples.nethack.utils.wrappers.prev_actions import PrevActionsWrapper +from sf_examples.nethack.utils.wrappers.screen_image import RenderCharImagesWithNumpyWrapperV2 +from sf_examples.nethack.utils.wrappers.seed_action_space import SeedActionSpaceWrapper +from sf_examples.nethack.utils.wrappers.task_rewards import TaskRewardsInfoWrapper + +__all__ = [ + RenderCharImagesWithNumpyWrapperV2, + PrevActionsWrapper, + TaskRewardsInfoWrapper, + BlstatsInfoWrapper, + SeedActionSpaceWrapper, +] diff --git a/sf_examples/nethack/utils/wrappers/blstats_info.py b/sf_examples/nethack/utils/wrappers/blstats_info.py new file mode 100644 index 000000000..faef9fdd6 --- /dev/null +++ b/sf_examples/nethack/utils/wrappers/blstats_info.py @@ -0,0 +1,36 @@ +from collections import namedtuple + +import gym + +BLStats = namedtuple( + "BLStats", + "x y strength_percentage strength dexterity constitution intelligence wisdom charisma score hitpoints max_hitpoints depth gold energy max_energy armor_class monster_level experience_level experience_points time hunger_state carrying_capacity dungeon_number level_number prop_mask align_bits", +) + + +class BlstatsInfoWrapper(gym.Wrapper): + def step(self, action): + # because we will see done=True at the first timestep of the new episode + # to properly calculate blstats at the end of the episode we need to keep the last_observation around + last_observation = tuple(a.copy() for a in self.env.unwrapped.last_observation) + obs, reward, done, info = self.env.step(action) + + if done: + info["episode_extra_stats"] = self.add_more_stats(info, last_observation) + + return obs, reward, done, info + + def add_more_stats(self, info, last_observation): + extra_stats = info.get("episode_extra_stats", {}) + blstats = BLStats(*last_observation[self.env.unwrapped._blstats_index]) + new_extra_stats = { + "score": blstats.score, + "turns": blstats.time, + "dlvl": blstats.depth, + "max_hitpoints": blstats.max_hitpoints, + "max_energy": blstats.max_energy, + "armor_class": blstats.armor_class, + "experience_level": blstats.experience_level, + "experience_points": blstats.experience_points, + } + return {**extra_stats, **new_extra_stats} diff --git a/sf_examples/nethack/utils/wrappers/prev_actions.py b/sf_examples/nethack/utils/wrappers/prev_actions.py new file mode 100644 index 000000000..0afb67793 --- /dev/null +++ b/sf_examples/nethack/utils/wrappers/prev_actions.py @@ -0,0 +1,24 @@ +import gym +import numpy as np + + +class PrevActionsWrapper(gym.Wrapper): + def __init__(self, env): + super().__init__(env) + self.prev_action = 0 + + obs_spaces = {"prev_actions": self.env.action_space} + obs_spaces.update([(k, self.env.observation_space[k]) for k in self.env.observation_space]) + self.observation_space = gym.spaces.Dict(obs_spaces) + + def reset(self, **kwargs): + self.prev_action = 0 + obs = self.env.reset(**kwargs) + obs["prev_actions"] = np.array([self.prev_action]) + return obs + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self.prev_action = action + obs["prev_actions"] = np.array([self.prev_action]) + return obs, reward, done, info diff --git a/sf_examples/nethack/utils/wrappers/screen_image.py b/sf_examples/nethack/utils/wrappers/screen_image.py new file mode 100644 index 000000000..fae99f25d --- /dev/null +++ b/sf_examples/nethack/utils/wrappers/screen_image.py @@ -0,0 +1,310 @@ +"""Taken & adapted from Chaos Dwarf in Nethack Challenge Starter Kit: +https://github.com/Miffyli/nle-sample-factory-baseline +MIT License +Copyright (c) 2021 Anssi +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import os +from typing import Any, SupportsFloat + +import cv2 +import gym +import nethack_render_utils +import numpy as np +from nle import nethack +from numba import njit +from PIL import Image, ImageDraw, ImageFont + +SMALL_FONT_PATH = os.path.abspath("sf_examples/nethack/nethack_render_utils/Hack-Regular.ttf") + + +# Mapping of 0-15 colors used. +# Taken from bottom image here. It seems about right +# https://i.stack.imgur.com/UQVe5.png +COLORS = [ + "#000000", + "#800000", + "#008000", + "#808000", + "#000080", + "#800080", + "#008080", + "#808080", # - flipped these ones around + "#C0C0C0", # | the gray-out dull stuff + "#FF0000", + "#00FF00", + "#FFFF00", + "#0000FF", + "#FF00FF", + "#00FFFF", + "#FFFFFF", +] + + +@njit +def _tile_characters_to_image( + out_image, + chars, + colors, + output_height_chars, + output_width_chars, + char_array, + offset_h, + offset_w, +): + """ + Build an image using cached images of characters in char_array to out_image + """ + char_height = char_array.shape[3] + char_width = char_array.shape[4] + for h in range(output_height_chars): + h_char = h + offset_h + # Stuff outside boundaries is not visible, so + # just leave it black + if h_char < 0 or h_char >= chars.shape[0]: + continue + for w in range(output_width_chars): + w_char = w + offset_w + if w_char < 0 or w_char >= chars.shape[1]: + continue + char = chars[h_char, w_char] + color = colors[h_char, w_char] + h_pixel = h * char_height + w_pixel = w * char_width + out_image[:, h_pixel : h_pixel + char_height, w_pixel : w_pixel + char_width] = char_array[char, color] + + +def _initialize_char_array(font_size, rescale_font_size): + """Draw all characters in PIL and cache them in numpy arrays + if rescale_font_size is given, assume it is (width, height) + Returns a np array of (num_chars, num_colors, char_height, char_width, 3) + """ + font = ImageFont.truetype(SMALL_FONT_PATH, font_size) + dummy_text = "".join([(chr(i) if chr(i).isprintable() else " ") for i in range(256)]) + bboxes = np.array([font.getbbox(char) for char in dummy_text]) + image_width = bboxes[:, 2].max() + image_height = bboxes[:, 3].max() + + char_width = rescale_font_size[0] + char_height = rescale_font_size[1] + char_array = np.zeros((256, 16, char_height, char_width, 3), dtype=np.uint8) + + for color_index in range(16): + for char_index in range(256): + char = dummy_text[char_index] + + image = Image.new("RGB", (image_width, image_height)) + image_draw = ImageDraw.Draw(image) + image_draw.rectangle((0, 0, image_width, image_height), fill=(0, 0, 0)) + + _, _, width, height = font.getbbox(char) + x = (image_width - width) // 2 + y = (image_height - height) // 2 + image_draw.text((x, y), char, font=font, fill=COLORS[color_index]) + + arr = np.array(image).copy() + if rescale_font_size: + arr = cv2.resize(arr, rescale_font_size, interpolation=cv2.INTER_AREA) + char_array[char_index, color_index] = arr + + return char_array + + +class RenderCharImagesWithNumpyWrapper(gym.Wrapper): + """ + Render characters as images, using PIL to render characters like we humans see on screen + but then some caching and numpy stuff to speed up things. + To speed things up, crop image around the player. + """ + + def __init__( + self, + env, + font_size=9, + crop_size=12, + rescale_font_size=(6, 6), + blstats_cursor=False, + ): + super().__init__(env) + self.char_array = _initialize_char_array(font_size, rescale_font_size) + self.char_height = self.char_array.shape[2] + self.char_width = self.char_array.shape[3] + # Transpose for CHW + self.char_array = self.char_array.transpose(0, 1, 4, 2, 3) + + self.crop_size = crop_size + self.blstats_cursor = blstats_cursor + + self.half_crop_size = crop_size // 2 + self.output_height_chars = crop_size + self.output_width_chars = crop_size + self.chw_image_shape = ( + 3, + self.output_height_chars * self.char_height, + self.output_width_chars * self.char_width, + ) + + obs_spaces = {"screen_image": gym.spaces.Box(low=0, high=255, shape=self.chw_image_shape, dtype=np.uint8)} + obs_spaces.update( + [ + (k, self.env.observation_space[k]) + for k in self.env.observation_space + if k not in ["tty_chars", "tty_colors"] + ] + ) + self.observation_space = gym.spaces.Dict(obs_spaces) + + def _render_text_to_image(self, obs): + chars = obs["tty_chars"] + colors = obs["tty_colors"] + offset_w = 0 + offset_h = 0 + if self.crop_size: + # Center around player + if self.blstats_cursor: + center_x, center_y = obs["blstats"][:2] + else: + center_y, center_x = obs["tty_cursor"] + offset_h = center_y - self.half_crop_size + offset_w = center_x - self.half_crop_size + + out_image = np.zeros(self.chw_image_shape, dtype=np.uint8) + + _tile_characters_to_image( + out_image=out_image, + chars=chars, + colors=colors, + output_height_chars=self.output_height_chars, + output_width_chars=self.output_width_chars, + char_array=self.char_array, + offset_h=offset_h, + offset_w=offset_w, + ) + + obs["screen_image"] = out_image + return obs + + def step(self, action): + obs, reward, done, info = self.env.step(action) + obs = self._render_text_to_image(obs) + return obs, reward, done, info + + def reset(self, **kwargs): + obs = self.env.reset(**kwargs) + obs = self._render_text_to_image(obs) + return obs + + +class RenderCharImagesWithNumpyWrapperV2(gym.Wrapper): + """ + Same as V1, but simpler and faster. + """ + + def __init__( + self, + env, + font_size=9, + crop_size=12, + rescale_font_size=(6, 6), + render_font_size=(6, 11), + ): + super().__init__(env) + self.char_array = _initialize_char_array(font_size, rescale_font_size) + self.char_height = self.char_array.shape[2] + self.char_width = self.char_array.shape[3] + # Transpose for CHW + self.char_array = self.char_array.transpose(0, 1, 4, 2, 3) + self.char_array = np.ascontiguousarray(self.char_array) + self.crop_size = crop_size + + crop_rows = crop_size or nethack.nethack.TERMINAL_SHAPE[0] + crop_cols = crop_size or nethack.nethack.TERMINAL_SHAPE[1] + + self.chw_image_shape = ( + 3, + crop_rows * self.char_height, + crop_cols * self.char_width, + ) + + obs_spaces = {"screen_image": gym.spaces.Box(low=0, high=255, shape=self.chw_image_shape, dtype=np.uint8)} + obs_spaces.update( + [ + (k, self.env.observation_space[k]) + for k in self.env.observation_space + # if k not in ["tty_chars", "tty_colors"] + ] + ) + self.observation_space = gym.spaces.Dict(obs_spaces) + + self.render_char_array = _initialize_char_array(font_size, render_font_size) + self.render_char_array = self.render_char_array.transpose(0, 1, 4, 2, 3) + self.render_char_array = np.ascontiguousarray(self.render_char_array) + + def _populate_obs(self, obs): + screen = np.zeros(self.chw_image_shape, order="C", dtype=np.uint8) + nethack_render_utils.render_crop( + obs["tty_chars"], + obs["tty_colors"], + obs["tty_cursor"], + self.char_array, + screen, + crop_size=self.crop_size, + ) + obs["screen_image"] = screen + + def step(self, action: Any) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]: + obs, reward, done, info = self.env.step(action) + self._populate_obs(obs) + return obs, reward, done, info + + def reset(self, **kwargs): + obs = self.env.reset(**kwargs) + self._populate_obs(obs) + return obs + + def render(self, mode="human"): + if mode == "rgb_array": + if not self.unwrapped.last_observation: + return + + # TODO: we don't crop but additionally we could show what model sees + obs = self.unwrapped.last_observation + tty_chars = obs[self.unwrapped._observation_keys.index("tty_chars")] + tty_colors = obs[self.unwrapped._observation_keys.index("tty_colors")] + + chw_image_shape = ( + 3, + nethack.nethack.TERMINAL_SHAPE[0] * self.render_char_array.shape[3], + nethack.nethack.TERMINAL_SHAPE[1] * self.render_char_array.shape[4], + ) + out_image = np.zeros(chw_image_shape, dtype=np.uint8) + + _tile_characters_to_image( + out_image=out_image, + chars=tty_chars, + colors=tty_colors, + output_height_chars=nethack.nethack.TERMINAL_SHAPE[0], + output_width_chars=nethack.nethack.TERMINAL_SHAPE[1], + char_array=self.render_char_array, + offset_h=0, + offset_w=0, + ) + + return out_image + else: + return self.env.render() diff --git a/sf_examples/nethack/utils/wrappers/seed_action_space.py b/sf_examples/nethack/utils/wrappers/seed_action_space.py new file mode 100644 index 000000000..2ebf363dc --- /dev/null +++ b/sf_examples/nethack/utils/wrappers/seed_action_space.py @@ -0,0 +1,14 @@ +from typing import Any + +import gymnasium as gym + + +class SeedActionSpaceWrapper(gym.Wrapper): + """ + To have reproducible decorrelate experience we need to seed action space + """ + + def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + obs, info = self.env.reset(seed=seed, options=options) + self.action_space.seed(seed=seed) + return obs, info diff --git a/sf_examples/nethack/utils/wrappers/task_rewards.py b/sf_examples/nethack/utils/wrappers/task_rewards.py new file mode 100644 index 000000000..9b58722f3 --- /dev/null +++ b/sf_examples/nethack/utils/wrappers/task_rewards.py @@ -0,0 +1,55 @@ +import gym + +from sf_examples.nethack.utils.task_rewards import ( + EatingScore, + GoldScore, + ScoutScore, + SokobanfillpitScore, + SokobansolvedlevelsScore, + StaircasePetScore, + StaircaseScore, +) + + +class TaskRewardsInfoWrapper(gym.Wrapper): + def __init__(self, env: gym.Env): + super().__init__(env) + + self.tasks = [ + EatingScore(), + GoldScore(), + ScoutScore(), + SokobanfillpitScore(), + # SokobansolvedlevelsScore(), # TODO: it could have bugs, for now turn off + StaircasePetScore(), + StaircaseScore(), + ] + + def reset(self, **kwargs): + obs = self.env.reset(**kwargs) + + for task in self.tasks: + task.reset_score() + + return obs + + def step(self, action): + # use tuple and copy to avoid shallow copy (`last_observation` would be the same as `observation`) + last_observation = tuple(a.copy() for a in self.env.unwrapped.last_observation) + obs, reward, done, info = self.env.step(action) + observation = tuple(a.copy() for a in self.env.unwrapped.last_observation) + end_status = info["end_status"] + + if done: + info["episode_extra_stats"] = self.add_more_stats(info) + + # we will accumulate rewards for each step and log them when done signal appears + for task in self.tasks: + task.reward(self.env.unwrapped, last_observation, observation, end_status) + + return obs, reward, done, info + + def add_more_stats(self, info): + extra_stats = info.get("episode_extra_stats", {}) + new_extra_stats = {task.name: task.score for task in self.tasks} + return {**extra_stats, **new_extra_stats}