-
Notifications
You must be signed in to change notification settings - Fork 148
/
gym_env.py
118 lines (94 loc) · 3.82 KB
/
gym_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gym
from gym import spaces as gym_spaces
import numpy as np
try:
import pybullet_envs
import time
pybullet_found = True
except ImportError:
pybullet_found = False
from mushroom_rl.core import Environment, MDPInfo
from mushroom_rl.utils.spaces import *
gym.logger.set_level(40)
class Gym(Environment):
"""
Interface for OpenAI Gym environments. It makes it possible to use every
Gym environment just providing the id, except for the Atari games that
are managed in a separate class.
"""
def __init__(self, name, horizon=None, gamma=0.99, wrappers=None, wrappers_args=None,
**env_args):
"""
Constructor.
Args:
name (str): gym id of the environment;
horizon (int): the horizon. If None, use the one from Gym;
gamma (float, 0.99): the discount factor;
wrappers (list, None): list of wrappers to apply over the environment. It
is possible to pass arguments to the wrappers by providing
a tuple with two elements: the gym wrapper class and a
dictionary containing the parameters needed by the wrapper
constructor;
wrappers_args (list, None): list of list of arguments for each wrapper;
** env_args: other gym environment parameters.
"""
# MDP creation
self._not_pybullet = True
self._first = True
if pybullet_found and '- ' + name in pybullet_envs.getList():
import pybullet
pybullet.connect(pybullet.DIRECT)
self._not_pybullet = False
self.env = gym.make(name, **env_args)
if wrappers is not None:
if wrappers_args is None:
wrappers_args = [dict()] * len(wrappers)
for wrapper, args in zip(wrappers, wrappers_args):
if isinstance(wrapper, tuple):
self.env = wrapper[0](self.env, *args, **wrapper[1])
else:
self.env = wrapper(self.env, *args, **env_args)
if horizon is None:
horizon = self.env._max_episode_steps
self.env._max_episode_steps = np.inf # Hack to ignore gym time limit.
# MDP properties
assert not isinstance(self.env.observation_space,
gym_spaces.MultiDiscrete)
assert not isinstance(self.env.action_space, gym_spaces.MultiDiscrete)
action_space = self._convert_gym_space(self.env.action_space)
observation_space = self._convert_gym_space(self.env.observation_space)
mdp_info = MDPInfo(observation_space, action_space, gamma, horizon)
if isinstance(action_space, Discrete):
self._convert_action = lambda a: a[0]
else:
self._convert_action = lambda a: a
super().__init__(mdp_info)
def reset(self, state=None):
if state is None:
return np.atleast_1d(self.env.reset())
else:
self.env.reset()
self.env.state = state
return np.atleast_1d(state)
def step(self, action):
action = self._convert_action(action)
obs, reward, absorbing, info = self.env.step(action)
return np.atleast_1d(obs), reward, absorbing, info
def render(self, mode='human'):
if self._first or self._not_pybullet:
self.env.render(mode=mode)
self._first = False
def stop(self):
try:
if self._not_pybullet:
self.env.close()
except:
pass
@staticmethod
def _convert_gym_space(space):
if isinstance(space, gym_spaces.Discrete):
return Discrete(space.n)
elif isinstance(space, gym_spaces.Box):
return Box(low=space.low, high=space.high, shape=space.shape)
else:
raise ValueError