Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: VM/torch2 #233

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 78 additions & 7 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from rl_games.common import datasets

from torch import optim
import torch
from torch import nn
import numpy as np
import gym
import torch

from typing import Optional, Tuple


class A2CAgent(a2c_common.ContinuousA2CBase):
def __init__(self, base_name, params):
Expand All @@ -23,9 +23,10 @@ def __init__(self, base_name, params):
'normalize_value' : self.normalize_value,
'normalize_input': self.normalize_input,
}

self.model = self.network.build(build_config)
self.model.to(self.ppo_device)

self.states = None
self.init_rnn_from_model(self.model)
self.last_lr = float(self.last_lr)
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(self, base_name, params):
def update_epoch(self):
self.epoch_num += 1
return self.epoch_num

def save(self, fn):
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)
Expand All @@ -74,6 +75,70 @@ def restore(self, fn):

def get_masked_action_values(self, obs, action_masks):
assert False

@torch.jit.script
def forward_for_gradients(smooth:bool, ppo:bool, has_value_loss:bool, bound_loss_type:str,
old_action_log_probs_batch, action_log_probs, advantage,
value_preds_batch, values, return_batch, mu, entropy,
curr_e_clip:float, clip_value:float, critic_coef:float, entropy_coef:float, bounds_loss_coef:float,
rnn_masks:Optional[torch.Tensor], sum_rnn_masks:int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if smooth:
a_loss = common_losses.smoothed_actor_loss(old_action_log_probs_batch, action_log_probs, advantage, ppo, curr_e_clip)
else:
a_loss = common_losses.actor_loss(old_action_log_probs_batch, action_log_probs, advantage, ppo, curr_e_clip)

if has_value_loss:
c_loss = common_losses.critic_loss(value_preds_batch, values, curr_e_clip, return_batch, clip_value)
else:
c_loss = torch.zeros(1)

if bound_loss_type == 'regularisation':
b_loss = (mu*mu).sum(dim=-1, keepdim=True)
elif bound_loss_type == 'bound':
soft_bound = 1.1
mu_loss_high = torch.clamp_min(mu - soft_bound, 0.0)**2
mu_loss_low = torch.clamp_max(mu + soft_bound, 0.0)**2
b_loss = (mu_loss_low + mu_loss_high).sum(dim=-1, keepdim=True)
else:
b_loss = torch.zeros(1)

losses = torch_ext.apply_masks_compilable([a_loss, c_loss, entropy, b_loss], rnn_masks, sum_rnn_masks)
a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3]

loss = a_loss + 0.5 * c_loss * critic_coef - entropy * entropy_coef + b_loss * bounds_loss_coef

return loss, a_loss, c_loss, entropy, b_loss

@torch.jit.script
def forward_for_gradients(smooth:bool, ppo:bool, has_value_loss:bool, bound_loss_type:str,
old_action_log_probs_batch, action_log_probs, advantage,
value_preds_batch, values, return_batch, mu, entropy,
curr_e_clip:float, clip_value:float, critic_coef:float, entropy_coef:float, bounds_loss_coef:float,
rnn_masks:Optional[torch.Tensor], sum_rnn_masks:int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if smooth:
a_loss = common_losses.smoothed_actor_loss(old_action_log_probs_batch, action_log_probs, advantage, ppo, curr_e_clip)
else:
a_loss = common_losses.actor_loss(old_action_log_probs_batch, action_log_probs, advantage, ppo, curr_e_clip)

if has_value_loss:
c_loss = common_losses.critic_loss(value_preds_batch, values, curr_e_clip, return_batch, clip_value)
else:
c_loss = torch.zeros(1)
if bound_loss_type == 'regularisation':
b_loss = (mu*mu).sum(dim=-1, keepdim=True)
elif bound_loss_type == 'bound':
soft_bound = 1.1
mu_loss_high = torch.clamp_min(mu - soft_bound, 0.0)**2
mu_loss_low = torch.clamp_max(mu + soft_bound, 0.0)**2
b_loss = (mu_loss_low + mu_loss_high).sum(dim=-1, keepdim=True)
else:
b_loss = torch.zeros(1)
losses = torch_ext.apply_masks_compilable([a_loss, c_loss, entropy, b_loss], rnn_masks, sum_rnn_masks)
a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3]

loss = a_loss + 0.5 * c_loss * critic_coef - entropy * entropy_coef + b_loss * bounds_loss_coef

return loss, a_loss, c_loss, entropy, b_loss

def calc_gradients(self, input_dict):
value_preds_batch = input_dict['old_values']
Expand Down Expand Up @@ -112,7 +177,13 @@ def calc_gradients(self, input_dict):
mu = res_dict['mus']
sigma = res_dict['sigmas']

