Skip to content

Commit

Permalink
shared mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaKapoor74 committed Sep 3, 2020
1 parent 0927001 commit daa8b2a
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions genrl/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,97 @@ def mlp(

return nn.Sequential(*layers)

def shared_mlp(
network1_prev,
network2_prev,
shared,
network1_post,
network2_post,
weight_init,
activation_func,
sac
)
"""
Generates an MLP model given sizes of each layer (Mostly used for SharedActorCritic)

:param network1_prev: Sizes of network1's initial layers
:param network2_prev: Sizes of network2's initial layers
:param shared: Sizes of shared layers
:param network1_post: Sizes of network1's latter layers
:param network2_post: Sizes of network2's latter layers
:param weight_init: type of weight initialization
:param activation_func: type of activation function
:param sac: True if Soft Actor Critic is being used, else False
:type network1_prev,network2_prev,shared,network1_post,network2_post: tuple or list
:type weight_init,activation_func: string
:type sac: bool
:returns: network1 and networ2(Neural Network with fully-connected linear layers and
activation layers)
"""
if len(network1_prev) != 0:
network1_prev = nn.ModuleList()
if len(network2_prev) != 0:
network2_prev = nn.ModuleList()
if len(shared) != 0:
shared = nn.ModuleList()
if len(network1_post) != 0:
network1_post = nn.ModuleList()
if len(network2_post) != 0:
network2_post = nn.ModuleList()
# add more activation functions
if activation_func == "relu":
activation = F.relu
elif activation_func == "tanh":
activation = torch.tanh
else:
activation = None
# add more weight init
if weight_init == "xavier_uniform":
weight_init = torch.nn.init.xavier_uniform_
elif weight_init == "xavier_normal":
weight_init = torch.nn.init.xavier_normal_
else:
weight_init = None
if len(shared) != 0 or len(network1_post) != 0 or len(network2_post) != 0:
if not (network1_prev[-1]==network2_prev[-1] and network1_prev[-1]==shared[0] and network1_post[0]==network2_post[0] and network1_post[0]==shared[-1]):
raise ValueError
for i in range(len(network1_prev)-1):
network1_prev.append(nn.Linear(network1_prev[i],network1_prev[i+1]))
if weight_init is not None:
weight_init(network1_prev[-1].weight)
for i in range(len(network2_prev)-1):
network2_prev.append(nn.Linear(network2_prev[i],network2_prev[i+1]))
if weight_init is not None:
weight_init(network2_prev[-1].weight)
for i in range(len(shared)-1):
shared.append(nn.Linear(shared[i], shared[i+1]))
if weight_init is not None:
weight_init(shared[-1].weight)
for i in range(len(network1_post)-1):
network1_post.append(nn.Linear(network1_post[i],network1_post[i+1]))
if weight_init is not None:
weight_init(network1_post[-1].weight)
for i in range(len(network2_post)-1):
network2_post.append(nn.Linear(network2_post[i],network2_post[i+1]))
if weight_init is not None:
weight_init(network2_post[-1].weight)
network1 = nn.Sequential(network1_prev,shared,network1_post)
network2 = nn.Sequential(network2_prev,shared,network2_post)
return network1,network2
def cnn(
channels: Tuple = (4, 16, 32),
Expand Down

0 comments on commit daa8b2a

Please sign in to comment.