Skip to content

Commit

Permalink
fix(pu): fix memory_lightzero_env return bug
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Mar 19, 2024
1 parent 9b0c0ae commit c467ec4
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions zoo/memory/envs/memory_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ class MemoryEnvLightZero(BaseEnv):
"""
Overview:
The MemoryEnvLightZero environment for LightZero, based on the Visual-Match and Key-to-Door Task from DeepMind.
Please refer to the following repository for more details:
- https://github.com/deepmind/deepmind-research/tree/master/tvt/pycolab
- https://github.com/twni2016/Memory-RL
Attributes:
config (dict): Configuration dict. Default configurations can be updated using this.
_cfg (dict): Internal configuration dict that stores runtime configurations.
Expand All @@ -30,11 +27,14 @@ class MemoryEnvLightZero(BaseEnv):
_save_replay (bool): Flag to check if replays are saved.
_render (bool): Flag to check if real-time rendering is enabled.
_gif_images (list): List to store frames for creating GIF replay.
_max_step (int): Maximum number of steps for the environment.
"""
config = dict(
env_id='visual_match', # The name of the environment, options: 'visual_match', 'key_to_door'
# max_step=60, # The maximum number of steps for each episode
num_apples=10, # Number of apples in the distractor phase
# apple_reward=(1, 10), # Range of rewards for collecting an apple
# apple_reward=(1, 1), # Range of rewards for collecting an apple
apple_reward=(0, 0), # Range of rewards for collecting an apple
fix_apple_reward_in_episode=False, # Whether to fix apple reward (DEFAULT_APPLE_REWARD) within an episode
final_reward=10.0, # Reward for choosing the correct door in the final phase
Expand Down Expand Up @@ -84,7 +84,8 @@ def reset(self) -> np.ndarray:
elif hasattr(self, '_seed'):
self._rng = np.random.RandomState(self._seed)
else:
self._rng = np.random.RandomState(0)
self._seed = 0 # TODO
self._rng = np.random.RandomState(self._seed)
print(f'memory_lightzero_env reset self._seed: {self._seed}')
if self._cfg.env_id == 'visual_match':
from zoo.memory.envs.pycolab_tvt.visual_match import Game, PASSIVE_EXPLORE_GRID
Expand All @@ -100,7 +101,7 @@ def reset(self) -> np.ndarray:
EXPLORE_GRID=PASSIVE_EXPLORE_GRID,
)
elif self._cfg.env_id == 'key_to_door':
from zoo.memory.envs.pycolab_tvt.key_to_door import Game, REWARD_GRID_SR
from zoo.memory.envs.pycolab_tvt.key_to_door import Game, REWARD_GRID_SR, MAX_FRAMES_PER_PHASE_SR
self._game = Game(
self._rng,
num_apples=self._cfg.num_apples,
Expand All @@ -122,6 +123,7 @@ def reset(self) -> np.ndarray:
self._reward_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(1,), dtype=np.float32)

self._current_step = 0
self.episode_reward_list = []
self._eval_episode_return = 0
obs, _, _ = self._episode.its_showtime()
obs = obs[0].reshape(1, 5, 5)
Expand All @@ -132,6 +134,7 @@ def reset(self) -> np.ndarray:
if self._cfg.flate_observation:
obs = obs.flatten()
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}

self._gif_images = []

return obs
Expand All @@ -150,6 +153,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
action = action.squeeze() # 0-dim array

observation, reward, _ = self._episode.play(action)
self.episode_reward_list.append(reward)
observation = observation[0].reshape(1, 5, 5)

self._current_step += 1
Expand All @@ -160,8 +164,9 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
if done:
# TODO
info['eval_episode_return'] = self._eval_episode_return
info['success'] = 1 if reward == self._cfg.final_reward else 0
info['success'] = 1 if self._eval_episode_return == self._cfg.final_reward else 0
info['eval_episode_return'] = info['success']
print(f'episode seed:{self._seed} done! self.episode_reward_list is: {self.episode_reward_list}')

observation = to_ndarray(observation, dtype=np.float32)
reward = to_ndarray([reward])
Expand Down

0 comments on commit c467ec4

Please sign in to comment.