Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NetHack: fix rendering, handle timeouts #294

Merged
merged 3 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/09-environment-integrations/nethack.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ git clone https://github.com/facebookresearch/nle.git nle && cd nle \
&& sed '/self\.nethack\.get_current_seeds = f/d' nle/env/tasks.py -i \
&& sed '/def seed(self, core=None, disp=None, reseed=True):/d' nle/env/tasks.py -i \
&& sed '/raise RuntimeError("NetHackChallenge doesn.t allow seed changes")/d' nle/env/tasks.py -i \
&& sed -i '/def render(self, mode="human"):/a\ if not self.last_observation:\n return' nle/env/base.py \
&& python setup.py install && cd ..

# install sample factory with nethack extras
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include = '\.pyi?$'
py_version = 38
line_length = 120
profile = 'black'
known_third_party = ["nle"]

[tool.pytest.ini_options]
addopts = "-s"
1 change: 1 addition & 0 deletions sample_factory/algo/utils/model_sharing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Utilities for sharing model parameters between components.
"""

import sys
from typing import Optional

Expand Down
1 change: 1 addition & 0 deletions sample_factory/algo/utils/running_mean_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
All credit goes to https://github.com/Denys88/rl_games (only slightly changed here, mostly with in-place operations)
Thanks a lot, great module!
"""

from typing import Dict, Final, List, Optional, Union

import gymnasium as gym
Expand Down
1 change: 1 addition & 0 deletions sample_factory/envs/env_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Gym env wrappers that make the environment suitable for the RL algorithms.

"""

import json
import os
from os.path import join
Expand Down
1 change: 0 additions & 1 deletion sample_factory/launcher/run_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

"""


import os
import time
from os.path import join
Expand Down
1 change: 0 additions & 1 deletion sample_factory/pbt/population_based_training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Population-Based Training implementation, inspired by https://arxiv.org/abs/1807.01281."""


import copy
import json
import math
Expand Down
1 change: 1 addition & 0 deletions sample_factory/utils/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
If no data normalization is needed we just keep the original data.
Otherwise, we create a copy of data and do all of the operations operations in-place.
"""

from typing import Dict

import torch
Expand Down
1 change: 0 additions & 1 deletion sf_examples/brax/brax_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
render_mode=human (although still slow and in low resolution).
"""


from typing import Dict, List, Tuple

import brax
Expand Down
1 change: 1 addition & 0 deletions sf_examples/brax/train_brax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Brax env integration.
"""

import sys
from typing import Dict, List, Optional, Tuple, Union

Expand Down
1 change: 1 addition & 0 deletions sf_examples/nethack/models/chaotic_dwarf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import torch
from nle import nethack
from torch import nn
Expand Down
15 changes: 15 additions & 0 deletions sf_examples/nethack/nethack_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from sample_factory.algo.utils.gymnasium_utils import patch_non_gymnasium_env
from sf_examples.nethack.utils.wrappers import (
BlstatsInfoWrapper,
GymV21CompatibilityV0,
NLETimeLimit,
PrevActionsWrapper,
RenderCharImagesWithNumpyWrapperV2,
SeedActionSpaceWrapper,
Expand Down Expand Up @@ -95,6 +97,19 @@ def make_nethack_env(env_name, cfg, env_config, render_mode: Optional[str] = Non
env = BlstatsInfoWrapper(env)
env = TaskRewardsInfoWrapper(env)

# add TimeLimit.truncated to info
env = NLETimeLimit(env)

# convert gym env to gymnasium one, due to issues with render NLE in reset
gymnasium_env = GymV21CompatibilityV0(env=env)

# preserving potential multi-agent env attributes
if hasattr(env, "num_agents"):
gymnasium_env.num_agents = env.num_agents
if hasattr(env, "is_multiagent"):
gymnasium_env.is_multiagent = env.is_multiagent
env = gymnasium_env

env = patch_non_gymnasium_env(env)

if render_mode:
Expand Down
4 changes: 4 additions & 0 deletions sf_examples/nethack/utils/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from sf_examples.nethack.utils.wrappers.blstats_info import BlstatsInfoWrapper
from sf_examples.nethack.utils.wrappers.gym_compatibility import GymV21CompatibilityV0
from sf_examples.nethack.utils.wrappers.prev_actions import PrevActionsWrapper
from sf_examples.nethack.utils.wrappers.screen_image import RenderCharImagesWithNumpyWrapperV2
from sf_examples.nethack.utils.wrappers.seed_action_space import SeedActionSpaceWrapper
from sf_examples.nethack.utils.wrappers.task_rewards import TaskRewardsInfoWrapper
from sf_examples.nethack.utils.wrappers.timelimit import NLETimeLimit

__all__ = [
RenderCharImagesWithNumpyWrapperV2,
PrevActionsWrapper,
TaskRewardsInfoWrapper,
BlstatsInfoWrapper,
SeedActionSpaceWrapper,
NLETimeLimit,
GymV21CompatibilityV0,
]
227 changes: 227 additions & 0 deletions sf_examples/nethack/utils/wrappers/gym_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""Compatibility wrappers for OpenAI gym V22 and V26."""

