diff --git a/docs/09-environment-integrations/nethack.md b/docs/09-environment-integrations/nethack.md
new file mode 100644
index 000000000..5dd1f5526
--- /dev/null
+++ b/docs/09-environment-integrations/nethack.md
@@ -0,0 +1,115 @@
+# NetHack
+
+## Installation
+Works in `Python 3.10`. Higher versions have problems with building NLE.
+
+To install NetHack, you need nle and its dependencies.
+
+```bash
+# nle dependencies
+apt-get install build-essential python3-dev python3-pip python3-numpy autoconf libtool pkg-config libbz2-dev
+conda install cmake flex bison lit
+
+# install nle locally and modify it to enable seeding and handle rendering with gymnasium
+git clone https://github.com/facebookresearch/nle.git nle && cd nle \
+&& git checkout v0.9.0 && git submodule init && git submodule update --recursive \
+&& sed '/#define NLE_ALLOW_SEEDING 1/i#define NLE_ALLOW_SEEDING 1' include/nleobs.h -i \
+&& sed '/self\.nethack\.set_initial_seeds = f/d' nle/env/tasks.py -i \
+&& sed '/self\.nethack\.set_current_seeds = f/d' nle/env/tasks.py -i \
+&& sed '/self\.nethack\.get_current_seeds = f/d' nle/env/tasks.py -i \
+&& sed '/def seed(self, core=None, disp=None, reseed=True):/d' nle/env/tasks.py -i \
+&& sed '/raise RuntimeError("NetHackChallenge doesn.t allow seed changes")/d' nle/env/tasks.py -i \
+&& sed -i '/def render(self, mode="human"):/a\ if not self.last_observation:\n return' nle/env/base.py \
+&& python setup.py install && cd ..
+
+# install sample factory with nethack extras
+pip install -e .[nethack]
+conda install -c conda-forge pybind11
+pip install -e sf_examples/nethack/nethack_render_utils
+```
+
+## Running Experiments
+
+Run NetHack experiments with the scripts in `sf_examples.nethack`.
+The default parameters have been chosen to match [dungeons & data](https://github.com/dungeonsdatasubmission/dungeonsdata-neurips2022) which is based on [nle sample factory baseline](https://github.com/Miffyli/nle-sample-factory-baseline). By moving from D&D to sample factory we've managed to increase the APPO score from 2k to 2.8k.
+
+To train a model in the `nethack_challenge` environment:
+
+```
+python -m sf_examples.nethack.train_nethack \
+ --env=nethack_challenge \
+ --batch_size=4096 \
+ --num_workers=16 \
+ --num_envs_per_worker=32 \
+ --worker_num_splits=2 \
+ --rollout=32 \
+ --character=mon-hum-neu-mal \
+ --model=ChaoticDwarvenGPT5 \
+ --rnn_size=512 \
+ --experiment=nethack_monk
+```
+
+To visualize the training results, use the `enjoy_nethack` script:
+
+```
+python -m sf_examples.nethack.enjoy_nethack --env=nethack_challenge --character=mon-hum-neu-mal --experiment=nethack_monk
+```
+
+Additionally it's possible to use an alternative `fast_eval_nethack` script which is much faster
+
+```
+python -m sf_examples.nethack.fast_eval_nethack --env=nethack_challenge --sample_env_episodes=128 --num_workers=16 --num_envs_per_worker=2 --character=mon-hum-neu-mal --experiment=nethack_monk
+```
+
+### List of Supported Environments
+
+- nethack_staircase
+- nethack_score
+- nethack_pet
+- nethack_oracle
+- nethack_gold
+- nethack_eat
+- nethack_scout
+- nethack_challenge
+
+## Results
+
+### Reports
+1. Sample Factory was benchmarked on `nethack_challenge` against Dungeons and Data. Sample-Factory was able to achieve similar sample efficiency as D&D using the same parameters and get better running returns (2.8k vs 2k). Training was done on `nethack_challenge` with human-monk character for 2B env steps.
+ - https://api.wandb.ai/links/bartekcupial/w69fid1w
+
+### Models
+Sample Factory APPO model trained on `nethack_challenge` environment is uploaded to the HuggingFace Hub. The model have been trained for 2B steps.
+
+The model below is the best model from the experiment against Dungeons and Data above. The evaluation metrics here are obtained by running the model 1024 times.
+
+Model card: https://huggingface.co/LLParallax/sample_factory_human_monk
+Evaluation results:
+```
+{
+ "reward/reward": 3245.3828125,
+ "reward/reward_min": 20.0,
+ "reward/reward_max": 18384.0,
+ "len/len": 2370.4560546875,
+ "len/len_min": 27.0,
+ "len/len_max": 21374.0,
+ "policy_stats/avg_score": 3245.4716796875,
+ "policy_stats/avg_turns": 14693.970703125,
+ "policy_stats/avg_dlvl": 1.13671875,
+ "policy_stats/avg_max_hitpoints": 46.42578125,
+ "policy_stats/avg_max_energy": 34.00390625,
+ "policy_stats/avg_armor_class": 4.68359375,
+ "policy_stats/avg_experience_level": 6.13671875,
+ "policy_stats/avg_experience_points": 663.375,
+ "policy_stats/avg_eating_score": 14063.2587890625,
+ "policy_stats/avg_gold_score": 76.033203125,
+ "policy_stats/avg_scout_score": 499.0478515625,
+ "policy_stats/avg_sokobanfillpit_score": 0.0,
+ "policy_stats/avg_staircase_pet_score": 0.005859375,
+ "policy_stats/avg_staircase_score": 4.9970703125,
+ "policy_stats/avg_episode_number": 1.5,
+ "policy_stats/avg_true_objective": 3245.3828125,
+ "policy_stats/avg_true_objective_min": 20.0,
+ "policy_stats/avg_true_objective_max": 18384.0
+}
+```
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 46254b904..a85ae51c2 100644
--- a/setup.py
+++ b/setup.py
@@ -14,6 +14,16 @@
_atari_deps = ["gymnasium[atari, accept-rom-license]"]
_mujoco_deps = ["gymnasium[mujoco]", "mujoco<2.5"]
+_nethack_deps = [
+ "numba ~= 0.58",
+ "pandas ~= 2.1",
+ "matplotlib ~= 3.8",
+ "seaborn ~= 0.12",
+ "scipy ~= 1.11",
+ "shimmy",
+ "tqdm ~= 4.66",
+ "debugpy ~= 1.6",
+]
_envpool_deps = ["envpool"]
_docs_deps = [
@@ -67,6 +77,7 @@
"atari": _atari_deps,
"envpool": _envpool_deps,
"mujoco": _mujoco_deps,
+ "nethack": _nethack_deps,
"vizdoom": ["vizdoom<2.0", "gymnasium[classic_control]"],
# "dmlab": ["dm_env"], <-- these are just auxiliary packages, the main package has to be built from sources
},
diff --git a/sf_examples/nethack/enjoy_nethack.py b/sf_examples/nethack/enjoy_nethack.py
new file mode 100644
index 000000000..443618e50
--- /dev/null
+++ b/sf_examples/nethack/enjoy_nethack.py
@@ -0,0 +1,30 @@
+import sys
+
+from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args
+from sample_factory.enjoy import enjoy
+from sf_examples.nethack.nethack_params import (
+ add_extra_params_general,
+ add_extra_params_model,
+ add_extra_params_nethack_env,
+ nethack_override_defaults,
+)
+from sf_examples.nethack.train_nethack import register_nethack_components
+
+
+def main(): # pragma: no cover
+ """Script entry point."""
+ register_nethack_components()
+
+ parser, cfg = parse_sf_args(evaluation=True)
+ add_extra_params_nethack_env(parser)
+ add_extra_params_model(parser)
+ add_extra_params_general(parser)
+ nethack_override_defaults(cfg.env, parser)
+ cfg = parse_full_cfg(parser)
+
+ status = enjoy(cfg)
+ return status
+
+
+if __name__ == "__main__": # pragma: no cover
+ sys.exit(main())
diff --git a/sf_examples/nethack/fast_eval_nethack.py b/sf_examples/nethack/fast_eval_nethack.py
new file mode 100644
index 000000000..d05572698
--- /dev/null
+++ b/sf_examples/nethack/fast_eval_nethack.py
@@ -0,0 +1,35 @@
+import sys
+
+from sample_factory.cfg.arguments import checkpoint_override_defaults, parse_full_cfg, parse_sf_args
+from sample_factory.eval import do_eval
+from sf_examples.nethack.nethack_params import (
+ add_extra_params_general,
+ add_extra_params_model,
+ add_extra_params_nethack_env,
+ nethack_override_defaults,
+)
+from sf_examples.nethack.train_nethack import register_nethack_components
+
+
+def main(): # pragma: no cover
+ """Script entry point."""
+ register_nethack_components()
+
+ parser, cfg = parse_sf_args(evaluation=True)
+ add_extra_params_nethack_env(parser)
+ add_extra_params_model(parser)
+ add_extra_params_general(parser)
+ nethack_override_defaults(cfg.env, parser)
+
+ # important, instead of `load_from_checkpoint` as in enjoy we want
+ # to override it here to be able to use argv arguments
+ checkpoint_override_defaults(cfg, parser)
+
+ cfg = parse_full_cfg(parser)
+
+ status = do_eval(cfg)
+ return status
+
+
+if __name__ == "__main__": # pragma: no cover
+ sys.exit(main())
diff --git a/sf_examples/nethack/models/__init__.py b/sf_examples/nethack/models/__init__.py
new file mode 100644
index 000000000..70c1a1138
--- /dev/null
+++ b/sf_examples/nethack/models/__init__.py
@@ -0,0 +1,6 @@
+from sf_examples.nethack.models.chaotic_dwarf import ChaoticDwarvenGPT5
+
+MODELS = [
+ ChaoticDwarvenGPT5,
+]
+MODELS_LOOKUP = {c.__name__: c for c in MODELS}
diff --git a/sf_examples/nethack/models/chaotic_dwarf.py b/sf_examples/nethack/models/chaotic_dwarf.py
new file mode 100644
index 000000000..b2d0df458
--- /dev/null
+++ b/sf_examples/nethack/models/chaotic_dwarf.py
@@ -0,0 +1,304 @@
+"""Adapted from Chaos Dwarf in Nethack Challenge Starter Kit:
+https://github.com/Miffyli/nle-sample-factory-baseline
+
+MIT License
+
+Copyright (c) 2021 Anssi
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+"""
+import torch
+from nle import nethack
+from torch import nn
+from torch.nn import functional as F
+
+from sample_factory.algo.utils.torch_utils import calc_num_elements
+from sample_factory.model.encoder import Encoder
+from sample_factory.utils.typing import Config, ObsSpace
+
+
+class MessageEncoder(nn.Module):
+ def __init__(self):
+ super(MessageEncoder, self).__init__()
+ self.hidden_dim = 128
+ self.msg_fwd = nn.Sequential(
+ nn.Linear(nethack.MESSAGE_SHAPE[0], 128),
+ nn.ELU(inplace=True),
+ nn.Linear(128, self.hidden_dim),
+ nn.ELU(inplace=True),
+ )
+
+ def forward(self, message):
+ return self.msg_fwd(message / 255.0)
+
+
+class BLStatsEncoder(nn.Module):
+ def __init__(self):
+ super(BLStatsEncoder, self).__init__()
+ self.hidden_dim = 128 + nethack.BLSTATS_SHAPE[0]
+ self.blstats_fwd = nn.Sequential(
+ nn.Linear(nethack.BLSTATS_SHAPE[0], 128),
+ nn.ELU(inplace=True),
+ nn.Linear(128, 128),
+ nn.ELU(inplace=True),
+ )
+
+ normalization_stats = torch.tensor(
+ [
+ 1.0 / 79.0, # hero col
+ 1.0 / 21, # hero row
+ 0.0, # strength pct
+ 1.0 / 10, # strength
+ 1.0 / 10, # dexterity
+ 1.0 / 10, # constitution
+ 1.0 / 10, # intelligence
+ 1.0 / 10, # wisdom
+ 1.0 / 10, # charisma
+ 0.0, # score
+ 1.0 / 10, # hitpoints
+ 1.0 / 10, # max hitpoints
+ 0.0, # depth
+ 1.0 / 1000, # gold
+ 1.0 / 10, # energy
+ 1.0 / 10, # max energy
+ 1.0 / 10, # armor class
+ 0.0, # monster level
+ 1.0 / 10, # experience level
+ 1.0 / 100, # experience points
+ 1.0 / 1000, # time
+ 1.0, # hunger_state
+ 1.0 / 10, # carrying capacity
+ 0.0, # carrying capacity
+ 0.0, # level number
+ 0.0, # condition bits
+ 0.0, # alignment bits
+ ],
+ requires_grad=False,
+ )
+ self.register_buffer("normalization_stats", normalization_stats)
+
+ self.blstat_range = (-5, 5)
+
+ def forward(self, blstats):
+ norm_bls = torch.clip(
+ blstats * self.normalization_stats,
+ self.blstat_range[0],
+ self.blstat_range[1],
+ )
+
+ return torch.cat([self.blstats_fwd(norm_bls), norm_bls], dim=-1)
+
+
+class TopLineEncoder(nn.Module):
+ def __init__(self):
+ super(TopLineEncoder, self).__init__()
+ self.hidden_dim = 128
+ self.i_dim = nethack.NLE_TERM_CO * 256
+
+ self.msg_fwd = nn.Sequential(
+ nn.Linear(self.i_dim, 128),
+ nn.ELU(inplace=True),
+ nn.Linear(128, self.hidden_dim),
+ nn.ELU(inplace=True),
+ )
+
+ def forward(self, message):
+ # Characters start at 33 in ASCII and go to 128. 96 = 128 - 32
+ message_normed = F.one_hot((message).long(), 256).reshape(-1, self.i_dim).float()
+ return self.msg_fwd(message_normed)
+
+
+class BottomLinesEncoder(nn.Module):
+ def __init__(self):
+ super(BottomLinesEncoder, self).__init__()
+ self.conv_layers = []
+ w = nethack.NLE_TERM_CO * 2
+ for in_ch, out_ch, filter, stride in [[2, 32, 8, 4], [32, 64, 4, 1]]:
+ self.conv_layers.append(nn.Conv1d(in_ch, out_ch, filter, stride=stride))
+ self.conv_layers.append(nn.ELU(inplace=True))
+ w = conv_outdim(w, filter, padding=0, stride=stride)
+
+ self.conv_net = nn.Sequential(*self.conv_layers)
+ self.fwd_net = nn.Sequential(
+ nn.Linear(w * out_ch, 128),
+ nn.ELU(),
+ nn.Linear(128, 128),
+ nn.ELU(),
+ )
+ self.hidden_dim = 128
+
+ def forward(self, bottom_lines):
+ B, D = bottom_lines.shape
+ # ASCII 32: ' ', ASCII [33-128]: visible characters
+ chars_normalised = (bottom_lines - 32) / 96
+
+ # ASCII [45-57]: -./01234556789
+ numbers_mask = (bottom_lines > 44) * (bottom_lines < 58)
+ digits_normalised = numbers_mask * (bottom_lines - 47) / 10
+
+ # Put in different channels & conv (B, 2, D)
+ x = torch.stack([chars_normalised, digits_normalised], dim=1)
+ return self.fwd_net(self.conv_net(x).view(B, -1))
+
+
+def conv_outdim(i_dim, k, padding=0, stride=1, dilation=1):
+ """Return the dimension after applying a convolution along one axis"""
+ return int(1 + (i_dim + 2 * padding - dilation * (k - 1) - 1) / stride)
+
+
+class InverseModel(nn.Module):
+ def __init__(self, h_dim, action_space):
+ super(InverseModel, self).__init__()
+ self.h_dim = h_dim * 2
+ self.action_space = action_space
+
+ self.fwd_model = nn.Sequential(
+ nn.Linear(self.h_dim, 128),
+ nn.ELU(inplace=True),
+ nn.Linear(128, 128),
+ nn.ELU(inplace=True),
+ nn.Linear(128, action_space),
+ )
+
+ def forward(self, obs):
+ T, B, *_ = obs.shape
+ x = torch.cat([obs[:-1], obs[1:]], dim=-1)
+ pred_a = self.fwd_model(x)
+ off_by_one = torch.ones((1, B, self.action_space), device=x.device) * -1
+ return torch.cat([pred_a, off_by_one], dim=0)
+
+
+class ScreenEncoder(nn.Module):
+ def __init__(self, screen_shape):
+ super(ScreenEncoder, self).__init__()
+ conv_layers = []
+
+ self.h, self.w = screen_shape
+ self.hidden_dim = 512
+
+ self.conv_filters = [
+ [3, 32, 8, 6, 1],
+ [32, 64, 4, 2, 1],
+ [64, 128, 3, 2, 1],
+ [128, 128, 3, 1, 1],
+ ]
+
+ for (
+ in_channels,
+ out_channels,
+ filter_size,
+ stride,
+ dilation,
+ ) in self.conv_filters:
+ conv_layers.append(
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ filter_size,
+ stride=stride,
+ dilation=dilation,
+ )
+ )
+ conv_layers.append(nn.ELU(inplace=True))
+
+ self.h = conv_outdim(self.h, filter_size, padding=0, stride=stride, dilation=dilation)
+ self.w = conv_outdim(self.w, filter_size, padding=0, stride=stride, dilation=dilation)
+
+ self.conv_head = nn.Sequential(*conv_layers)
+ self.out_size = self.h * self.w * out_channels
+
+ self.fc_head = nn.Sequential(nn.Linear(self.out_size, self.hidden_dim), nn.ELU(inplace=True))
+
+ def forward(self, screen_image):
+ x = self.conv_head(screen_image / 255.0)
+ x = x.view(-1, self.out_size)
+ x = self.fc_head(x)
+ return x
+
+
+class ChaoticDwarvenGPT5(Encoder):
+ def __init__(self, cfg: Config, obs_space: ObsSpace):
+ super().__init__(cfg)
+ self.obs_keys = list(sorted(obs_space.keys())) # always the same order
+ self.encoders = nn.ModuleDict()
+
+ self.use_tty_only = cfg.use_tty_only
+ self.use_prev_action = cfg.use_prev_action
+
+ # screen encoder (TODO: could also use only tty_chars)
+ pixel_size = cfg.pixel_size
+ if cfg.crop_dim == 0:
+ screen_shape = (24 * pixel_size, 80 * pixel_size)
+ else:
+ screen_shape = (cfg.crop_dim * pixel_size, cfg.crop_dim * pixel_size)
+ self.screen_encoder = torch.jit.script(ScreenEncoder(screen_shape))
+ screen_shape = obs_space["screen_image"].shape
+
+ # top and bottom encoders
+ if self.use_tty_only:
+ self.topline_encoder = TopLineEncoder()
+ self.bottomline_encoder = torch.jit.script(BottomLinesEncoder())
+ topline_shape = (obs_space["tty_chars"].shape[1],)
+ bottomline_shape = (2 * obs_space["tty_chars"].shape[1],)
+ else:
+ self.topline_encoder = torch.jit.script(MessageEncoder())
+ self.bottomline_encoder = torch.jit.script(BLStatsEncoder())
+ topline_shape = obs_space["message"].shape
+ bottomline_shape = obs_space["blstats"].shape
+
+ if self.use_prev_action:
+ self.num_actions = obs_space["prev_actions"].n
+ self.prev_actions_dim = self.num_actions
+ else:
+ self.num_actions = None
+ self.prev_actions_dim = 0
+
+ self.encoder_out_size = sum(
+ [
+ calc_num_elements(self.screen_encoder, screen_shape),
+ calc_num_elements(self.topline_encoder, topline_shape),
+ calc_num_elements(self.bottomline_encoder, bottomline_shape),
+ self.prev_actions_dim,
+ ]
+ )
+
+ def forward(self, obs_dict):
+ B, C, H, W = obs_dict["screen_image"].shape
+
+ if self.use_tty_only:
+ topline = obs_dict["tty_chars"][..., 0, :]
+ bottom_line = obs_dict["tty_chars"][..., -2:, :]
+ else:
+ topline = obs_dict["message"]
+ bottom_line = obs_dict["blstats"]
+
+ encodings = [
+ self.topline_encoder(topline.float(memory_format=torch.contiguous_format).view(B, -1)),
+ self.bottomline_encoder(bottom_line.float(memory_format=torch.contiguous_format).view(B, -1)),
+ self.screen_encoder(obs_dict["screen_image"].float(memory_format=torch.contiguous_format).view(B, C, H, W)),
+ ]
+
+ if self.use_prev_action:
+ prev_actions = obs_dict["prev_actions"].long().view(B)
+ encodings.append(torch.nn.functional.one_hot(prev_actions, self.num_actions))
+
+ return torch.cat(encodings, dim=1)
+
+ def get_out_size(self) -> int:
+ return self.encoder_out_size
diff --git a/sf_examples/nethack/nethack_env.py b/sf_examples/nethack/nethack_env.py
new file mode 100644
index 000000000..b870aad6d
--- /dev/null
+++ b/sf_examples/nethack/nethack_env.py
@@ -0,0 +1,107 @@
+from typing import Optional
+
+from nle.env.tasks import (
+ NetHackChallenge,
+ NetHackEat,
+ NetHackGold,
+ NetHackOracle,
+ NetHackScore,
+ NetHackScout,
+ NetHackStaircase,
+ NetHackStaircasePet,
+)
+
+from sample_factory.algo.utils.gymnasium_utils import patch_non_gymnasium_env
+from sf_examples.nethack.utils.wrappers import (
+ BlstatsInfoWrapper,
+ PrevActionsWrapper,
+ RenderCharImagesWithNumpyWrapperV2,
+ SeedActionSpaceWrapper,
+ TaskRewardsInfoWrapper,
+)
+
+NETHACK_ENVS = dict(
+ nethack_staircase=NetHackStaircase,
+ nethack_score=NetHackScore,
+ nethack_pet=NetHackStaircasePet,
+ nethack_oracle=NetHackOracle,
+ nethack_gold=NetHackGold,
+ nethack_eat=NetHackEat,
+ nethack_scout=NetHackScout,
+ nethack_challenge=NetHackChallenge,
+)
+
+
+def nethack_env_by_name(name):
+ if name in NETHACK_ENVS.keys():
+ return NETHACK_ENVS[name]
+ else:
+ raise Exception("Unknown NetHack env")
+
+
+def make_nethack_env(env_name, cfg, env_config, render_mode: Optional[str] = None):
+ assert render_mode in (None, "human", "full", "ansi", "string", "rgb_array")
+
+ env_class = nethack_env_by_name(env_name)
+
+ observation_keys = (
+ "message",
+ "blstats",
+ "tty_chars",
+ "tty_colors",
+ "tty_cursor",
+ # ALSO AVAILABLE (OFF for speed)
+ # "specials",
+ # "colors",
+ # "chars",
+ # "glyphs",
+ # "inv_glyphs",
+ # "inv_strs",
+ # "inv_letters",
+ # "inv_oclasses",
+ )
+
+ kwargs = dict(
+ character=cfg.character,
+ max_episode_steps=cfg.max_episode_steps,
+ observation_keys=observation_keys,
+ penalty_step=cfg.penalty_step,
+ penalty_time=cfg.penalty_time,
+ penalty_mode=cfg.fn_penalty_step,
+ savedir=cfg.savedir,
+ save_ttyrec_every=cfg.save_ttyrec_every,
+ )
+ if env_name == "challenge":
+ kwargs["no_progress_timeout"] = 150
+
+ if env_name in ("staircase", "pet", "oracle"):
+ kwargs.update(reward_win=cfg.reward_win, reward_lose=cfg.reward_lose)
+ # else: # print warning once
+ # warnings.warn("Ignoring cfg.reward_win and cfg.reward_lose")
+
+ env = env_class(**kwargs)
+
+ if cfg.add_image_observation:
+ env = RenderCharImagesWithNumpyWrapperV2(
+ env,
+ crop_size=cfg.crop_dim,
+ rescale_font_size=(cfg.pixel_size, cfg.pixel_size),
+ )
+
+ if cfg.use_prev_action:
+ env = PrevActionsWrapper(env)
+
+ if cfg.add_stats_to_info:
+ env = BlstatsInfoWrapper(env)
+ env = TaskRewardsInfoWrapper(env)
+
+ env = patch_non_gymnasium_env(env)
+
+ if render_mode:
+ env.render_mode = render_mode
+
+ if cfg.serial_mode and cfg.num_workers == 1:
+ # full reproducability can only be achieved in serial mode and when there is only 1 worker
+ env = SeedActionSpaceWrapper(env)
+
+ return env
diff --git a/sf_examples/nethack/nethack_params.py b/sf_examples/nethack/nethack_params.py
new file mode 100644
index 000000000..926a8b2a7
--- /dev/null
+++ b/sf_examples/nethack/nethack_params.py
@@ -0,0 +1,134 @@
+from sample_factory.utils.utils import str2bool
+
+
+def add_extra_params_nethack_env(parser):
+ """
+ Specify any additional command line arguments for NetHack environments.
+ """
+ p = parser
+ p.add_argument(
+ "--character", type=str, default="mon-hum-neu-mal", help="name of character. Defaults to 'mon-hum-neu-mal'."
+ )
+ p.add_argument(
+ "--max_episode_steps",
+ type=int,
+ default=100000,
+ help="maximum amount of steps allowed before the game is forcefully quit. In such cases, `info 'end_status']` will be equal to `StepStatus.ABORTED`",
+ )
+ p.add_argument(
+ "--penalty_step", type=float, default=0.0, help="constant applied to amount of frozen steps. Defaults to 0.0."
+ )
+ p.add_argument(
+ "--penalty_time", type=float, default=0.0, help="constant applied to amount of frozen steps. Defaults to 0.0."
+ )
+ p.add_argument(
+ "--fn_penalty_step",
+ type=str,
+ default="constant",
+ help="name of the mode for calculating the time step penalty. Can be `constant`, `exp`, `square`, `linear`, or `always`. Defaults to `constant`.",
+ )
+ p.add_argument(
+ "--savedir",
+ type=str,
+ default=None,
+ help="Path to save ttyrecs (game recordings) into, if save_ttyrec_every is nonzero. If nonempty string, interpreted as a path to a new or existing directory. If "
+ " (empty string) or None, NLE choses a unique directory name.Defaults to `None`.",
+ )
+ p.add_argument(
+ "--save_ttyrec_every",
+ type=int,
+ default=0,
+ help="Integer, if 0, no ttyrecs (game recordings) will be saved. Otherwise, save a ttyrec every Nth episode.",
+ )
+ p.add_argument(
+ "--add_image_observation",
+ type=str2bool,
+ default=True,
+ help="If True, additional wrapper will render screen image. Defaults to `True`.",
+ )
+ p.add_argument("--crop_dim", type=int, default=18, help="Crop image around the player. Defaults to `18`.")
+ p.add_argument(
+ "--pixel_size",
+ type=int,
+ default=6,
+ help="Rescales each character to size of `(pixel_size, pixel_size). Defaults to `6`.",
+ )
+
+
+def add_extra_params_model(parser):
+ """
+ Specify any additional command line arguments for NetHack models.
+ """
+ p = parser
+ p.add_argument(
+ "--use_prev_action",
+ type=str2bool,
+ default=True,
+ help="If True, the model will use previous action. Defaults to `True`",
+ )
+ p.add_argument(
+ "--use_tty_only",
+ type=str2bool,
+ default=True,
+ help="If True, the model will use tty_chars for the topline and bottomline. Defaults to `True`",
+ )
+
+
+def add_extra_params_general(parser):
+ """
+ Specify any additional command line arguments for NetHack.
+ """
+ p = parser
+ p.add_argument(
+ "--model", type=str, default="ChaoticDwarvenGPT5", help="Name of the model. Defaults to `ChaoticDwarvenGPT5`."
+ )
+ p.add_argument(
+ "--add_stats_to_info",
+ type=str2bool,
+ default=True,
+ help="If True, adds wrapper which loggs additional statisics. Defaults to `True`.",
+ )
+
+
+def nethack_override_defaults(_env, parser):
+ """RL params specific to NetHack envs."""
+ # set hyperparameter values to the same as in d&d
+ parser.set_defaults(
+ use_record_episode_statistics=False,
+ gamma=0.999,
+ num_workers=12,
+ num_envs_per_worker=2,
+ worker_num_splits=2,
+ train_for_env_steps=2_000_000_000,
+ nonlinearity="relu",
+ use_rnn=True,
+ rnn_type="lstm",
+ actor_critic_share_weights=True,
+ policy_initialization="orthogonal",
+ policy_init_gain=1.0,
+ adaptive_stddev=False, # True only for continous action distributions
+ reward_scale=1.0,
+ reward_clip=10.0,
+ batch_size=1024,
+ rollout=32,
+ max_grad_norm=4,
+ num_epochs=1,
+ num_batches_per_epoch=1, # can be used for increasing the batch_size for SGD
+ ppo_clip_ratio=0.1,
+ ppo_clip_value=1.0,
+ value_loss_coeff=1.0,
+ exploration_loss="entropy",
+ exploration_loss_coeff=0.001,
+ learning_rate=0.0001,
+ gae_lambda=1.0,
+ with_vtrace=False, # in d&d they've used vtrace
+ normalize_input=False, # turn off for now and use normalization from d&d
+ normalize_returns=True,
+ async_rl=True,
+ experiment_summaries_interval=50,
+ adam_beta1=0.9,
+ adam_beta2=0.999,
+ adam_eps=1e-7,
+ seed=22,
+ save_every_sec=120,
+ )
diff --git a/sf_examples/nethack/nethack_render_utils/CMakeLists.txt b/sf_examples/nethack/nethack_render_utils/CMakeLists.txt
new file mode 100644
index 000000000..727be7ad8
--- /dev/null
+++ b/sf_examples/nethack/nethack_render_utils/CMakeLists.txt
@@ -0,0 +1,12 @@
+cmake_minimum_required(VERSION 3.4...3.18)
+project(nethack_render_utils VERSION 0.0.1)
+
+find_package(pybind11 REQUIRED)
+include_directories(${pybind11_INCLUDE_DIR})
+
+pybind11_add_module(nethack_render_utils src/main.cpp)
+
+# EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a
+# define (VERSION_INFO) here.
+target_compile_definitions(nethack_render_utils
+ PRIVATE VERSION_INFO=${EXAMPLE_VERSION_INFO})
\ No newline at end of file
diff --git a/sf_examples/nethack/nethack_render_utils/Hack-Regular.ttf b/sf_examples/nethack/nethack_render_utils/Hack-Regular.ttf
new file mode 100644
index 000000000..097db1814
Binary files /dev/null and b/sf_examples/nethack/nethack_render_utils/Hack-Regular.ttf differ
diff --git a/sf_examples/nethack/nethack_render_utils/setup.py b/sf_examples/nethack/nethack_render_utils/setup.py
new file mode 100644
index 000000000..aaae05690
--- /dev/null
+++ b/sf_examples/nethack/nethack_render_utils/setup.py
@@ -0,0 +1,127 @@
+import os
+import re
+import subprocess
+import sys
+
+from setuptools import Extension, setup
+from setuptools.command.build_ext import build_ext
+
+# Convert distutils Windows platform specifiers to CMake -A arguments
+PLAT_TO_CMAKE = {
+ "win32": "Win32",
+ "win-amd64": "x64",
+ "win-arm32": "ARM",
+ "win-arm64": "ARM64",
+}
+
+
+# A CMakeExtension needs a sourcedir instead of a file list.
+# The name must be the _single_ output extension from the CMake build.
+# If you need multiple extensions, see scikit-build.
+class CMakeExtension(Extension):
+ def __init__(self, name, sourcedir=""):
+ Extension.__init__(self, name, sources=[])
+ self.sourcedir = os.path.abspath(sourcedir)
+
+
+class CMakeBuild(build_ext):
+ def build_extension(self, ext):
+ extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
+
+ # required for auto-detection & inclusion of auxiliary "native" libs
+ if not extdir.endswith(os.path.sep):
+ extdir += os.path.sep
+
+ debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
+ cfg = "Debug" if debug else "Release"
+
+ # CMake lets you override the generator - we need to check this.
+ # Can be set with Conda-Build, for example.
+ cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
+
+ # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
+ # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
+ # from Python.
+ cmake_args = [
+ f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
+ f"-DPYTHON_EXECUTABLE={sys.executable}",
+ f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
+ ]
+ build_args = []
+ # Adding CMake arguments set as environment variable
+ # (needed e.g. to build for ARM OSx on conda-forge)
+ if "CMAKE_ARGS" in os.environ:
+ cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
+
+ # In this example, we pass in the version to C++. You might not need to.
+ cmake_args += [f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"]
+
+ if self.compiler.compiler_type != "msvc":
+ # Using Ninja-build since it a) is available as a wheel and b)
+ # multithreads automatically. MSVC would require all variables be
+ # exported for Ninja to pick it up, which is a little tricky to do.
+ # Users can override the generator with CMAKE_GENERATOR in CMake
+ # 3.15+.
+ if not cmake_generator:
+ try:
+ import ninja # noqa: F401
+
+ cmake_args += ["-GNinja"]
+ except ImportError:
+ pass
+
+ else:
+ # Single config generators are handled "normally"
+ single_config = any(x in cmake_generator for x in {"NMake", "Ninja"})
+
+ # CMake allows an arch-in-generator style for backward compatibility
+ contains_arch = any(x in cmake_generator for x in {"ARM", "Win64"})
+
+ # Specify the arch if using MSVC generator, but only if it doesn't
+ # contain a backward-compatibility arch spec already in the
+ # generator name.
+ if not single_config and not contains_arch:
+ cmake_args += ["-A", PLAT_TO_CMAKE[self.plat_name]]
+
+ # Multi-config generators have a different way to specify configs
+ if not single_config:
+ cmake_args += [f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
+ build_args += ["--config", cfg]
+
+ if sys.platform.startswith("darwin"):
+ # Cross-compile support for macOS - respect ARCHFLAGS if set
+ archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
+ if archs:
+ cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]
+
+ # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
+ # across all generators.
+ if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
+ # self.parallel is a Python 3 only way to set parallel jobs by hand
+ # using -j in the build_ext call, not supported by pip or PyPA-build.
+ if hasattr(self, "parallel") and self.parallel:
+ # CMake 3.12+ only.
+ build_args += [f"-j{self.parallel}"]
+
+ if not os.path.exists(self.build_temp):
+ os.makedirs(self.build_temp)
+
+ subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp)
+ subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
+
+
+# The information here can also be placed in setup.cfg - better separation of
+# logic and declaration, and simpler if you include description/version in a file.
+setup(
+ name="nethack_render_utils",
+ version="0.0.1",
+ author="Eric Hambro",
+ author_email="ehambro@fb.com",
+ description="Render NetHack glyphs as a screen",
+ long_description="",
+ ext_modules=[CMakeExtension("nethack_render_utils")],
+ cmdclass={"build_ext": CMakeBuild},
+ zip_safe=False,
+ extras_require={"test": ["pytest>=6.0"]},
+ python_requires=">=3.6",
+)
diff --git a/sf_examples/nethack/nethack_render_utils/src/main.cpp b/sf_examples/nethack/nethack_render_utils/src/main.cpp
new file mode 100644
index 000000000..36ff5019d
--- /dev/null
+++ b/sf_examples/nethack/nethack_render_utils/src/main.cpp
@@ -0,0 +1,213 @@
+#include
+
+#include
+#include
+#include
+
+#define STRINGIFY(x) #x
+#define MACRO_STRINGIFY(x) STRINGIFY(x)
+
+namespace py = pybind11;
+
+template void check_array_types(py::handle h, int min_dim) {
+ if (h.is_none())
+ throw std::invalid_argument("Array is None!");
+ ;
+ if (!py::isinstance(h))
+ throw std::invalid_argument("Numpy array required.");
+
+ py::array array = py::array::ensure(h);
+ if (!array.dtype().is(py::dtype::of()))
+ throw std::invalid_argument("Buffer dtype mismatch.");
+
+ if (!(array.flags() & py::array::c_style))
+ throw std::invalid_argument("Array isn't C contiguous.");
+
+ py::buffer_info buf = array.request();
+ if (buf.ndim < min_dim)
+ throw std::invalid_argument("Wrong ndim in array");
+}
+
+void check_array_shapes(py::handle tty_chars, py::handle tty_colors,
+ py::handle tty_cursor, py::handle glyph_images,
+ py::handle out, int crop_size) {
+
+ const auto &chars_shape = py::array::ensure(tty_chars).request().shape;
+ const auto &colors_shape = py::array::ensure(tty_colors).request().shape;
+ if (!std::equal(chars_shape.begin(), chars_shape.end(), colors_shape.begin()))
+ throw std::invalid_argument("Shape mismatch (tty_chars, tty_colors).");
+
+ const auto &img_shape = py::array::ensure(glyph_images).request().shape;
+ if (!(img_shape[0] == 256 && img_shape[1] == 16))
+ throw std::invalid_argument("Shape of glyph_images must start (256, 16)");
+
+ int crop_height =
+ (crop_size > 0) ? crop_size : chars_shape[chars_shape.size() - 2];
+ int crop_width =
+ (crop_size > 0) ? crop_size : chars_shape[chars_shape.size() - 1];
+ const auto &out_shape = py::array::ensure(out).request().shape;
+ size_t dim = out_shape.size();
+ if (!(out_shape[dim - 3] == img_shape[2] &&
+ out_shape[dim - 2] == (crop_height * img_shape[3]) &&
+ out_shape[dim - 1] == (crop_width * img_shape[4])))
+ throw std::invalid_argument("Shape mismatch (glyph_images, out).");
+
+ const auto &cursor_shape = py::array::ensure(tty_cursor).request().shape;
+ if (!(cursor_shape[cursor_shape.size() - 1] == 2))
+ throw std::invalid_argument("Shape of glyph_images must be (2)");
+
+ if (!(chars_shape.size() - 2 == cursor_shape.size() - 1 &&
+ chars_shape.size() - 2 == out_shape.size() - 3)) {
+ throw std::invalid_argument("Different dims for batch conversion");
+ }
+
+ for (int i = 0; i < chars_shape.size() - 2; ++i) {
+ if (!(chars_shape[i] == cursor_shape[i] &&
+ chars_shape[i] == out_shape[i])) {
+ throw std::invalid_argument("Different batch sizes for batch conversion");
+ }
+ }
+}
+
+void tile_crop(py::array_t tty_chars, py::array_t tty_colors,
+ py::array_t tty_cursor, py::array_t images,
+ py::array_t out_array, int crop_size) {
+
+ py::buffer_info chars_buff = tty_chars.request();
+ py::buffer_info colors_buff = tty_colors.request();
+ py::buffer_info cursor_buff = tty_cursor.request();
+ py::buffer_info images_buff = images.request();
+ py::buffer_info out_buff = out_array.request();
+
+ const auto &chars_shape = chars_buff.shape;
+ const auto &img_shape = images_buff.shape;
+ const auto &out_shape = out_buff.shape;
+
+ int lead_dims = chars_shape.size() - 2;
+ int lead_elems = 1;
+ for (int i = 0; i < lead_dims; ++i) {
+ lead_elems *= chars_shape[i];
+ }
+
+ int rows = chars_shape[lead_dims + 0];
+ int cols = chars_shape[lead_dims + 1];
+
+ int img_colors = img_shape[1];
+ int img_channels = img_shape[2];
+ int img_rows = img_shape[3];
+ int img_cols = img_shape[4];
+
+ int out_chan = out_shape[lead_dims + 0];
+ int out_rows = out_shape[lead_dims + 1];
+ int out_cols = out_shape[lead_dims + 2];
+
+ uint8_t *char_ptr = static_cast(chars_buff.ptr);
+ int8_t *color_ptr = static_cast(colors_buff.ptr);
+ uint8_t *out_ptr = static_cast(out_buff.ptr);
+ uint8_t *cur_ptr = static_cast(cursor_buff.ptr);
+ uint8_t *img_ptr = static_cast(images_buff.ptr);
+
+ int half_crop_size = crop_size / 2;
+
+ // Strides
+ int s_char_frame = rows * cols;
+ int s_char_row = cols;
+
+ int s_color_frame = rows * cols;
+ int s_color_row = cols;
+
+ int s_cursor_frame = 2;
+
+ int s_img_col = img_cols;
+ int s_img_row = img_rows * img_cols;
+ int s_img_color = img_channels * img_rows * img_cols;
+ int s_img_glyph = img_colors * img_channels * img_rows * img_cols;
+
+ int s_out_frame = out_chan * out_rows * out_cols;
+ int s_out_chan = out_rows * out_cols;
+ int s_out_row = out_cols;
+ {
+ py::gil_scoped_release release;
+
+ for (size_t i = 0; i < lead_elems; ++i) {
+ auto chars_at = [char_ptr, s_char_row](int h, int w) {
+ return *(char_ptr + h * s_char_row + w);
+ };
+ auto colors_at = [color_ptr, s_color_row](int h, int w) {
+ return *(color_ptr + h * s_color_row + w);
+ };
+ auto img_at = [img_ptr, s_img_glyph, s_img_color, s_img_row,
+ s_img_col](int glyph, int color, int chan, int h, int w) {
+ return *(img_ptr + glyph * s_img_glyph + color * s_img_color +
+ chan * s_img_row + h * s_img_col + w);
+ };
+ auto out_ptr_ = [out_ptr, s_out_chan, s_out_row](int chan, int h, int w) {
+ return (out_ptr + chan * s_out_chan + h * s_out_row + w);
+ };
+
+ int start_h = (crop_size > 0) ? cur_ptr[0] - half_crop_size : 0;
+ int start_w = (crop_size > 0) ? cur_ptr[1] - half_crop_size : 0;
+
+ int max_r = (crop_size > 0) ? crop_size : rows;
+ int max_c = (crop_size > 0) ? crop_size : cols;
+ for (size_t r = 0; r < max_r; ++r) {
+ int h = r + start_h;
+ for (size_t c = 0; c < max_c; ++c) {
+ int w = c + start_w;
+ for (size_t i_chan = 0; i_chan < img_channels; ++i_chan) {
+ for (size_t i_r = 0; i_r < img_rows; ++i_r) {
+ for (size_t i_c = 0; i_c < img_cols; ++i_c) {
+
+ if ((h < 0 || h >= rows || w < 0 || w >= cols)) {
+ *out_ptr_(i_chan, r * img_rows + i_r, c * img_cols + i_c) = 0;
+ } else {
+ int this_glyph = chars_at(h, w);
+ int this_color = colors_at(h, w);
+ *out_ptr_(i_chan, r * img_rows + i_r, c * img_cols + i_c) =
+ img_at(this_glyph, this_color, i_chan, i_r, i_c);
+ }
+ }
+ }
+ }
+ }
+ }
+ char_ptr += s_char_frame;
+ color_ptr += s_color_frame;
+ cur_ptr += s_cursor_frame;
+ out_ptr += s_out_frame;
+ }
+ }
+}
+void render_crop(py::object tty_chars, py::object tty_colors,
+ py::object tty_cursor, py::object images, py::object out_array,
+ int crop_size) {
+
+ check_array_types(tty_chars, 2);
+ check_array_types(tty_colors, 2);
+ check_array_types(tty_cursor, 1);
+ check_array_types(images, 5);
+ check_array_types(out_array, 3);
+ check_array_shapes(tty_chars, tty_colors, tty_cursor, images, out_array,
+ crop_size);
+
+ tile_crop(tty_chars, tty_colors, tty_cursor, images, out_array, crop_size);
+}
+
+namespace py = pybind11;
+
+PYBIND11_MODULE(nethack_render_utils, m) {
+ m.doc() = R"pbdoc(
+ A module to turn glyphs into the screen in pixels
+ -----------------------
+ )pbdoc";
+
+ m.def("render_crop", &render_crop, py::arg("tty_chars"),
+ py::arg("tty_colors"), py::arg("tty_cursor"), py::arg("images"),
+ py::arg("out_array"), py::arg("crop_size") = 12);
+
+#ifdef VERSION_INFO
+ m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
+#else
+ m.attr("__version__") = "dev";
+#endif
+}
diff --git a/sf_examples/nethack/nethack_render_utils/test.py b/sf_examples/nethack/nethack_render_utils/test.py
new file mode 100644
index 000000000..e15b429af
--- /dev/null
+++ b/sf_examples/nethack/nethack_render_utils/test.py
@@ -0,0 +1,137 @@
+import sys
+import time
+from concurrent.futures import ThreadPoolExecutor
+
+import gym
+import nethack_render_utils as m
+import nle # NOQA: F401
+import numpy as np
+import tqdm
+from PIL import Image as im
+
+sys.path.append("/private/home/ehambro/fair/workspaces/wrapper-hackrl/hackrl")
+import wrappers # NOQA: E402
+
+
+def create_env():
+ return wrappers.RenderCharImagesWithNumpyWrapper(gym.make("NetHackChallenge-v0"), blstats_cursor=False)
+
+
+def load_obs():
+ e = create_env()
+ e.reset()
+ e.step(0)
+ obs = e.step(1)[0]
+ obs = e.step(5)[0]
+
+ images = e.char_array.copy()
+
+ return (
+ obs["tty_chars"].copy(),
+ obs["tty_colors"].copy(),
+ obs["tty_cursor"].copy(),
+ images,
+ obs["screen_image"].copy(),
+ )
+
+
+def test_main():
+ assert m.__version__ == "0.0.1"
+ assert m.add(1, 2) == 3
+ assert m.subtract(1, 2) == -1
+
+
+np.set_printoptions(threshold=sys.maxsize)
+
+
+def test(_):
+ obs = [np.ascontiguousarray(x) for x in load_obs()]
+ chars, colors, cursor, images, screen_image = obs
+ out = np.zeros_like(screen_image, order="C")
+ out = np.zeros((3, 72, 72), order="C", dtype=np.uint8)
+
+ m.render_crop(chars, colors, cursor, images, out, screen_image)
+
+ if not np.all(out == screen_image):
+ scr_im = im.fromarray(np.transpose(screen_image, (1, 2, 0)))
+ out_im = im.fromarray(np.transpose(out, (1, 2, 0)))
+
+ # saving the final output
+ # as a PNG file
+ out_im.save("out_im.png")
+ scr_im.save("scr_im.png")
+ print(cursor[1] - 6, cursor[1] + 6)
+ print(
+ chars[
+ max(cursor[0] - 6, 0) : cursor[0] + 6,
+ max(cursor[1] - 6, 0) : cursor[1] + 6,
+ ]
+ )
+
+ np.testing.assert_array_equal(out, screen_image)
+
+
+if __name__ == "__main__":
+ with ThreadPoolExecutor(max_workers=10) as tp:
+
+ def fn(_):
+ obs = [np.ascontiguousarray(x) for x in load_obs()]
+ chars, colors, cursor, images, screen_image = obs
+
+ out = np.zeros_like(screen_image, order="C")
+ m.render_crop(chars, colors, cursor, images, out)
+ np.testing.assert_array_equal(screen_image, out)
+
+ def fn_batched(_):
+ obs = [np.ascontiguousarray(x) for x in load_obs()]
+ chars, colors, cursor, images, screen_image = obs
+ obs = [np.ascontiguousarray(np.stack([x] * 10)) for x in (chars, colors, cursor, screen_image)]
+ obs = [np.ascontiguousarray(np.stack([x] * 20)) for x in obs]
+ (chars, colors, cursor, screen_image) = obs
+
+ out = np.zeros_like(screen_image, order="C")
+ m.render_crop(chars, colors, cursor, images, out)
+ np.testing.assert_array_equal(screen_image, out)
+
+ retries = 100
+ batch_size = (100, 100)
+ obs = []
+ for _ in range(retries):
+ this_obs = [np.ascontiguousarray(x) for x in load_obs()]
+ chars, colors, cursor, images, screen_image = this_obs
+ z = [np.ascontiguousarray(np.stack([x] * batch_size[0])) for x in (chars, colors, cursor, screen_image)]
+ z = [np.ascontiguousarray(np.stack([x] * batch_size[1])) for x in z]
+ (chars, colors, cursor, screen_image) = z
+ this_obs = chars, colors, cursor, images, screen_image
+ out = np.zeros_like(screen_image, order="C")
+
+ obs.append((chars, colors, cursor, images, out))
+
+ print("Testing")
+ list(map(fn, tqdm.tqdm(range(200))))
+
+ print("Testing Batched")
+ list(map(fn_batched, tqdm.tqdm(range(200))))
+
+ print("Profile Single Thread")
+ start = time.time()
+ for o in obs:
+ chars, colors, cursor, images, out = o
+ m.render_crop(chars, colors, cursor, images, out)
+ t = time.time() - start
+ print("Time:", t)
+ print("SPS:", retries * batch_size[0] * batch_size[1] / t)
+
+ print("Profile Batch")
+ start = time.time()
+
+ def _parallel(o, i):
+ chars, colors, cursor, images, out = o
+ m.render_crop(chars[i], colors[i], cursor[i], images, out[i])
+
+ for o in obs:
+ list(tp.map(_parallel, [(o, i) for i in range(batch_size[0])]))
+
+ t = time.time() - start
+ print("Time:", t)
+ print("SPS:", retries * batch_size[0] * batch_size[1] / t)
diff --git a/sf_examples/nethack/scripts/sample_env.py b/sf_examples/nethack/scripts/sample_env.py
new file mode 100644
index 000000000..5b1b260f0
--- /dev/null
+++ b/sf_examples/nethack/scripts/sample_env.py
@@ -0,0 +1,34 @@
+import sys
+
+from sample_factory.algo.utils.rl_utils import make_dones
+from sample_factory.envs.create_env import create_env
+from sample_factory.utils.utils import log
+from sf_examples.nethack.train_nethack import parse_nethack_args, register_nethack_components
+
+
+def main():
+ register_nethack_components()
+ cfg = parse_nethack_args(evaluation=True)
+
+ render_mode = "human"
+ if cfg.save_video:
+ render_mode = "rgb_array"
+ elif cfg.no_render:
+ render_mode = None
+
+ env = create_env(cfg.env, cfg=cfg, render_mode=render_mode)
+
+ env.seed(0)
+ env.action_space.seed(0)
+
+ for i in range(10):
+ env.reset()
+ done = False
+ while not done:
+ obs, rew, terminated, truncated, info = env.step(env.action_space.sample())
+ done = make_dones(terminated, truncated)
+ log.info("Done!")
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/sf_examples/nethack/train_nethack.py b/sf_examples/nethack/train_nethack.py
new file mode 100644
index 000000000..133869bce
--- /dev/null
+++ b/sf_examples/nethack/train_nethack.py
@@ -0,0 +1,58 @@
+import sys
+
+from sample_factory.algo.utils.context import global_model_factory
+from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args
+from sample_factory.envs.env_utils import register_env
+from sample_factory.model.encoder import Encoder
+from sample_factory.train import run_rl
+from sample_factory.utils.typing import Config, ObsSpace
+from sf_examples.nethack.models import MODELS_LOOKUP
+from sf_examples.nethack.nethack_env import NETHACK_ENVS, make_nethack_env
+from sf_examples.nethack.nethack_params import (
+ add_extra_params_general,
+ add_extra_params_model,
+ add_extra_params_nethack_env,
+ nethack_override_defaults,
+)
+
+
+def register_nethack_envs():
+ for env_name in NETHACK_ENVS.keys():
+ register_env(env_name, make_nethack_env)
+
+
+def make_nethack_encoder(cfg: Config, obs_space: ObsSpace) -> Encoder:
+ """Factory function as required by the API."""
+ try:
+ model_cls = MODELS_LOOKUP[cfg.model]
+ except KeyError:
+ raise NotImplementedError("model=%s" % cfg.model) from None
+
+ return model_cls(cfg, obs_space)
+
+
+def register_nethack_components():
+ register_nethack_envs()
+ global_model_factory().register_encoder_factory(make_nethack_encoder)
+
+
+def parse_nethack_args(argv=None, evaluation=False):
+ parser, partial_cfg = parse_sf_args(argv=argv, evaluation=evaluation)
+ add_extra_params_nethack_env(parser)
+ add_extra_params_model(parser)
+ add_extra_params_general(parser)
+ nethack_override_defaults(partial_cfg.env, parser)
+ final_cfg = parse_full_cfg(parser, argv)
+ return final_cfg
+
+
+def main(): # pragma: no cover
+ """Script entry point."""
+ register_nethack_components()
+ cfg = parse_nethack_args()
+ status = run_rl(cfg)
+ return status
+
+
+if __name__ == "__main__": # pragma: no cover
+ sys.exit(main())
diff --git a/sf_examples/nethack/utils/task_rewards.py b/sf_examples/nethack/utils/task_rewards.py
new file mode 100644
index 000000000..99daa8fc4
--- /dev/null
+++ b/sf_examples/nethack/utils/task_rewards.py
@@ -0,0 +1,168 @@
+import re
+
+import numpy as np
+from nle import nethack
+
+
+class Score:
+ def __init__(self):
+ self.score = 0
+ # convert name to snake_case
+ # https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case
+ self.name = re.sub("(?!^)([A-Z]+)", r"_\1", self.__class__.__name__).lower()
+
+ def reset_score(self):
+ self.score = 0
+
+
+class GoldScore(Score):
+ def reward(self, env, last_observation, observation, end_status):
+ old_blstats = last_observation[env._blstats_index]
+ blstats = observation[env._blstats_index]
+
+ old_gold = old_blstats[nethack.NLE_BL_GOLD]
+ gold = blstats[nethack.NLE_BL_GOLD]
+
+ reward = np.abs(gold - old_gold)
+ self.score += reward
+
+ return reward
+
+
+class EatingScore(Score):
+ def reward(self, env, last_observation, observation, end_status):
+ old_internal = last_observation[env._internal_index]
+ internal = observation[env._internal_index]
+
+ reward = max(0, internal[7] - old_internal[7])
+ self.score += reward
+
+ return reward
+
+
+class ScoutScore(Score):
+ def __init__(self):
+ super().__init__()
+ self.dungeon_explored = {}
+
+ def reward(self, env, last_observation, observation, end_status):
+ glyphs = observation[env._glyph_index]
+ blstats = observation[env._blstats_index]
+
+ dungeon_num = blstats[nethack.NLE_BL_DNUM]
+ dungeon_level = blstats[nethack.NLE_BL_DLEVEL]
+
+ key = (dungeon_num, dungeon_level)
+ explored = np.sum(glyphs != nethack.GLYPH_CMAP_OFF)
+ explored_old = 0
+ if key in self.dungeon_explored:
+ explored_old = self.dungeon_explored[key]
+ reward = explored - explored_old
+ self.dungeon_explored[key] = explored
+ self.score += reward
+
+ return reward
+
+ def reset_score(self):
+ super().reset_score()
+ self.dungeon_explored = {}
+
+
+class StaircaseScore(Score):
+ """
+ This task requires the agent to get on top of a staircase down (>).
+ The reward function is :math:`I`, where :math:`I` is 1 if the
+ task is successful, and 0 otherwise.
+ """
+
+ def reward(self, env, last_observation, observation, end_status):
+ internal = observation[env._internal_index]
+ stairs_down = internal[4]
+
+ reward = 1 if stairs_down else 0
+ self.score += reward
+
+ return reward
+
+
+class StaircasePetScore(Score):
+ """
+ This task requires the agent to get on top of a staircase down (>), while
+ having their pet next to it. See `NetHackStaircase` for the reward function.
+ """
+
+ def reward(self, env, last_observation, observation, end_status):
+ internal = observation[env._internal_index]
+ stairs_down = internal[4]
+
+ reward = 0
+ if stairs_down:
+ glyphs = observation[env._glyph_index]
+ blstats = observation[env._blstats_index]
+ x, y = blstats[:2]
+
+ neighbors = glyphs[y - 1 : y + 2, x - 1 : x + 2]
+ if np.any(nethack.glyph_is_pet(neighbors)):
+ reward = 1
+
+ self.score += reward
+
+ return reward
+
+
+class SokobanfillpitScore(Score):
+ """
+ This task requires the agent to put the boulders inside wholes for sokoban.
+ We count each successful boulder moved into a whole as a total reward.
+ """
+
+ def reward(self, env, last_observation, observation, end_status):
+ # the score counts how many pits we fill
+ char_array = [chr(i) for i in observation[env._message_index]]
+ message = "".join(char_array)
+
+ if message.startswith("The boulder fills a pit.") or message.startswith(
+ "The boulder falls into and plugs a whole in the floor!"
+ ):
+ reward = 1
+ else:
+ reward = 0
+ self.score += reward
+
+ return reward
+
+
+class SokobansolvedlevelsScore(Score):
+ def __init__(self):
+ super().__init__()
+ self.sokoban_levels = {}
+
+ def reward(self, env, last_observation, observation, end_status):
+ glyphs = observation[env._glyph_index]
+ blstats = observation[env._blstats_index]
+
+ dungeon_num = blstats[nethack.NLE_BL_DNUM]
+ dungeon_level = blstats[nethack.NLE_BL_DLEVEL]
+
+ # when we know that this is sokoban
+ if dungeon_num == 4:
+ # TODO: maybe we should count "solving" sokoban level when we reach the next level of the sokoban?
+ # checking if all pits are solved can be buggy if the glyphs have different values on other levels
+
+ # count the number of pits, glyphs SS.S_pit
+ pits = np.isin(glyphs, [2411]).sum()
+ key = (dungeon_num, dungeon_level)
+ self.sokoban_levels[key] = pits
+
+ def reset_score(self):
+ super().reset_score()
+ self.sokoban_levels = {}
+
+ @property
+ def score(self):
+ score = 0
+ for pits in self.sokoban_levels.values():
+ # when all pits are filled we assume that sokoban level is solved
+ if pits == 0:
+ score += 1
+ return score
diff --git a/sf_examples/nethack/utils/wrappers/__init__.py b/sf_examples/nethack/utils/wrappers/__init__.py
new file mode 100644
index 000000000..6867f54dd
--- /dev/null
+++ b/sf_examples/nethack/utils/wrappers/__init__.py
@@ -0,0 +1,13 @@
+from sf_examples.nethack.utils.wrappers.blstats_info import BlstatsInfoWrapper
+from sf_examples.nethack.utils.wrappers.prev_actions import PrevActionsWrapper
+from sf_examples.nethack.utils.wrappers.screen_image import RenderCharImagesWithNumpyWrapperV2
+from sf_examples.nethack.utils.wrappers.seed_action_space import SeedActionSpaceWrapper
+from sf_examples.nethack.utils.wrappers.task_rewards import TaskRewardsInfoWrapper
+
+__all__ = [
+ RenderCharImagesWithNumpyWrapperV2,
+ PrevActionsWrapper,
+ TaskRewardsInfoWrapper,
+ BlstatsInfoWrapper,
+ SeedActionSpaceWrapper,
+]
diff --git a/sf_examples/nethack/utils/wrappers/blstats_info.py b/sf_examples/nethack/utils/wrappers/blstats_info.py
new file mode 100644
index 000000000..faef9fdd6
--- /dev/null
+++ b/sf_examples/nethack/utils/wrappers/blstats_info.py
@@ -0,0 +1,36 @@
+from collections import namedtuple
+
+import gym
+
+BLStats = namedtuple(
+ "BLStats",
+ "x y strength_percentage strength dexterity constitution intelligence wisdom charisma score hitpoints max_hitpoints depth gold energy max_energy armor_class monster_level experience_level experience_points time hunger_state carrying_capacity dungeon_number level_number prop_mask align_bits",
+)
+
+
+class BlstatsInfoWrapper(gym.Wrapper):
+ def step(self, action):
+ # because we will see done=True at the first timestep of the new episode
+ # to properly calculate blstats at the end of the episode we need to keep the last_observation around
+ last_observation = tuple(a.copy() for a in self.env.unwrapped.last_observation)
+ obs, reward, done, info = self.env.step(action)
+
+ if done:
+ info["episode_extra_stats"] = self.add_more_stats(info, last_observation)
+
+ return obs, reward, done, info
+
+ def add_more_stats(self, info, last_observation):
+ extra_stats = info.get("episode_extra_stats", {})
+ blstats = BLStats(*last_observation[self.env.unwrapped._blstats_index])
+ new_extra_stats = {
+ "score": blstats.score,
+ "turns": blstats.time,
+ "dlvl": blstats.depth,
+ "max_hitpoints": blstats.max_hitpoints,
+ "max_energy": blstats.max_energy,
+ "armor_class": blstats.armor_class,
+ "experience_level": blstats.experience_level,
+ "experience_points": blstats.experience_points,
+ }
+ return {**extra_stats, **new_extra_stats}
diff --git a/sf_examples/nethack/utils/wrappers/prev_actions.py b/sf_examples/nethack/utils/wrappers/prev_actions.py
new file mode 100644
index 000000000..0afb67793
--- /dev/null
+++ b/sf_examples/nethack/utils/wrappers/prev_actions.py
@@ -0,0 +1,24 @@
+import gym
+import numpy as np
+
+
+class PrevActionsWrapper(gym.Wrapper):
+ def __init__(self, env):
+ super().__init__(env)
+ self.prev_action = 0
+
+ obs_spaces = {"prev_actions": self.env.action_space}
+ obs_spaces.update([(k, self.env.observation_space[k]) for k in self.env.observation_space])
+ self.observation_space = gym.spaces.Dict(obs_spaces)
+
+ def reset(self, **kwargs):
+ self.prev_action = 0
+ obs = self.env.reset(**kwargs)
+ obs["prev_actions"] = np.array([self.prev_action])
+ return obs
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+ self.prev_action = action
+ obs["prev_actions"] = np.array([self.prev_action])
+ return obs, reward, done, info
diff --git a/sf_examples/nethack/utils/wrappers/screen_image.py b/sf_examples/nethack/utils/wrappers/screen_image.py
new file mode 100644
index 000000000..fae99f25d
--- /dev/null
+++ b/sf_examples/nethack/utils/wrappers/screen_image.py
@@ -0,0 +1,310 @@
+"""Taken & adapted from Chaos Dwarf in Nethack Challenge Starter Kit:
+https://github.com/Miffyli/nle-sample-factory-baseline
+MIT License
+Copyright (c) 2021 Anssi
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+"""
+
+import os
+from typing import Any, SupportsFloat
+
+import cv2
+import gym
+import nethack_render_utils
+import numpy as np
+from nle import nethack
+from numba import njit
+from PIL import Image, ImageDraw, ImageFont
+
+SMALL_FONT_PATH = os.path.abspath("sf_examples/nethack/nethack_render_utils/Hack-Regular.ttf")
+
+
+# Mapping of 0-15 colors used.
+# Taken from bottom image here. It seems about right
+# https://i.stack.imgur.com/UQVe5.png
+COLORS = [
+ "#000000",
+ "#800000",
+ "#008000",
+ "#808000",
+ "#000080",
+ "#800080",
+ "#008080",
+ "#808080", # - flipped these ones around
+ "#C0C0C0", # | the gray-out dull stuff
+ "#FF0000",
+ "#00FF00",
+ "#FFFF00",
+ "#0000FF",
+ "#FF00FF",
+ "#00FFFF",
+ "#FFFFFF",
+]
+
+
+@njit
+def _tile_characters_to_image(
+ out_image,
+ chars,
+ colors,
+ output_height_chars,
+ output_width_chars,
+ char_array,
+ offset_h,
+ offset_w,
+):
+ """
+ Build an image using cached images of characters in char_array to out_image
+ """
+ char_height = char_array.shape[3]
+ char_width = char_array.shape[4]
+ for h in range(output_height_chars):
+ h_char = h + offset_h
+ # Stuff outside boundaries is not visible, so
+ # just leave it black
+ if h_char < 0 or h_char >= chars.shape[0]:
+ continue
+ for w in range(output_width_chars):
+ w_char = w + offset_w
+ if w_char < 0 or w_char >= chars.shape[1]:
+ continue
+ char = chars[h_char, w_char]
+ color = colors[h_char, w_char]
+ h_pixel = h * char_height
+ w_pixel = w * char_width
+ out_image[:, h_pixel : h_pixel + char_height, w_pixel : w_pixel + char_width] = char_array[char, color]
+
+
+def _initialize_char_array(font_size, rescale_font_size):
+ """Draw all characters in PIL and cache them in numpy arrays
+ if rescale_font_size is given, assume it is (width, height)
+ Returns a np array of (num_chars, num_colors, char_height, char_width, 3)
+ """
+ font = ImageFont.truetype(SMALL_FONT_PATH, font_size)
+ dummy_text = "".join([(chr(i) if chr(i).isprintable() else " ") for i in range(256)])
+ bboxes = np.array([font.getbbox(char) for char in dummy_text])
+ image_width = bboxes[:, 2].max()
+ image_height = bboxes[:, 3].max()
+
+ char_width = rescale_font_size[0]
+ char_height = rescale_font_size[1]
+ char_array = np.zeros((256, 16, char_height, char_width, 3), dtype=np.uint8)
+
+ for color_index in range(16):
+ for char_index in range(256):
+ char = dummy_text[char_index]
+
+ image = Image.new("RGB", (image_width, image_height))
+ image_draw = ImageDraw.Draw(image)
+ image_draw.rectangle((0, 0, image_width, image_height), fill=(0, 0, 0))
+
+ _, _, width, height = font.getbbox(char)
+ x = (image_width - width) // 2
+ y = (image_height - height) // 2
+ image_draw.text((x, y), char, font=font, fill=COLORS[color_index])
+
+ arr = np.array(image).copy()
+ if rescale_font_size:
+ arr = cv2.resize(arr, rescale_font_size, interpolation=cv2.INTER_AREA)
+ char_array[char_index, color_index] = arr
+
+ return char_array
+
+
+class RenderCharImagesWithNumpyWrapper(gym.Wrapper):
+ """
+ Render characters as images, using PIL to render characters like we humans see on screen
+ but then some caching and numpy stuff to speed up things.
+ To speed things up, crop image around the player.
+ """
+
+ def __init__(
+ self,
+ env,
+ font_size=9,
+ crop_size=12,
+ rescale_font_size=(6, 6),
+ blstats_cursor=False,
+ ):
+ super().__init__(env)
+ self.char_array = _initialize_char_array(font_size, rescale_font_size)
+ self.char_height = self.char_array.shape[2]
+ self.char_width = self.char_array.shape[3]
+ # Transpose for CHW
+ self.char_array = self.char_array.transpose(0, 1, 4, 2, 3)
+
+ self.crop_size = crop_size
+ self.blstats_cursor = blstats_cursor
+
+ self.half_crop_size = crop_size // 2
+ self.output_height_chars = crop_size
+ self.output_width_chars = crop_size
+ self.chw_image_shape = (
+ 3,
+ self.output_height_chars * self.char_height,
+ self.output_width_chars * self.char_width,
+ )
+
+ obs_spaces = {"screen_image": gym.spaces.Box(low=0, high=255, shape=self.chw_image_shape, dtype=np.uint8)}
+ obs_spaces.update(
+ [
+ (k, self.env.observation_space[k])
+ for k in self.env.observation_space
+ if k not in ["tty_chars", "tty_colors"]
+ ]
+ )
+ self.observation_space = gym.spaces.Dict(obs_spaces)
+
+ def _render_text_to_image(self, obs):
+ chars = obs["tty_chars"]
+ colors = obs["tty_colors"]
+ offset_w = 0
+ offset_h = 0
+ if self.crop_size:
+ # Center around player
+ if self.blstats_cursor:
+ center_x, center_y = obs["blstats"][:2]
+ else:
+ center_y, center_x = obs["tty_cursor"]
+ offset_h = center_y - self.half_crop_size
+ offset_w = center_x - self.half_crop_size
+
+ out_image = np.zeros(self.chw_image_shape, dtype=np.uint8)
+
+ _tile_characters_to_image(
+ out_image=out_image,
+ chars=chars,
+ colors=colors,
+ output_height_chars=self.output_height_chars,
+ output_width_chars=self.output_width_chars,
+ char_array=self.char_array,
+ offset_h=offset_h,
+ offset_w=offset_w,
+ )
+
+ obs["screen_image"] = out_image
+ return obs
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+ obs = self._render_text_to_image(obs)
+ return obs, reward, done, info
+
+ def reset(self, **kwargs):
+ obs = self.env.reset(**kwargs)
+ obs = self._render_text_to_image(obs)
+ return obs
+
+
+class RenderCharImagesWithNumpyWrapperV2(gym.Wrapper):
+ """
+ Same as V1, but simpler and faster.
+ """
+
+ def __init__(
+ self,
+ env,
+ font_size=9,
+ crop_size=12,
+ rescale_font_size=(6, 6),
+ render_font_size=(6, 11),
+ ):
+ super().__init__(env)
+ self.char_array = _initialize_char_array(font_size, rescale_font_size)
+ self.char_height = self.char_array.shape[2]
+ self.char_width = self.char_array.shape[3]
+ # Transpose for CHW
+ self.char_array = self.char_array.transpose(0, 1, 4, 2, 3)
+ self.char_array = np.ascontiguousarray(self.char_array)
+ self.crop_size = crop_size
+
+ crop_rows = crop_size or nethack.nethack.TERMINAL_SHAPE[0]
+ crop_cols = crop_size or nethack.nethack.TERMINAL_SHAPE[1]
+
+ self.chw_image_shape = (
+ 3,
+ crop_rows * self.char_height,
+ crop_cols * self.char_width,
+ )
+
+ obs_spaces = {"screen_image": gym.spaces.Box(low=0, high=255, shape=self.chw_image_shape, dtype=np.uint8)}
+ obs_spaces.update(
+ [
+ (k, self.env.observation_space[k])
+ for k in self.env.observation_space
+ # if k not in ["tty_chars", "tty_colors"]
+ ]
+ )
+ self.observation_space = gym.spaces.Dict(obs_spaces)
+
+ self.render_char_array = _initialize_char_array(font_size, render_font_size)
+ self.render_char_array = self.render_char_array.transpose(0, 1, 4, 2, 3)
+ self.render_char_array = np.ascontiguousarray(self.render_char_array)
+
+ def _populate_obs(self, obs):
+ screen = np.zeros(self.chw_image_shape, order="C", dtype=np.uint8)
+ nethack_render_utils.render_crop(
+ obs["tty_chars"],
+ obs["tty_colors"],
+ obs["tty_cursor"],
+ self.char_array,
+ screen,
+ crop_size=self.crop_size,
+ )
+ obs["screen_image"] = screen
+
+ def step(self, action: Any) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
+ obs, reward, done, info = self.env.step(action)
+ self._populate_obs(obs)
+ return obs, reward, done, info
+
+ def reset(self, **kwargs):
+ obs = self.env.reset(**kwargs)
+ self._populate_obs(obs)
+ return obs
+
+ def render(self, mode="human"):
+ if mode == "rgb_array":
+ if not self.unwrapped.last_observation:
+ return
+
+ # TODO: we don't crop but additionally we could show what model sees
+ obs = self.unwrapped.last_observation
+ tty_chars = obs[self.unwrapped._observation_keys.index("tty_chars")]
+ tty_colors = obs[self.unwrapped._observation_keys.index("tty_colors")]
+
+ chw_image_shape = (
+ 3,
+ nethack.nethack.TERMINAL_SHAPE[0] * self.render_char_array.shape[3],
+ nethack.nethack.TERMINAL_SHAPE[1] * self.render_char_array.shape[4],
+ )
+ out_image = np.zeros(chw_image_shape, dtype=np.uint8)
+
+ _tile_characters_to_image(
+ out_image=out_image,
+ chars=tty_chars,
+ colors=tty_colors,
+ output_height_chars=nethack.nethack.TERMINAL_SHAPE[0],
+ output_width_chars=nethack.nethack.TERMINAL_SHAPE[1],
+ char_array=self.render_char_array,
+ offset_h=0,
+ offset_w=0,
+ )
+
+ return out_image
+ else:
+ return self.env.render()
diff --git a/sf_examples/nethack/utils/wrappers/seed_action_space.py b/sf_examples/nethack/utils/wrappers/seed_action_space.py
new file mode 100644
index 000000000..2ebf363dc
--- /dev/null
+++ b/sf_examples/nethack/utils/wrappers/seed_action_space.py
@@ -0,0 +1,14 @@
+from typing import Any
+
+import gymnasium as gym
+
+
+class SeedActionSpaceWrapper(gym.Wrapper):
+ """
+ To have reproducible decorrelate experience we need to seed action space
+ """
+
+ def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
+ obs, info = self.env.reset(seed=seed, options=options)
+ self.action_space.seed(seed=seed)
+ return obs, info
diff --git a/sf_examples/nethack/utils/wrappers/task_rewards.py b/sf_examples/nethack/utils/wrappers/task_rewards.py
new file mode 100644
index 000000000..9b58722f3
--- /dev/null
+++ b/sf_examples/nethack/utils/wrappers/task_rewards.py
@@ -0,0 +1,55 @@
+import gym
+
+from sf_examples.nethack.utils.task_rewards import (
+ EatingScore,
+ GoldScore,
+ ScoutScore,
+ SokobanfillpitScore,
+ SokobansolvedlevelsScore,
+ StaircasePetScore,
+ StaircaseScore,
+)
+
+
+class TaskRewardsInfoWrapper(gym.Wrapper):
+ def __init__(self, env: gym.Env):
+ super().__init__(env)
+
+ self.tasks = [
+ EatingScore(),
+ GoldScore(),
+ ScoutScore(),
+ SokobanfillpitScore(),
+ # SokobansolvedlevelsScore(), # TODO: it could have bugs, for now turn off
+ StaircasePetScore(),
+ StaircaseScore(),
+ ]
+
+ def reset(self, **kwargs):
+ obs = self.env.reset(**kwargs)
+
+ for task in self.tasks:
+ task.reset_score()
+
+ return obs
+
+ def step(self, action):
+ # use tuple and copy to avoid shallow copy (`last_observation` would be the same as `observation`)
+ last_observation = tuple(a.copy() for a in self.env.unwrapped.last_observation)
+ obs, reward, done, info = self.env.step(action)
+ observation = tuple(a.copy() for a in self.env.unwrapped.last_observation)
+ end_status = info["end_status"]
+
+ if done:
+ info["episode_extra_stats"] = self.add_more_stats(info)
+
+ # we will accumulate rewards for each step and log them when done signal appears
+ for task in self.tasks:
+ task.reward(self.env.unwrapped, last_observation, observation, end_status)
+
+ return obs, reward, done, info
+
+ def add_more_stats(self, info):
+ extra_stats = info.get("episode_extra_stats", {})
+ new_extra_stats = {task.name: task.score for task in self.tasks}
+ return {**extra_stats, **new_extra_stats}