Skip to content

Commit

Permalink
[FIX] wrapper_args is a list of dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
AhmedMagdyHendawy committed Sep 2, 2024
1 parent c6ce144 commit 72070fc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions mushroom_rl/environments/gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, name, horizon=None, gamma=0.99, wrappers=None, wrappers_args=
a tuple with two elements: the gym wrapper class and a
dictionary containing the parameters needed by the wrapper
constructor;
wrappers_args (list, None): list of list of arguments for each wrapper;
wrappers_args (list, None): list of dictionaries of arguments for each wrapper;
** env_args: other gym environment parameters.
"""
Expand All @@ -56,9 +56,9 @@ def __init__(self, name, horizon=None, gamma=0.99, wrappers=None, wrappers_args=
wrappers_args = [dict()] * len(wrappers)
for wrapper, args in zip(wrappers, wrappers_args):
if isinstance(wrapper, tuple):
self.env = wrapper[0](self.env, *args, **wrapper[1])
self.env = wrapper[0](self.env, **args, **wrapper[1])
else:
self.env = wrapper(self.env, *args, **env_args)
self.env = wrapper(self.env, **args, **env_args)

horizon = self._set_horizon(self.env, horizon)

Expand Down
6 changes: 3 additions & 3 deletions mushroom_rl/environments/gymnasium_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, name, horizon=None, gamma=0.99, headless = False, wrappers=No
a tuple with two elements: the gym wrapper class and a
dictionary containing the parameters needed by the wrapper
constructor;
wrappers_args (list, None): list of list of arguments for each wrapper;
wrappers_args (list, None): list of dictionaries of arguments for each wrapper;
** env_args: other gym environment parameters.
"""
Expand All @@ -60,9 +60,9 @@ def __init__(self, name, horizon=None, gamma=0.99, headless = False, wrappers=No
wrappers_args = [dict()] * len(wrappers)
for wrapper, args in zip(wrappers, wrappers_args):
if isinstance(wrapper, tuple):
self.env = wrapper[0](self.env, *args, **wrapper[1])
self.env = wrapper[0](self.env, **args, **wrapper[1])
else:
self.env = wrapper(self.env, *args, **env_args)
self.env = wrapper(self.env, **args, **env_args)

horizon = self._set_horizon(self.env, horizon)

Expand Down

0 comments on commit 72070fc

Please sign in to comment.