# pyright: reportGeneralTypeIssues=false, reportPrivateImportUsage=false
from __future__ import annotations

import sys
from typing import Any

import gymnasium
from gymnasium import error
from gymnasium.core import ActType, ObsType
from gymnasium.error import MissingArgument
from gymnasium.logger import warn
from gymnasium.spaces import Box, Dict, Discrete, Graph, MultiBinary, MultiDiscrete, Sequence, Text, Tuple
from gymnasium.utils.step_api_compatibility import convert_to_terminated_truncated_step_api

if sys.version_info >= (3, 8):
from typing import Protocol, runtime_checkable
else:
from typing_extensions import Protocol, runtime_checkable


try:
import gym
import gym.wrappers
except ImportError as e:
GYM_IMPORT_ERROR = e
else:
GYM_IMPORT_ERROR = None


@runtime_checkable
class LegacyV21Env(Protocol):
"""A protocol for OpenAI Gym v0.21 environment."""

observation_space: gym.Space
action_space: gym.Space

def reset(self) -> Any:
"""Reset the environment and return the initial observation."""
...

def step(self, action: Any) -> tuple[Any, float, bool, dict]:
"""Run one timestep of the environment's dynamics."""
...

def render(self, mode: str | None = "human") -> Any:
"""Render the environment."""
...

def close(self):
"""Close the environment."""
...

def seed(self, seed: int | None = None):
"""Set the seed for this env's random number generator(s)."""
...


class GymV21CompatibilityV0(gymnasium.Env[ObsType, ActType]):
r"""A wrapper which can transform an environment from the old API to the new API.

Old step API refers to step() method returning (observation, reward, done, info), and reset() only retuning the observation.
New step API refers to step() method returning (observation, reward, terminated, truncated, info) and reset() returning (observation, info).
(Refer to docs for details on the API change)

Known limitations:
- Environments that use `self.np_random` might not work as expected.
"""

def __init__(
self,
env_id: str | None = None,
make_kwargs: dict | None = None,
env: gym.Env | None = None,
render_mode: str | None = None,
):
"""A wrapper which converts old-style envs to valid modern envs.

Some information may be lost in the conversion, so we recommend updating your environment.
"""
if GYM_IMPORT_ERROR is not None:
raise error.DependencyNotInstalled(
f"{GYM_IMPORT_ERROR} (Hint: You need to install gym with `pip install gym` to use gym environments"
)

if make_kwargs is None:
make_kwargs = {}

if env is not None:
gym_env = env
elif env_id is not None:
gym_env = gym.make(env_id, **make_kwargs)
else:
raise MissingArgument("Either env_id or env must be provided to create a legacy gym environment.")
self.observation_space = _convert_space(gym_env.observation_space)
self.action_space = _convert_space(gym_env.action_space)

gym_env = _strip_default_wrappers(gym_env)

self.metadata = getattr(gym_env, "metadata", {"render_modes": []})
self.render_mode = render_mode
self.reward_range = getattr(gym_env, "reward_range", None)
self.spec = getattr(gym_env, "spec", None)

self.gym_env: LegacyV21Env = gym_env

def __getattr__(self, item: str):
"""Gets an attribute that only exists in the base environments."""
return getattr(self.gym_env, item)

def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[ObsType, dict]:
"""Resets the environment.

Args:
seed: the seed to reset the environment with
options: the options to reset the environment with

Returns:
(observation, info)
"""
if seed is not None:
self.gym_env.seed(seed)

# Options are ignored - https://github.com/openai/gym/blob/c755d5c35a25ab118746e2ba885894ff66fb8c43/gym/core.py
if options is not None:
warn(f"Gym v21 environment do not accept options as a reset parameter, options={options}")