a_loss = self.actor_loss_func(old_action_log_probs_batch, action_log_probs, advantage, self.ppo, curr_e_clip)
loss, a_loss, c_loss, entropy, b_loss = \
A2CAgent.forward_for_gradients(self.actor_loss_func == common_losses.smoothed_actor_loss,
self.ppo, self.has_value_loss, self.bound_loss_type,
old_action_log_probs_batch, action_log_probs, advantage,
value_preds_batch, values, return_batch, mu, entropy.unsqueeze(1),
curr_e_clip, self.clip_value, self.critic_coef, self.entropy_coef, self.bounds_loss_coef,
None if rnn_masks is None else rnn_masks.unsqueeze(1), 0 if rnn_masks is None else rnn_masks.numel())

if self.has_value_loss:
c_loss = common_losses.critic_loss(self.model,value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)
Expand Down
20 changes: 16 additions & 4 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import numpy as np
from rl_games.algos_torch import torch_ext
from rl_games.algos_torch.running_mean_std import RunningMeanStd, RunningMeanStdObs
from rl_games.common import common_losses
from rl_games.common import datasets
from rl_games.common import schedulers
from rl_games.common import common_losses, datasets, schedulers
from typing import Dict, Optional


class CentralValueTrain(nn.Module):

def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_length, num_actors, num_actions,
seq_len, normalize_value, network, config, writter, max_epochs, multi_gpu, zero_rnn_on_done):
nn.Module.__init__(self)
Expand Down Expand Up @@ -80,6 +80,7 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng
total_agents = self.num_actors #* self.num_agents
num_seqs = self.horizon_length // self.seq_len
assert ((self.horizon_length * total_agents // self.num_minibatches) % self.seq_len == 0)

self.mb_rnn_states = [ torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype=torch.float32, device=self.ppo_device) for s in self.rnn_states]

if self.multi_gpu:
Expand Down Expand Up @@ -225,9 +226,17 @@ def train_net(self):
self.frame += self.batch_size
if self.writter != None:
self.writter.add_scalar('losses/cval_loss', avg_loss, self.frame)
self.writter.add_scalar('info/cval_lr', self.lr, self.frame)
self.writter.add_scalar('info/cval_lr', self.lr, self.frame)

return avg_loss

@torch.jit.script
def forward_for_gradients(values, value_preds_batch, e_clip:float, returns_batch, clip_value:float, rnn_masks_batch:Optional[torch.Tensor], sum_rnn_masks:int):
loss = common_losses.critic_loss(value_preds_batch, values, e_clip, returns_batch, clip_value)
losses = torch_ext.apply_masks_compilable([loss], rnn_masks_batch, sum_rnn_masks)

return losses[0]

def calc_gradients(self, batch):
obs_batch = self._preproc_obs(batch['obs'])
value_preds_batch = batch['old_values']
Expand Down Expand Up @@ -255,6 +264,7 @@ def calc_gradients(self, batch):
else:
for param in self.model.parameters():
param.grad = None

loss.backward()

if self.multi_gpu:
Expand All @@ -263,9 +273,11 @@ def calc_gradients(self, batch):
for param in self.model.parameters():
if param.grad is not None:
all_grads_list.append(param.grad.view(-1))

all_grads = torch.cat(all_grads_list)
dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)
offset = 0

for param in self.model.parameters():
if param.grad is not None:
param.grad.data.copy_(
Expand Down
58 changes: 43 additions & 15 deletions rl_games/algos_torch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ def build(self, config):
normalize_value = config.get('normalize_value', False)
normalize_input = config.get('normalize_input', False)
value_size = config.get('value_size', 1)

return self.Network(self.network_builder.build(self.model_class, **config), obs_shape=obs_shape,
normalize_value=normalize_value, normalize_input=normalize_input, value_size=value_size)


class BaseModelNetwork(nn.Module):
def __init__(self, obs_shape, normalize_value, normalize_input, value_size):
nn.Module.__init__(self)
Expand All @@ -40,12 +42,13 @@ def __init__(self, obs_shape, normalize_value, normalize_input, value_size):
self.value_size = value_size

if normalize_value:
self.value_mean_std = RunningMeanStd((self.value_size,)) # GeneralizedMovingStats((self.value_size,)) #
self.value_mean_std = RunningMeanStd((self.value_size,)) # GeneralizedMovingStats((self.value_size,)) #

if normalize_input:
if isinstance(obs_shape, dict):
self.running_mean_std = RunningMeanStdObs(obs_shape)
self.running_mean_std = torch.jit.script(RunningMeanStdObs(obs_shape))
else:
self.running_mean_std = RunningMeanStd(obs_shape)
self.running_mean_std = torch.jit.script(RunningMeanStd(obs_shape))

def norm_obs(self, observation):
with torch.no_grad():
Expand All @@ -55,6 +58,7 @@ def denorm_value(self, value):
with torch.no_grad():
return self.value_mean_std(value, denorm=True) if self.normalize_value else value


class ModelA2C(BaseModel):
def __init__(self, network):
BaseModel.__init__(self, 'a2c')
Expand All @@ -67,7 +71,7 @@ def __init__(self, a2c_network, **kwargs):

def is_rnn(self):
return self.a2c_network.is_rnn()

def get_default_rnn_state(self):
return self.a2c_network.get_default_rnn_state()

Expand Down Expand Up @@ -111,6 +115,7 @@ def forward(self, input_dict):
}
return result


