Skip to content

Commit

Permalink
Clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Nov 4, 2024
1 parent 54d5eb6 commit 44ccb51
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 45 deletions.
77 changes: 38 additions & 39 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch.nn as nn

from rl_games.algos_torch.d2rl import D2RLNet
from rl_games.algos_torch.sac_helper import SquashedNormal
from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones
from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue
from rl_games.algos_torch.sac_helper import SquashedNormal
from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones
from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue
from rl_games.algos_torch.spatial_softmax import SpatialSoftArgmax


Expand Down Expand Up @@ -88,13 +88,13 @@ def _build_rnn(self, name, input, units, layers):
if name == 'gru':
return GRUWithDones(input_size=input, hidden_size=units, num_layers=layers)

def _build_sequential_mlp(self,
input_size,
units,
def _build_sequential_mlp(self,
input_size,
units,
activation,
dense_func,
norm_only_first_layer=False,
norm_func_name = None):
norm_only_first_layer=False,
norm_func_name=None):
print('build mlp:', input_size)
in_size = input_size
layers = []
Expand All @@ -115,13 +115,13 @@ def _build_sequential_mlp(self,

return nn.Sequential(*layers)

def _build_mlp(self,
input_size,
units,
def _build_mlp(self,
input_size,
units,
activation,
dense_func,
dense_func,
norm_only_first_layer=False,
norm_func_name = None,
norm_func_name=None,
d2rl=False):
if d2rl:
act_layers = [self.activations_factory.create(activation) for i in range(len(units))]
Expand Down Expand Up @@ -213,21 +213,21 @@ def __init__(self, params, **kwargs):
self.critic_cnn = nn.Sequential()
self.actor_mlp = nn.Sequential()
self.critic_mlp = nn.Sequential()

if self.has_cnn:
if self.permute_input:
input_shape = torch_ext.shape_whc_to_cwh(input_shape)
cnn_args = {
'ctype' : self.cnn['type'],
'input_shape' : input_shape,
'convs' :self.cnn['convs'],
'convs' : self.cnn['convs'],
'activation' : self.cnn['activation'],
'norm_func_name' : self.normalization,
}
self.actor_cnn = self._build_conv(**cnn_args)

if self.separate:
self.critic_cnn = self._build_conv( **cnn_args)
self.critic_cnn = self._build_conv(**cnn_args)

cnn_output_size = self._calc_input_size(input_shape, self.actor_cnn)

Expand Down Expand Up @@ -318,7 +318,7 @@ def __init__(self, params, **kwargs):
if self.fixed_sigma:
sigma_init(self.sigma)
else:
sigma_init(self.sigma.weight)
sigma_init(self.sigma.weight)

def forward(self, obs_dict):
obs = obs_dict['obs']
Expand All @@ -339,7 +339,7 @@ def forward(self, obs_dict):
a_out = a_out.contiguous().view(a_out.size(0), -1)

c_out = self.critic_cnn(c_out)
c_out = c_out.contiguous().view(c_out.size(0), -1)
c_out = c_out.contiguous().view(c_out.size(0), -1)

if self.has_rnn:
seq_length = obs_dict.get('seq_length', 1)
Expand All @@ -359,11 +359,11 @@ def forward(self, obs_dict):
a_out = a_out.reshape(num_seqs, seq_length, -1)
c_out = c_out.reshape(num_seqs, seq_length, -1)

a_out = a_out.transpose(0,1)
c_out = c_out.transpose(0,1)
a_out = a_out.transpose(0, 1)
c_out = c_out.transpose(0, 1)
if dones is not None:
dones = dones.reshape(num_seqs, seq_length, -1)
dones = dones.transpose(0,1)
dones = dones.transpose(0, 1)

if len(states) == 2:
a_states = states[0]
Expand All @@ -374,8 +374,8 @@ def forward(self, obs_dict):
a_out, a_states = self.a_rnn(a_out, a_states, dones, bptt_len)
c_out, c_states = self.c_rnn(c_out, c_states, dones, bptt_len)

a_out = a_out.transpose(0,1)
c_out = c_out.transpose(0,1)
a_out = a_out.transpose(0, 1)
c_out = c_out.transpose(0, 1)
a_out = a_out.contiguous().reshape(a_out.size()[0] * a_out.size()[1], -1)
c_out = c_out.contiguous().reshape(c_out.size()[0] * c_out.size()[1], -1)

Expand All @@ -398,7 +398,7 @@ def forward(self, obs_dict):
else:
a_out = self.actor_mlp(a_out)
c_out = self.critic_mlp(c_out)

value = self.value_act(self.value(c_out))

if self.is_discrete:
Expand Down Expand Up @@ -474,7 +474,7 @@ def forward(self, obs_dict):
else:
sigma = self.sigma_act(self.sigma(out))
return mu, mu*0 + sigma, value, states

def is_separate_critic(self):
return self.separate

Expand Down Expand Up @@ -503,7 +503,7 @@ def get_default_rnn_state(self):
return (torch.zeros((num_layers, self.num_seqs, rnn_units)),
torch.zeros((num_layers, self.num_seqs, rnn_units)))
else:
return (torch.zeros((num_layers, self.num_seqs, rnn_units)),)
return (torch.zeros((num_layers, self.num_seqs, rnn_units)),)

def load(self, params):
self.separate = params.get('separate', False)
Expand Down Expand Up @@ -655,10 +655,10 @@ def __init__(self, params, **kwargs):

if self.has_rnn:
if not self.is_rnn_before_mlp:
rnn_in_size = out_size
rnn_in_size = out_size
out_size = self.rnn_units
else:
rnn_in_size = mlp_input_size
rnn_in_size = mlp_input_size
mlp_input_size = self.rnn_units

if self.require_rewards:
Expand All @@ -667,12 +667,12 @@ def __init__(self, params, **kwargs):
rnn_in_size += actions_num

self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers)
#self.layer_norm = torch.nn.LayerNorm(self.rnn_units)
# self.layer_norm = torch.nn.LayerNorm(self.rnn_units)

mlp_args = {
'input_size' : mlp_input_size,
'units' :self.units,
'activation' : self.activation,
'units' : self.units,
'activation' : self.activation,
'norm_func_name' : self.normalization,
'dense_func' : torch.nn.Linear
}
Expand All @@ -687,9 +687,9 @@ def __init__(self, params, **kwargs):
self.logits = torch.nn.Linear(out_size, actions_num)
if self.is_continuous:
self.mu = torch.nn.Linear(out_size, actions_num)
self.mu_act = self.activations_factory.create(self.space_config['mu_activation'])
self.mu_act = self.activations_factory.create(self.space_config['mu_activation'])
mu_init = self.init_factory.create(**self.space_config['mu_init'])
self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation'])
self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation'])
sigma_init = self.init_factory.create(**self.space_config['sigma_init'])

if self.fixed_sigma:
Expand All @@ -716,7 +716,7 @@ def __init__(self, params, **kwargs):
else:
sigma_init(self.sigma.weight)

mlp_init(self.value.weight)
mlp_init(self.value.weight)

def forward(self, obs_dict):
if self.require_rewards or self.require_last_actions:
Expand Down Expand Up @@ -827,7 +827,7 @@ def load(self, params):

def _build_impala(self, input_shape, depths):
in_channels = input_shape[0]
layers = nn.ModuleList()
layers = nn.ModuleList()
for d in depths:
layers.append(ImpalaSequential(in_channels, d))
in_channels = d
Expand All @@ -845,7 +845,7 @@ def get_default_rnn_state(self):
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)),
torch.zeros((num_layers, self.num_seqs, self.rnn_units)))
else:
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)))
return (torch.zeros((num_layers, self.num_seqs, self.rnn_units)))

