From 72070fcf12d24da75e4a6880f17b9318715a410c Mon Sep 17 00:00:00 2001 From: Ahmed Hendawy Date: Mon, 2 Sep 2024 11:22:14 +0200 Subject: [PATCH] [FIX] wrapper_args is a list of dicts --- mushroom_rl/environments/gym_env.py | 6 +++--- mushroom_rl/environments/gymnasium_env.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mushroom_rl/environments/gym_env.py b/mushroom_rl/environments/gym_env.py index 325c2d2b..191d8c84 100644 --- a/mushroom_rl/environments/gym_env.py +++ b/mushroom_rl/environments/gym_env.py @@ -36,7 +36,7 @@ def __init__(self, name, horizon=None, gamma=0.99, wrappers=None, wrappers_args= a tuple with two elements: the gym wrapper class and a dictionary containing the parameters needed by the wrapper constructor; - wrappers_args (list, None): list of list of arguments for each wrapper; + wrappers_args (list, None): list of dictionaries of arguments for each wrapper; ** env_args: other gym environment parameters. """ @@ -56,9 +56,9 @@ def __init__(self, name, horizon=None, gamma=0.99, wrappers=None, wrappers_args= wrappers_args = [dict()] * len(wrappers) for wrapper, args in zip(wrappers, wrappers_args): if isinstance(wrapper, tuple): - self.env = wrapper[0](self.env, *args, **wrapper[1]) + self.env = wrapper[0](self.env, **args, **wrapper[1]) else: - self.env = wrapper(self.env, *args, **env_args) + self.env = wrapper(self.env, **args, **env_args) horizon = self._set_horizon(self.env, horizon) diff --git a/mushroom_rl/environments/gymnasium_env.py b/mushroom_rl/environments/gymnasium_env.py index 80519a12..ee42945c 100644 --- a/mushroom_rl/environments/gymnasium_env.py +++ b/mushroom_rl/environments/gymnasium_env.py @@ -38,7 +38,7 @@ def __init__(self, name, horizon=None, gamma=0.99, headless = False, wrappers=No a tuple with two elements: the gym wrapper class and a dictionary containing the parameters needed by the wrapper constructor; - wrappers_args (list, None): list of list of arguments for each wrapper; + wrappers_args (list, None): list of dictionaries of arguments for each wrapper; ** env_args: other gym environment parameters. """ @@ -60,9 +60,9 @@ def __init__(self, name, horizon=None, gamma=0.99, headless = False, wrappers=No wrappers_args = [dict()] * len(wrappers) for wrapper, args in zip(wrappers, wrappers_args): if isinstance(wrapper, tuple): - self.env = wrapper[0](self.env, *args, **wrapper[1]) + self.env = wrapper[0](self.env, **args, **wrapper[1]) else: - self.env = wrapper(self.env, *args, **env_args) + self.env = wrapper(self.env, **args, **env_args) horizon = self._set_horizon(self.env, horizon)