class ModelA2CMultiDiscrete(BaseModel):
def __init__(self, network):
BaseModel.__init__(self, 'a2c')
Expand All @@ -123,7 +128,7 @@ def __init__(self, a2c_network, **kwargs):

def is_rnn(self):
return self.a2c_network.is_rnn()

def get_default_rnn_state(self):
return self.a2c_network.get_default_rnn_state()

Expand All @@ -141,6 +146,7 @@ def forward(self, input_dict):
prev_actions = input_dict.get('prev_actions', None)
input_dict['obs'] = self.norm_obs(input_dict['obs'])
logits, value, states = self.a2c_network(input_dict)

if is_train:
if action_masks is None:
categorical = [Categorical(logits=logit) for logit in logits]
Expand All @@ -163,8 +169,8 @@ def forward(self, input_dict):
if action_masks is None:
categorical = [Categorical(logits=logit) for logit in logits]
else:
categorical = [CategoricalMasked(logits=logit, masks=mask) for logit, mask in zip(logits, action_masks)]
categorical = [CategoricalMasked(logits=logit, masks=mask) for logit, mask in zip(logits, action_masks)]

selected_action = [c.sample().long() for c in categorical]
neglogp = [-c.log_prob(a.squeeze()) for c,a in zip(categorical, selected_action)]
selected_action = torch.stack(selected_action, dim=-1)
Expand All @@ -178,6 +184,7 @@ def forward(self, input_dict):
}
return result


class ModelA2CContinuous(BaseModel):
def __init__(self, network):
BaseModel.__init__(self, 'a2c')
Expand All @@ -190,7 +197,7 @@ def __init__(self, a2c_network, **kwargs):

def is_rnn(self):
return self.a2c_network.is_rnn()

def get_default_rnn_state(self):
return self.a2c_network.get_default_rnn_state()

Expand Down Expand Up @@ -262,21 +269,22 @@ def forward(self, input_dict):
mu, logstd, value, states = self.a2c_network(input_dict)
sigma = torch.exp(logstd)
distr = torch.distributions.Normal(mu, sigma, validate_args=False)

if is_train:
entropy = distr.entropy().sum(dim=-1)
prev_neglogp = self.neglogp(prev_actions, mu, sigma, logstd)
sigma, entropy, prev_neglogp = ModelA2CContinuousLogStd.compute_neglog_train(prev_actions, mu, logstd, prev_actions.size()[-1])
result = {
'prev_neglogp' : torch.squeeze(prev_neglogp),
'values' : value,
'entropy' : entropy,
'rnn_states' : states,
'mus' : mu,
'sigmas' : sigma
}
}

return result
else:
selected_action = distr.sample()
neglogp = self.neglogp(selected_action, mu, sigma, logstd)

selected_action, sigma, neglogp = ModelA2CContinuousLogStd.compute_neglog_infer(mu, logstd, torch.broadcast_shapes(mu.size(), logstd.size())[-1])
result = {
'neglogpacs' : torch.squeeze(neglogp),
'values' : self.denorm_value(value),
Expand All @@ -285,13 +293,34 @@ def forward(self, input_dict):
'mus' : mu,
'sigmas' : sigma
}

return result

def neglogp(self, x, mean, std, logstd):
return 0.5 * (((x - mean) / std)**2).sum(dim=-1) \
+ 0.5 * np.log(2.0 * np.pi) * x.size()[-1] \
+ logstd.sum(dim=-1)

@torch.jit.script
def compute_neglog_train(action, mu, logstd, action_cols:int):
sigma = torch.exp(logstd)
entropy = (0.5 + 0.5 * torch.log(2 * torch.pi) + torch.log(sigma)).sum(dim=-1)
neglogp = 0.5 * (((action - mu) / sigma)**2).sum(dim=-1) \
+ 0.5 * torch.log(2.0 * torch.pi) * action_cols \
+ logstd.sum(dim=-1)

return sigma, entropy, neglogp

@torch.jit.script
def compute_neglog_infer(mu, logstd, action_cols:int):
sigma = torch.exp(logstd)
selected_action = torch.normal(mu, sigma)
neglogp = 0.5 * (((selected_action - mu) / sigma)**2).sum(dim=-1) \
+ 0.5 * torch.log(2.0 * torch.pi) * action_cols \
+ logstd.sum(dim=-1)

return selected_action, sigma, neglogp


class ModelCentralValue(BaseModel):
def __init__(self, network):
Expand Down Expand Up @@ -330,13 +359,12 @@ def forward(self, input_dict):
return result



class ModelSACContinuous(BaseModel):

def __init__(self, network):
BaseModel.__init__(self, 'sac')
self.network_builder = network

class Network(BaseModelNetwork):
def __init__(self, sac_network,**kwargs):
BaseModelNetwork.__init__(self,**kwargs)
Expand Down
Loading