def build(self, name, **kwargs):
net = A2CResnetBuilder.Network(self.params, **kwargs)
Expand Down Expand Up @@ -976,7 +976,7 @@ def forward(self, obs_dict):
obs = obs_dict['obs']
mu, sigma = self.actor(obs)
return mu, sigma

def is_separate_critic(self):
return self.separate

Expand All @@ -997,12 +997,11 @@ def load(self, params):

if self.has_space:
self.is_discrete = 'discrete' in params['space']
self.is_continuous = 'continuous'in params['space']
self.is_continuous = 'continuous' in params['space']
if self.is_continuous:
self.space_config = params['space']['continuous']
elif self.is_discrete:
self.space_config = params['space']['discrete']
else:
self.is_discrete = False
self.is_continuous = False

4 changes: 2 additions & 2 deletions rl_games/common/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,15 +635,15 @@ def observation(self, observation):

class OldGymWrapper(gym.Env):
def __init__(self, env):
import gymnasium

self.env = env

# Convert Gymnasium spaces to Gym spaces
self.observation_space = self.convert_space(env.observation_space)
self.action_space = self.convert_space(env.action_space)

def convert_space(self, space):
import gymnasium

"""Recursively convert Gymnasium spaces to Gym spaces."""
if isinstance(space, gymnasium.spaces.Box):
return gym.spaces.Box(
Expand Down
2 changes: 1 addition & 1 deletion rl_games/configs/myosuite/ppo_myo_hand_pose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ params:
e_clip: 0.2
clip_value: False
num_actors: 32
horizon_length: 128
horizon_length: 256
minibatch_size: 2048
mini_epochs: 5
critic_coef: 2
Expand Down
7 changes: 4 additions & 3 deletions rl_games/configs/myosuite/ppo_myo_hand_reach.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ params:
val: 0
fixed_sigma: True
mlp:
units: [256, 128, 64]
units: [512, 256, 128]
d2rl: False
activation: elu
initializer:
name: default
scale: 2

config:
env_name: myo_gym
name: MyoHandReachRandom
Expand All @@ -50,12 +51,12 @@ params:
e_clip: 0.2
clip_value: False
num_actors: 32
horizon_length: 128
horizon_length: 256
minibatch_size: 2048
mini_epochs: 5
critic_coef: 2
bounds_loss_coef: 0.001
max_epochs: 5000
max_epochs: 10000
use_diagnostics: True
weight_decay: 0.0
use_smooth_clamp: True
Expand Down

0 comments on commit 44ccb51

Please sign in to comment.