Skip to content

Commit

Permalink
fix(xcy): fix render settings when using gymnasium (#173)
Browse files Browse the repository at this point in the history
* polish(xcy):fix the render in gymnasium

* polish(xcy):change the arguments of save gif
  • Loading branch information
HarryXuancy authored Dec 22, 2023
1 parent 27188cf commit 95e94b9
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 39 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ gymnasium[atari]
numpy>=1.22.4
pympler
bsuite
minigrid
minigrid
moviepy
2 changes: 1 addition & 1 deletion zoo/atari/entry/atari_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
# A boolean flag indicating whether to save the video of the environment.
main_config.env.save_replay = True
# The path where the recorded video will be saved.
main_config.env.save_path = './video'
main_config.env.replay_path = './video'
# The maximum number of steps for each episode during evaluation. This may need to be adjusted based on the specific characteristics of the environment.
main_config.env.eval_max_episode_steps = int(20)

Expand Down
22 changes: 11 additions & 11 deletions zoo/atari/envs/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ClipRewardWrapper, FrameStackWrapper
from ding.utils.compression_helper import jpeg_data_compressor
from easydict import EasyDict
from gym.wrappers import RecordVideo
from gymnasium.wrappers import RecordVideo


# only for reference now
Expand Down Expand Up @@ -93,8 +93,17 @@ def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) ->
if config.render_mode_human:
env = gymnasium.make(config.env_name, render_mode='human')
else:
env = gymnasium.make(config.env_name)
env = gymnasium.make(config.env_name, render_mode='rgb_array')
assert 'NoFrameskip' in env.spec.id
if config.save_replay:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
video_name = f'{env.spec.id}-video-{timestamp}'
env = RecordVideo(
env,
video_folder=config.replay_path,
episode_trigger=lambda episode_id: True,
name_prefix=video_name
)
env = GymnasiumToGymWrapper(env)
env = NoopResetWrapper(env, noop_max=30)
env = MaxAndSkipWrapper(env, skip=config.frame_skip)
Expand All @@ -108,15 +117,6 @@ def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) ->
env = ScaledFloatFrameWrapper(env)
if clip_rewards:
env = ClipRewardWrapper(env)
if config.save_replay:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
video_name = f'{env.spec.id}-video-{timestamp}'
env = RecordVideo(
env,
video_folder=config.replay_path,
episode_trigger=lambda episode_id: True,
name_prefix=video_name
)

env = JpegWrapper(env, transform2string=config.transform2string)
if config.game_wrapper:
Expand Down
4 changes: 2 additions & 2 deletions zoo/box2d/bipedalwalker/envs/bipedalwalker_cont_disc_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def reset(self) -> np.ndarray:
- info_dict (:obj:`Dict[str, Any]`): Including observation, action_mask, and to_play label.
"""
if not self._init_flag:
self._env = gym.make('BipedalWalker-v3', hardcore=True)
self._env = gym.make('BipedalWalker-v3', hardcore=True, render_mode="rgb_array")
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
self._reward_space = gym.spaces.Box(
Expand Down Expand Up @@ -141,7 +141,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
if self._act_scale:
action = affine_transform(action, min_val=self._raw_action_space.low, max_val=self._raw_action_space.high)
if self._save_replay_gif:
self._frames.append(self._env.render(mode='rgb_array'))
self._frames.append(self._env.render())
obs, rew, terminated, truncated, info = self._env.step(action)
done = terminated or truncated

Expand Down
6 changes: 3 additions & 3 deletions zoo/box2d/bipedalwalker/envs/bipedalwalker_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def reset(self) -> Dict[str, np.ndarray]:
if not self._init_flag:
assert self._cfg.env_type in ['normal', 'hardcore'], "env_type must be in ['normal', 'hardcore']"
if self._cfg.env_type == 'normal':
self._env = gym.make('BipedalWalker-v3')
self._env = gym.make('BipedalWalker-v3', render_mode="rgb_array")
elif self._cfg.env_type == 'hardcore':
self._env = gym.make('BipedalWalker-v3', hardcore=True)
self._env = gym.make('BipedalWalker-v3', hardcore=True, render_mode="rgb_array")
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
self._reward_space = gym.spaces.Box(
Expand Down Expand Up @@ -134,7 +134,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
if self._act_scale:
action = affine_transform(action, min_val=self.action_space.low, max_val=self.action_space.high)
if self._save_replay_gif:
self._frames.append(self._env.render(mode='rgb_array'))
self._frames.append(self._env.render())

obs, rew, terminated, truncated, info = self._env.step(action)
done = terminated or truncated
Expand Down
5 changes: 3 additions & 2 deletions zoo/box2d/lunarlander/envs/lunarlander_cont_disc_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import gymnasium as gym
import numpy as np
from itertools import product
from ding.envs import BaseEnvTimestep
from ding.envs import ObsPlusPrevActRewWrapper
from ding.envs.common import affine_transform
Expand Down Expand Up @@ -84,7 +85,7 @@ def reset(self) -> np.ndarray:
- info_dict (:obj:`Dict[str, Any]`): Including observation, action_mask, and to_play label.
"""
if not self._init_flag:
self._env = gym.make(self._cfg.env_name)
self._env = gym.make(self._cfg.env_name, render_mode="rgb_array")
if self._replay_path is not None:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
video_name = f'{self._env.spec.id}-video-{timestamp}'
Expand Down Expand Up @@ -147,7 +148,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
if self._act_scale:
action = affine_transform(action, min_val=-1, max_val=1)
if self._save_replay_gif:
self._frames.append(self._env.render(mode='rgb_array'))
self._frames.append(self._env.render())
obs, rew, terminated, truncated, info = self._env.step(action)
done = terminated or truncated

