Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding weight initialization options #223

Merged
merged 6 commits into from
Nov 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion slm_lab/agent/net/convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, net_spec, in_dim, out_dim):
super(ConvNet, self).__init__(net_spec, in_dim, out_dim)
# set default
util.set_attr(self, dict(
init_fxn='xavier_uniform_',
batch_norm=True,
clip_grad=False,
clip_grad_val=1.0,
Expand All @@ -106,6 +107,7 @@ def __init__(self, net_spec, in_dim, out_dim):
'hid_layers',
'hid_layers_activation',
'batch_norm',
'init_fxn',
'clip_grad',
'clip_grad_val',
'loss_spec',
Expand All @@ -130,7 +132,7 @@ def __init__(self, net_spec, in_dim, out_dim):
tail_in_dim = self.dense_hid_layers[-1] if len(self.dense_hid_layers) > 0 else self.conv_out_dim
self.model_tails = nn.ModuleList([nn.Linear(tail_in_dim, out_d) for out_d in self.out_dim])

net_util.init_layers(self.modules())
net_util.init_layers(self, self.init_fxn)
for module in self.modules():
module.to(self.device)
self.loss_fn = net_util.get_loss_fn(self, self.loss_spec)
Expand Down
4 changes: 3 additions & 1 deletion slm_lab/agent/net/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, net_spec, in_dim, out_dim):
super(MLPNet, self).__init__(net_spec, in_dim, out_dim)
# set default
util.set_attr(self, dict(
init_fxn='xavier_uniform_',
clip_grad=False,
clip_grad_val=1.0,
loss_spec={'name': 'MSELoss'},
Expand All @@ -78,6 +79,7 @@ def __init__(self, net_spec, in_dim, out_dim):
'separate',
'hid_layers',
'hid_layers_activation',
'init_fxn',
'clip_grad',
'clip_grad_val',
'loss_spec',
Expand All @@ -100,7 +102,7 @@ def __init__(self, net_spec, in_dim, out_dim):
else: # if more than 1 output, add last layer as tails separate from main model
self.model_tails = nn.ModuleList([nn.Linear(dims[-1], out_d) for out_d in self.out_dim])

net_util.init_layers(self.modules())
net_util.init_layers(self, self.init_fxn)
for module in self.modules():
module.to(self.device)
self.loss_fn = net_util.get_loss_fn(self, self.loss_spec)
Expand Down
56 changes: 33 additions & 23 deletions slm_lab/agent/net/net_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

from slm_lab import ROOT_DIR
from slm_lab.agent.algorithm import policy_util
from slm_lab.lib import logger, util
Expand Down Expand Up @@ -83,34 +85,42 @@ def get_out_dim(body, add_critic=False):
return out_dim


def init_gru_layer(layer):
'''Initializes a GRU layer in with xavier_uniform initialization and 0 biases'''
for layer_p in layer._all_weights:
for p in layer_p:
if 'weight' in p:
torch.nn.init.xavier_uniform_(layer.__getattr__(p))
elif 'bias' in p:
torch.nn.init.constant_(layer.__getattr__(p), 0.0)
def init_layers(net, init_fxn):
if init_fxn == 'xavier_uniform_':
try:
gain = torch.nn.init.calculate_gain(net.hid_layers_activation)
except ValueError:
gain = 1
init_fxn = partial(torch.nn.init.xavier_uniform_, gain=gain)
elif 'kaiming' in init_fxn:
assert net.hid_layers_activation in ['relu', 'leaky_relu'], f'Kaiming initialization not supported for {net.hid_layers_activation}'
init_fxn = torch.nn.init.__dict__[init_fxn]
init_fxn = partial(init_fxn, nonlinearity=net.hid_layers_activation)
else:
init_fxn = torch.nn.init.__dict__[init_fxn]
net.apply(partial(init_parameters, init_fxn=init_fxn))


def init_layers(layers):
def init_parameters(module, init_fxn):
'''
Initializes all of the layers of type 'Linear', 'Conv', or 'GRU', using xavier uniform initialization for the weights and 0.01 for the biases, 0.0 for the biases of the GRU.
Initializes all layers of type 'BatchNorm' using uniform initialization for the weights and the same as above for the biases
Initializes module's weights using init_fxn, which is the name of function from from torch.nn.init
Initializes module's biases to either 0.01 or 0.0, depending on module
The only exception is BatchNorm layers, for which we use uniform initialization
'''
bias_init = 0.01
for layer in layers:
classname = layer.__class__.__name__
if 'BatchNorm' in classname:
torch.nn.init.uniform_(layer.weight.data)
torch.nn.init.constant_(layer.bias.data, bias_init)
elif 'GRU' in classname:
init_gru_layer(layer)
elif 'Linear' in classname:
torch.nn.init.xavier_uniform_(layer.weight.data)
torch.nn.init.constant_(layer.bias.data, bias_init)
else:
pass
classname = module.__class__.__name__
if 'BatchNorm' in classname:
init_fxn(module.weight)
torch.nn.init.constant_(module.bias, bias_init)
elif 'GRU' in classname:
for name, param in module.named_parameters():
if 'weight' in name:
init_fxn(param)
elif 'bias' in name:
torch.nn.init.constant_(param, 0.0)
elif 'Linear' in classname or ('Conv' in classname and 'Net' not in classname):
init_fxn(module.weight)
torch.nn.init.constant_(module.bias, bias_init)
kengz marked this conversation as resolved.
Show resolved Hide resolved


# lr decay methods
Expand Down
4 changes: 3 additions & 1 deletion slm_lab/agent/net/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, net_spec, in_dim, out_dim):
super(RecurrentNet, self).__init__(net_spec, in_dim, out_dim)
# set default
util.set_attr(self, dict(
init_fxn='xavier_uniform_',
rnn_num_layers=1,
clip_grad=False,
clip_grad_val=1.0,
Expand All @@ -92,6 +93,7 @@ def __init__(self, net_spec, in_dim, out_dim):
'rnn_hidden_size',
'rnn_num_layers',
'seq_len',
'init_fxn',
'clip_grad',
'clip_grad_val',
'loss_spec',
Expand Down Expand Up @@ -120,7 +122,7 @@ def __init__(self, net_spec, in_dim, out_dim):
# tails
self.model_tails = nn.ModuleList([nn.Linear(self.rnn_hidden_size, out_d) for out_d in self.out_dim])

net_util.init_layers(self.modules())
net_util.init_layers(self, self.init_fxn)
for module in self.modules():
module.to(self.device)
self.loss_fn = net_util.get_loss_fn(self, self.loss_spec)
Expand Down