obs, info = self.gym_env.reset(), {}

if self.render_mode is not None:
self.render()

return obs, info

def step(self, action: ActType) -> tuple[Any, float, bool, bool, dict]:
"""Steps through the environment.

Args:
action: action to step through the environment with

Returns:
(observation, reward, terminated, truncated, info)
"""
obs, reward, done, info = self.gym_env.step(action)

if self.render_mode is not None:
self.render()

return convert_to_terminated_truncated_step_api((obs, reward, done, info))

def render(self) -> Any:
"""Renders the environment.

Returns:
The rendering of the environment, depending on the render mode
"""
return self.gym_env.render(mode=self.render_mode)

def close(self):
"""Closes the environment."""
self.gym_env.close()

def __str__(self):
"""Returns the wrapper name and the unwrapped environment string."""
return f"<{type(self).__name__}{self.gym_env}>"

def __repr__(self):
"""Returns the string representation of the wrapper."""
return str(self)


def _strip_default_wrappers(env: gym.Env) -> gym.Env:
"""Strips builtin wrappers from the environment.

Args:
env: the environment to strip builtin wrappers from

Returns:
The environment without builtin wrappers
"""
default_wrappers = ()
if hasattr(gym.wrappers, "render_collection"):
default_wrappers += (gym.wrappers.render_collection.RenderCollection,)
if hasattr(gym.wrappers, "human_rendering"):
default_wrappers += (gym.wrappers.human_rendering.HumanRendering,)
while isinstance(env, default_wrappers):
env = env.env
return env


def _convert_space(space: gym.Space) -> gymnasium.Space:
"""Converts a gym space to a gymnasium space.

Args:
space: the space to convert

Returns:
The converted space
"""
if isinstance(space, gym.spaces.Discrete):
return Discrete(n=space.n)
elif isinstance(space, gym.spaces.Box):
return Box(low=space.low, high=space.high, shape=space.shape, dtype=space.dtype)
elif isinstance(space, gym.spaces.MultiDiscrete):
return MultiDiscrete(nvec=space.nvec)
elif isinstance(space, gym.spaces.MultiBinary):
return MultiBinary(n=space.n)
elif isinstance(space, gym.spaces.Tuple):
return Tuple(spaces=tuple(map(_convert_space, space.spaces)))
elif isinstance(space, gym.spaces.Dict):
return Dict(spaces={k: _convert_space(v) for k, v in space.spaces.items()})
elif isinstance(space, gym.spaces.Sequence):
return Sequence(space=_convert_space(space.feature_space))
elif isinstance(space, gym.spaces.Graph):
return Graph(
node_space=_convert_space(space.node_space), # type: ignore
edge_space=_convert_space(space.edge_space), # type: ignore
)
elif isinstance(space, gym.spaces.Text):
return Text(
max_length=space.max_length,
min_length=space.min_length,
charset=space._char_str,
)
else:
raise NotImplementedError(f"Cannot convert space of type {space}. Please upgrade your code to gymnasium.")
12 changes: 12 additions & 0 deletions sf_examples/nethack/utils/wrappers/timelimit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import gym
from nle.env.base import NLE


class NLETimeLimit(gym.Wrapper):
def __init__(self, env: gym.Env):
super().__init__(env)

def step(self, action):
obs, reward, done, info = super().step(action)
info["TimeLimit.truncated"] = True if info["end_status"] == NLE.StepStatus.ABORTED else False
return obs, reward, done, info
1 change: 1 addition & 0 deletions sf_examples/train_custom_env_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
python -m sf_examples.enjoy_custom_env_custom_model --algo=APPO --env=my_custom_env_v1 --experiment=example

"""

from __future__ import annotations

import sys
Expand Down
1 change: 1 addition & 0 deletions sf_examples/train_custom_multi_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
python -m sf_examples.enjoy_custom_multi_env --algo=APPO --env=my_custom_multi_env_v1 --experiment=example_multi

"""

from __future__ import annotations

import random
Expand Down
1 change: 1 addition & 0 deletions sf_examples/vizdoom/train_custom_vizdoom_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
python -m sf_examples.vizdoom.enjoy_custom_vizdoom_env --algo=APPO --env=doom_my_custom_env --experiment=doom_my_custom_env_example

"""

import argparse
import functools
import os
Expand Down
Loading