Skip to content

Commit

Permalink
adding new functionalities
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaKapoor74 committed Sep 3, 2020
1 parent 8d2cf06 commit 1365585
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 8 deletions.
104 changes: 104 additions & 0 deletions genrl/environments/gym_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,107 @@ def close(self) -> None:
Closes environment
"""
self.env.close()


class MultiGymWrapper(gym.Wrapper):
"""
Wrapper class for all MultiAgent Particle Environments
:param env: Gym environment name
:param n_envs: Number of environments. None if not vectorised
:param parallel: If vectorised, should environments be run through \
serially or parallelly
:type env: string
:type n_envs: None, int
:type parallel: boolean
"""

def __init__(self, env: gym.Env):
super(GymWrapper, self).__init__(env)
self.env = env

self.observation_space = self.env.observation_space
self.action_space = self.env.action_space

self.state = None
self.action = None
self.reward = None
self.done = False
self.info = {}

def __getattr__(self, name: str) -> Any:
"""
All other calls would go to base env
"""
env = super(GymWrapper, self).__getattribute__("env")
return getattr(env, name)

@property
def obs_shape(self):
if isinstance(self.env.observation_space, gym.spaces.Discrete):
obs_shape = (1,)
elif isinstance(self.env.observation_space, gym.spaces.Box):
obs_shape = self.env.observation_space.shape
return obs_shape

@property
def action_shape(self):
if isinstance(self.env.action_space, gym.spaces.Box):
action_shape = self.env.action_space.shape
elif isinstance(self.env.action_space, gym.spaces.Discrete):
action_shape = (1,)
return action_shape

def sample(self) -> np.ndarray:
"""
Shortcut method to directly sample from environment's action space
:returns: Random action from action space
:rtype: NumPy Array
"""
return self.env.action_space.sample()

def render(self, mode: str = "human") -> None:
"""
Renders all envs in a tiles format similar to baselines.
:param mode: Can either be 'human' or 'rgb_array'. \
Displays tiled images in 'human' and returns tiled images in 'rgb_array'
:type mode: string
"""
self.env.render(mode=mode)

def seed(self, seed: int = None) -> None:
"""
Set environment seed
:param seed: Value of seed
:type seed: int
"""
self.env.seed(seed)

def step(self, action: np.ndarray) -> np.ndarray:
"""
Steps the env through given action
:param action: Action taken by agent
:type action: NumPy array
:returns: Next observation, reward, game status and debugging info
"""
self.state, self.reward, self.done, self.info = self.env.step(action)
self.action = action
return self.state, self.reward, self.done, self.info

def reset(self) -> np.ndarray:
"""
Resets environment
:returns: Initial state
"""
return self.env.reset()

def close(self) -> None:
"""
Closes environment
"""
self.env.close()
71 changes: 63 additions & 8 deletions genrl/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,52 @@ def mlp(

return nn.Sequential(*layers)


# If at all you need to concatenate states to actions after passing states through n FC layers
def mlp_(
self,
layer_sizes,
weight_init,
activation_func,
concat_ind,
sac
):
"""
Generates an MLP model given sizes of each layer
:param layer_sizes: Sizes of hidden layers
:param weight_init: type of weight initialization
:param activation_func: type of activation function
:param concat_ind: index of layer at which actions to be concatenated
:param sac: True if Soft Actor Critic is being used, else False
:type layer_sizes: tuple or list
:type concat_ind: int
:type sac: bool
:type weight_init,activation_func: string
:returns: (Neural Network with fully-connected linear layers and
activation layers)
"""
layers = []
limit = len(layer_sizes) if sac is False else len(sizes) - 1

# add more activations
activation = nn.Tanh() if activation_func == "tanh" else nn.ReLU()

# add more weight init
if weight_init == "xavier_uniform":
weight_init = torch.nn.init.xavier_uniform_
elif weight_init == "xavier_normal":
weight_init = torch.nn.init.xavier_normal_


for layer in range(limit - 1):
if layer==concat_ind:
continue
act = activation if layer < limit - 2 else nn.Identity()
layers += [nn.Linear(sizes[layer], sizes[layer + 1]), act]
weight_init(layers[-1][0].weight)


def shared_mlp(
network1_prev,
network2_prev,
Expand Down Expand Up @@ -102,14 +148,6 @@ def shared_mlp(
network2_post = nn.ModuleList()
# add more activation functions
if activation_func == "relu":
activation = F.relu
elif activation_func == "tanh":
activation = torch.tanh
else:
activation = None
# add more weight init
if weight_init == "xavier_uniform":
weight_init = torch.nn.init.xavier_uniform_
Expand All @@ -118,32 +156,49 @@ def shared_mlp(
else:
weight_init = None
if activation_func == "relu":
activation = nn.ReLU()
elif activation_func == "tanh":
activation = nn.Tanh()
else:
activation = None
if len(shared) != 0 or len(network1_post) != 0 or len(network2_post) != 0:
if not (network1_prev[-1]==network2_prev[-1] and network1_prev[-1]==shared[0] and network1_post[0]==network2_post[0] and network1_post[0]==shared[-1]):
raise ValueError
for i in range(len(network1_prev)-1):
network1_prev.append(nn.Linear(network1_prev[i],network1_prev[i+1]))
if activation is not None:
network1_prev.append(activation)
if weight_init is not None:
weight_init(network1_prev[-1].weight)
for i in range(len(network2_prev)-1):
network2_prev.append(nn.Linear(network2_prev[i],network2_prev[i+1]))
if activation is not None:
network2_prev.append(activation)
if weight_init is not None:
weight_init(network2_prev[-1].weight)
for i in range(len(shared)-1):
shared.append(nn.Linear(shared[i], shared[i+1]))
if activation is not None:
shared.append(activation)
if weight_init is not None:
weight_init(shared[-1].weight)
for i in range(len(network1_post)-1):
network1_post.append(nn.Linear(network1_post[i],network1_post[i+1]))
if activation is not None:
network1_post.append(activation)
if weight_init is not None:
weight_init(network1_post[-1].weight)
for i in range(len(network2_post)-1):
network2_post.append(nn.Linear(network2_post[i],network2_post[i+1]))
if activation is not None:
network2_post.append(activation)
if weight_init is not None:
weight_init(network2_post[-1].weight)
Expand Down

0 comments on commit 1365585

Please sign in to comment.