Skip to content

Commit

Permalink
a few more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
DenSumy committed Dec 29, 2023
1 parent a59d1a0 commit b32474d
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 41 deletions.
3 changes: 2 additions & 1 deletion rl_games/common/gymnasium_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Space
:param space: A gym/gymnasium Space
:return: Patched space (gymnasium Space)
"""

if space is None:
return None
# Gymnasium space, no convertion to be done
if isinstance(space, gymnasium.Space):
return space
Expand Down
5 changes: 2 additions & 3 deletions rl_games/common/vecenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def render(self):
pass

def reset(self):
obs, info = self.env.reset(self.saved_seed) # ignoring info for now
obs, info = self.env.reset(seed=self.saved_seed) # ignoring info for now
obs = self._obs_to_fp32(obs)
return obs

Expand Down Expand Up @@ -192,8 +192,7 @@ def __init__(self, config_name, num_actors, **kwargs):
self.use_torch = False
self.seed = kwargs.pop('seed', None)


self.remote_worker = self.ray.remote(RayWorker)
self.remote_worker = self.ray.remote(RayWorkerGymnasium)
self.workers = [self.remote_worker.remote(self.config_name, kwargs) for i in range(self.num_actors)]

if self.seed is not None:
Expand Down
46 changes: 25 additions & 21 deletions rl_games/common/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def __init__(self, env):

def reset(self, **kwargs):
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(1)
obs, _, terminated, truncated, _ = self.env.step(1)
done = terminated or truncated
if done:
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(2)
Expand All @@ -89,7 +90,8 @@ def __init__(self, env):
self.was_real_done = True

def step(self, action):
obs, reward, done, info = self.env.step(action)
obs, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
Expand Down Expand Up @@ -124,7 +126,8 @@ def __init__(self, env):
self.current_steps=0

def step(self, action):
obs, reward, done, info = self.env.step(action)
obs, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
if reward == 0:
self.current_steps += 1
else:
Expand All @@ -135,7 +138,7 @@ def step(self, action):
done = True
reward = -1
obs = self.env.reset()
return obs, reward, done, info
return obs, reward, done, truncated, info


class MaxAndSkipEnv(gym.Wrapper):
Expand Down Expand Up @@ -234,16 +237,16 @@ def __init__(self, env, k, flat = False):
else:
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=observation_space.dtype)

def reset(self):
ob = self.env.reset()
def reset(self, **kwargs):
ob, info = self.env.reset(**kwargs)
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob()

def step(self, action):
ob, reward, done, info = self.env.step(action)
ob, reward, terminated, truncated, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info
return self._get_ob(), reward, terminated, truncated, info

def _get_ob(self):
assert len(self.frames) == self.k
Expand Down Expand Up @@ -277,16 +280,16 @@ def __init__(self, env, k, transpose = False, flatten = False):
else:
self.observation_space = spaces.Box(low=0, high=1, shape=(k, shp[0]), dtype=env.observation_space.dtype)

def reset(self):
ob = self.env.reset()
def reset(self, **kwargs):
ob = self.env.reset(**kwargs)
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob()

def step(self, action):
ob, reward, done, info = self.env.step(action)
ob, reward, terminated, truncated, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info
return self._get_ob(), reward, terminated, truncated, info

def _get_ob(self):
assert len(self.frames) == self.k
Expand Down Expand Up @@ -324,8 +327,8 @@ def __init__(self, env, k, transpose = False, flatten = False):
self.observation_space = spaces.Box(low=0, high=1, shape=(k, shp[0]), dtype=env.observation_space.dtype)
self.state_space = spaces.Box(low=0, high=1, shape=(k, state_shp[0]), dtype=env.observation_space.dtype)

def reset(self):
obs_dict = self.env.reset()
def reset(self, **kwargs):
obs_dict = self.env.reset(**kwargs)
ob = obs_dict["obs"]
state = obs_dict["state"]
for _ in range(self.k):
Expand Down Expand Up @@ -378,9 +381,9 @@ def __init__(self, env, k = 2, greyscale=True):
shape = (shp[:-1] + (shp[-1] * k,))
self.observation_space = spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)

def reset(self):
def reset(self, **kwargs):
import cv2
frames = self.env.reset()
frames = self.env.reset(**kwargs)
self.frames.append(frames)

if self.greyscale:
Expand Down Expand Up @@ -511,9 +514,9 @@ def __init__(self, env, p=0.25):
self.p = p
self.last_action = 0

def reset(self):
def reset(self, **kwargs):
self.last_action = 0
return self.env.reset()
return self.env.reset(**kwargs)

def step(self, action):
if self.unwrapped.np_random.uniform() < self.p:
Expand All @@ -535,7 +538,8 @@ def get_current_room(self):
return int(ram[self.room_address])

def step(self, action):
obs, rew, done, info = self.env.step(action)
obs, rew, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
self.visited_rooms.add(self.get_current_room())
if done:
if 'scores' not in info:
Expand All @@ -544,7 +548,7 @@ def step(self, action):
self.visited_rooms.clear()
return obs, rew, done, info

def reset(self):
def reset(self, **kwargs):
return self.env.reset()


Expand Down Expand Up @@ -595,7 +599,7 @@ def step(self, action):
}
return obs, reward, done, info

def reset(self):
def reset(self, **kwargs):
obs = self.env.reset()
obs = {
'observation': obs,
Expand Down
2 changes: 1 addition & 1 deletion rl_games/configs/test/test_asymmetric_continuous.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ params:
normalize_value: True
seq_length: 4
weight_decay: 0.0000
multi_gpu: True
multi_gpu: False
bounds_loss_coef: 0.0
use_diagnostics: True
env_config:
Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def step(self, action):
is_done = jax_to_torch(is_done)
return next_obs, reward, is_done, info

def reset(self):
def reset(self, **kwargs):
# todo add random init like in collab examples?
obs = self.env.reset()
return jax_to_torch(obs)
Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/diambra/diambra.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, **kwargs):
def _preproc_state_obs(self, obs):
return obs

def reset(self):
def reset(self, **kwargs):
self._game_num += 1
obs = self.env.reset() # rename, to think remove
obs_dict = self._preproc_state_obs(obs)
Expand Down
10 changes: 6 additions & 4 deletions rl_games/envs/envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, config_name, num_actors, **kwargs):
self.use_dict_obs_space = kwargs.pop('use_dict_obs_space', False)
self.flatten_obs = kwargs.pop('flatten_obs', False) # for the dm control
self.env = envpool.make( env_name,
env_type=kwargs.pop('env_type', 'gym'),
env_type=kwargs.pop('env_type', 'gymnasium'),
num_envs=num_actors,
batch_size=self.batch_size,
**kwargs
Expand Down Expand Up @@ -63,8 +63,9 @@ def _set_scores(self, infos, dones):
self.scores *= 1 - dones

def step(self, action):
next_obs, reward, is_done, info = self.env.step(action , self.ids)
info['time_outs'] = info['TimeLimit.truncated']
next_obs, reward, is_terminated, is_truncated, info = self.env.step(action , self.ids)
is_done = is_terminated | is_truncated
info['time_outs'] = is_truncated
self._set_scores(info, is_done)
if self.flatten_obs:
next_obs = flatten_dict(next_obs)
Expand All @@ -78,6 +79,7 @@ def step(self, action):

def reset(self):
obs = self.env.reset(self.ids)
obs = obs[0]
if self.flatten_obs:
obs = flatten_dict(obs)
if self.use_dict_obs_space:
Expand All @@ -86,7 +88,7 @@ def reset(self):
'reward': np.zeros(obs.shape[0]),
'last_action': np.zeros(obs.shape[0]),
}

return obs

def get_number_of_agents(self):
Expand Down
4 changes: 2 additions & 2 deletions rl_games/envs/slimevolley_selfplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ def __init__(self, name="SlimeVolleyDiscrete-v0", **kwargs):
self.observation_space = self.env.observation_space
self.action_space = self.env.action_space

def reset(self):
def reset(self, **kwargs):
if self.agent == None:
self.create_agent(self.config_path)
obs = self.env.reset()
obs = self.env.reset( **kwargs)
self.opponent_obs = obs
self.sum_rewards = 0
return obs
Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/smac_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _preproc_state_obs(self, state, obs):
def get_number_of_agents(self):
return self.n_agents

def reset(self):
def reset(self, **kwargs):
if self._game_num % self.replay_save_freq == 1:
print('saving replay')
self.env.save_replay()
Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/smac_v2_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _preproc_state_obs(self, state, obs):
def get_number_of_agents(self):
return self.n_agents

def reset(self):
def reset(self, **kwargs)):
if self._game_num % self.replay_save_freq == 1:
print('saving replay')
self.env.save_replay()
Expand Down
4 changes: 2 additions & 2 deletions rl_games/envs/test/rnn_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, **kwargs):
def get_number_of_agents(self):
return 1

def reset(self):
def reset(self, seed):
self._curr_steps = 0
self._current_pos = [0,0]
bound = self.max_dist - self.min_dist
Expand Down Expand Up @@ -143,7 +143,7 @@ def step(self, action):
else:
reward = reward[0] + reward[1]

return obses, np.array(reward).astype(np.float32), done, info
return obses, np.array(reward).astype(np.float32), done, false, info

def has_action_mask(self):
return False
6 changes: 3 additions & 3 deletions rl_games/envs/test/test_asymmetric_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def __init__(self, wrapped_env_name, **kwargs):
def get_number_of_agents(self):
return self.n_agents

def reset(self):
obs = self.env.reset()
def reset(self, seed):
obs, info = self.env.reset(seed)
obs_dict = {}
obs_dict["obs"] = obs * self.mask
obs_dict["state"] = obs
Expand All @@ -46,7 +46,7 @@ def step(self, actions):
obses = obs_dict
else:
obses = obs_dict["obs"].astype(np.float32)
return obses, rewards, dones, info
return obses, rewards, dones, false, info

def has_action_mask(self):
return False

0 comments on commit b32474d

Please sign in to comment.