-
-
Notifications
You must be signed in to change notification settings - Fork 404
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fix import * add gym_wrapper * update test * update gym version * minor refactor * Test commit to see that I do not nuke all of git yet again * Give warnings on DELTA buttons instead of errors * Remove env params. Change to dict obs * Fix things and tests * Fix for GRAY8 * Clean up code * Cleaning up * Update example * Add docs and stable-baselines3 example * Add support for automap buffer to ViZDoomEnv, add warning when format other than RGB24 or GRAY8, fix displaying of GRAY8 images by PyGame * Update render, force mode * Cleaning * Fix typo in VizdoomEnv.__collect_observations Co-authored-by: Anssi 'Miffyli' Kanervisto <[email protected]> Co-authored-by: Marek Wydmuch <[email protected]>
- Loading branch information
1 parent
986157a
commit 92df36c
Showing
13 changed files
with
532 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# OpenAI Gym wrappers | ||
|
||
Installing ViZDoom with `pip install vizdoom[gym]` will include | ||
Gym wrappers to interact with ViZDoom over [Gym API](https://www.gymlibrary.ml/). | ||
|
||
These wrappers are under `gym_wrappers`, containing the basic environment and | ||
few example environments based on the built-in scenarios. This environment | ||
simply initializes ViZDoom with the settings from the scenario config files | ||
and implements the necessary API to function as a Gym API. | ||
|
||
See following examples for use: | ||
- `examples/python/gym_wrapper.py` for basic usage | ||
- `examples/python/learning_stable_baselines.py` for example training with [stable-baselines3](https://github.com/DLR-RM/stable-baselines3/) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#!/usr/bin/env python3 | ||
|
||
##################################################################### | ||
# Example for running a vizdoom scenario as a gym env | ||
##################################################################### | ||
|
||
import gym | ||
from vizdoom import gym_wrapper | ||
|
||
if __name__ == '__main__': | ||
env = gym.make("VizdoomHealthGatheringSupreme-v0") | ||
|
||
# Rendering random rollouts for ten episodes | ||
for _ in range(10): | ||
done = False | ||
obs = env.reset() | ||
while not done: | ||
obs, rew, done, info = env.step(env.action_space.sample()) | ||
env.render() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
#!/usr/bin/env python3 | ||
|
||
##################################################################### | ||
# Example script of training agents with stable-baselines3 | ||
# on ViZDoom using the Gym API | ||
# | ||
# Note: ViZDoom must be installed with optional gym dependencies: | ||
# pip install vizdoom[gym] | ||
# You also need stable-baselines3: | ||
# pip install stable-baselines3 | ||
# | ||
# See more stable-baselines3 documentation here: | ||
# https://stable-baselines3.readthedocs.io/en/master/index.html | ||
##################################################################### | ||
|
||
from argparse import ArgumentParser | ||
|
||
import cv2 | ||
import numpy as np | ||
import gym | ||
import vizdoom.gym_wrapper | ||
|
||
from stable_baselines3 import PPO | ||
from stable_baselines3.common.env_util import make_vec_env | ||
|
||
DEFAULT_ENV = "VizdoomBasic-v0" | ||
AVAILABLE_ENVS = [env for env in [env_spec.id for env_spec in gym.envs.registry.all()] if "Vizdoom" in env] | ||
# Height and width of the resized image | ||
IMAGE_SHAPE = (60, 80) | ||
|
||
# Training parameters | ||
TRAINING_TIMESTEPS = int(1e6) | ||
N_STEPS = 128 | ||
N_ENVS = 8 | ||
FRAME_SKIP = 4 | ||
|
||
|
||
class ObservationWrapper(gym.ObservationWrapper): | ||
""" | ||
ViZDoom environments return dictionaries as observations, containing | ||
the main image as well other info. | ||
The image is also too large for normal training. | ||
This wrapper replaces the dictionary observation space with a simple | ||
Box space (i.e., only the RGB image), and also resizes the image to a | ||
smaller size. | ||
NOTE: Ideally, you should set the image size to smaller in the scenario files | ||
for faster running of ViZDoom. This can really impact performance, | ||
and this code is pretty slow because of this! | ||
""" | ||
def __init__(self, env, shape=IMAGE_SHAPE): | ||
super().__init__(env) | ||
self.image_shape = shape | ||
self.image_shape_reverse = shape[::-1] | ||
|
||
# Create new observation space with the new shape | ||
num_channels = env.observation_space["rgb"].shape[-1] | ||
new_shape = (shape[0], shape[1], num_channels) | ||
self.observation_space = gym.spaces.Box(0, 255, shape=new_shape, dtype=np.uint8) | ||
|
||
def observation(self, observation): | ||
observation = cv2.resize(observation["rgb"], self.image_shape_reverse) | ||
return observation | ||
|
||
|
||
def main(args): | ||
# Create multiple environments: this speeds up training with PPO | ||
# We apply two wrappers on the environment: | ||
# 1) The above wrapper that modifies the observations (takes only the image and resizes it) | ||
# 2) A reward scaling wrapper. Normally the scenarios use large magnitudes for rewards (e.g., 100, -100). | ||
# This may lead to unstable learning, and we scale the rewards by 1/100 | ||
def wrap_env(env): | ||
env = ObservationWrapper(env) | ||
env = gym.wrappers.TransformReward(env, lambda r: r * 0.01) | ||
return env | ||
|
||
envs = make_vec_env( | ||
args.env, | ||
n_envs=N_ENVS, | ||
wrapper_class=wrap_env | ||
) | ||
|
||
agent = PPO("CnnPolicy", envs, n_steps=N_STEPS, verbose=1) | ||
|
||
# Do the actual learning | ||
# This will print out the results in the console. | ||
# If agent gets better, "ep_rew_mean" should increase steadily | ||
agent.learn(total_timesteps=TRAINING_TIMESTEPS) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser("Train stable-baselines3 PPO agents on ViZDoom.") | ||
parser.add_argument("--env", | ||
default=DEFAULT_ENV, | ||
choices=AVAILABLE_ENVS, | ||
help="Name of the environment to play") | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from gym.envs.registration import register | ||
|
||
register( | ||
id="VizdoomBasic-v0", | ||
entry_point="vizdoom.gym_wrapper.gym_env_defns:VizdoomScenarioEnv", | ||
kwargs={"scenario_file": "basic.cfg"} | ||
) | ||
|
||
register( | ||
id="VizdoomCorridor-v0", | ||
entry_point="vizdoom.gym_wrapper.gym_env_defns:VizdoomScenarioEnv", | ||
kwargs={"scenario_file": "deadly_corridor.cfg"} | ||
) | ||
|
||
register( | ||
id="VizdoomDefendCenter-v0", | ||
entry_point="vizdoom.gym_wrapper.gym_env_defns:VizdoomScenarioEnv", | ||
kwargs={"scenario_file": "defend_the_center.cfg"} | ||
) | ||
|
||
register( | ||
id="VizdoomDefendLine-v0", | ||
entry_point="vizdoom.gym_wrapper.gym_env_defns:VizdoomScenarioEnv", | ||
kwargs={"scenario_file": "defend_the_line.cfg"} | ||
) | ||
|
||
register( | ||
id="VizdoomHealthGathering-v0", | ||
entry_point="vizdoom.gym_wrapper.gym_env_defns:VizdoomScenarioEnv", | ||
kwargs={"scenario_file": "health_gathering.cfg"} | ||
) | ||
|
||
register( | ||
id="VizdoomMyWayHome-v0", | ||
entry_point="vizdoom.gym_wrapper.gym_env_defns:VizdoomScenarioEnv", | ||
kwargs={"scenario_file": "my_way_home.cfg"} | ||
) | ||
|
||
register( | ||
id="VizdoomPredictPosition-v0", | ||
entry_point="vizdoom.gym_wrapper.gym_env_defns:VizdoomScenarioEnv", | ||
kwargs={"scenario_file": "predict_position.cfg"} | ||
) | ||
|
||
register( | ||
id="VizdoomTakeCover-v0", | ||
entry_point="vizdoom.gym_wrapper.gym_env_defns:VizdoomScenarioEnv", | ||
kwargs={"scenario_file": "take_cover.cfg"} | ||
) | ||
|
||
register( | ||
id="VizdoomDeathmatch-v0", | ||
entry_point="vizdoom.gym_wrapper.gym_env_defns:VizdoomScenarioEnv", | ||
kwargs={"scenario_file": "deathmatch.cfg"} | ||
) | ||
|
||
register( | ||
id="VizdoomHealthGatheringSupreme-v0", | ||
entry_point="vizdoom.gym_wrapper.gym_env_defns:VizdoomScenarioEnv", | ||
kwargs={"scenario_file": "health_gathering_supreme.cfg"} | ||
) |
Oops, something went wrong.