Expand Down
4 changes: 2 additions & 2 deletions zoo/box2d/lunarlander/envs/lunarlander_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def reset(self) -> Dict[str, np.ndarray]:
- obs (:obj:`np.ndarray`): The initial observation after resetting.
"""
if not self._init_flag:
self._env = gym.make(self._cfg.env_name)
self._env = gym.make(self._cfg.env_name, render_mode="rgb_array")
if self._replay_path is not None:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
video_name = f'{self._env.spec.id}-video-{timestamp}'
Expand Down Expand Up @@ -133,7 +133,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
if self._act_scale:
action = affine_transform(action, min_val=-1, max_val=1)
if self._save_replay_gif:
self._frames.append(self._env.render(mode='rgb_array'))
self._frames.append(self._env.render())

obs, rew, terminated, truncated, info = self._env.step(action)
done = terminated or truncated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def reset(self) -> Dict[str, np.ndarray]:
if necessary. Returns the first observation.
"""
if not self._init_flag:
self._env = gym.make('CartPole-v0')
self._env = gym.make('CartPole-v0', render_mode="rgb_array")
if self._replay_path is not None:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
video_name = f'{self._env.spec.id}-video-{timestamp}'
Expand Down
39 changes: 28 additions & 11 deletions zoo/classic_control/pendulum/entry/pendulum_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,37 @@
import numpy as np

if __name__ == "__main__":
"""
model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
"""
model_path = "./ckpt/ckpt_best.pth.tar"
seeds = [0]
num_episodes_each_seed = 1
main_config.env.evaluator_env_num = 1
main_config.env.n_evaluator_episode = 1
total_test_episodes = num_episodes_each_seed * len(seeds)
Entry point for the evaluation of the MuZero model on the Pendulum environment.
Variables:
- model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the
pretrained model. An absolute path is recommended. In LightZero, the path is usually something like
``exp_name/ckpt/ckpt_best.pth.tar``.
- returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed.
- returns_seeds (:obj:`List[float]`): List to store the returns for each seed.
- seeds (:obj:`List[int]`): List of seeds for the environment.
- num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed.
- total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of
seeds and the number of episodes per seed.
"""
# model_path = "./ckpt/ckpt_best.pth.tar"
model_path = None
returns_mean_seeds = []
returns_seeds = []
seeds = [0]
num_episodes_each_seed = 2
total_test_episodes = num_episodes_each_seed * len(seeds)
create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base
main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1
main_config.env.n_evaluator_episode = total_test_episodes
main_config.env.replay_path = './video'

for seed in seeds:
"""
- returns_mean (:obj:`float`): The mean return of the evaluation.
- returns (:obj:`List[float]`): The returns of the evaluation.
"""
returns_mean, returns = eval_muzero(
[main_config, create_config],
seed=seed,
Expand All @@ -36,4 +53,4 @@
print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).")
print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.")
print("Across all seeds, the mean reward is:", returns_mean_seeds.mean())
print("=" * 20)
print("=" * 20)
8 changes: 4 additions & 4 deletions zoo/classic_control/pendulum/envs/pendulum_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def __init__(self, cfg: dict) -> None:
self._cfg = cfg
self._act_scale = cfg.act_scale
try:
self._env = gym.make('Pendulum-v1')
self._env = gym.make('Pendulum-v1', render_mode="rgb_array")
except:
self._env = gym.make('Pendulum-v0')
self._env = gym.make('Pendulum-v0', render_mode="rgb_array")
self._init_flag = False
self._replay_path = cfg.replay_path
self._continuous = cfg.get("continuous", True)
Expand All @@ -71,9 +71,9 @@ def reset(self) -> Dict[str, np.ndarray]:
"""
if not self._init_flag:
try:
self._env = gym.make('Pendulum-v1')
self._env = gym.make('Pendulum-v1', render_mode="rgb_array")
except:
self._env = gym.make('Pendulum-v0')
self._env = gym.make('Pendulum-v0', render_mode="rgb_array")
if self._replay_path is not None:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
video_name = f'{self._env.spec.id}-video-{timestamp}'
Expand Down
3 changes: 2 additions & 1 deletion zoo/minigrid/envs/minigrid_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
if self._save_replay_gif:
self._frames.append(self._env.render())
# using the step method of Gymnasium env, return is (observation, reward, terminated, truncated, info)
obs, rew, done, _, info = self._env.step(action)
obs, rew, terminated, truncated, info = self._env.step(action)
done = terminated or truncated
rew = float(rew)
self._eval_episode_return += rew
self._current_step += 1
Expand Down

0 comments on commit 95e94b9

Please sign in to comment.