Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed fp16 central value #60

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, base_name, config):
}
self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)
self.use_experimental_cv = self.config.get('use_experimental_cv', True)

self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len)
self.algo_observer.after_init(self)

Expand All @@ -71,6 +72,7 @@ def get_masked_action_values(self, obs, action_masks):

def calc_gradients(self, input_dict, opt_step):
self.set_train()

value_preds_batch = input_dict['old_values']
old_action_log_probs_batch = input_dict['old_logp_actions']
advantage = input_dict['advantages']
Expand All @@ -89,16 +91,16 @@ def calc_gradients(self, input_dict, opt_step):
batch_dict = {
'is_train': True,
'prev_actions': actions_batch,
'obs' : obs_batch,
'obs': obs_batch,
}

rnn_masks = None
if self.is_rnn:
rnn_masks = input_dict['rnn_masks']
batch_dict['rnn_states'] = input_dict['rnn_states']
batch_dict['seq_length'] = self.seq_len

with torch.cuda.amp.autocast(enabled=self.mixed_precision):
rnn_masks = None
if self.is_rnn:
rnn_masks = input_dict['rnn_masks']
batch_dict['rnn_states'] = input_dict['rnn_states']
batch_dict['seq_length'] = self.seq_len

res_dict = self.model(batch_dict)
action_log_probs = res_dict['prev_neglogp']
values = res_dict['value']
Expand Down Expand Up @@ -133,11 +135,12 @@ def calc_gradients(self, input_dict, opt_step):
self.scaler.update()

with torch.no_grad():
reduce_kl = not self.is_rnn
kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl)
if self.is_rnn:
kl_dist = (kl_dist * rnn_masks).sum() / sum_mask
kl_dist = kl_dist.item()
with torch.cuda.amp.autocast(enabled=self.mixed_precision):
reduce_kl = not self.is_rnn
kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl)
if self.is_rnn:
kl_dist = (kl_dist * rnn_masks).sum() / sum_mask
kl_dist = kl_dist.item()

self.train_result = (a_loss.item(), c_loss.item(), entropy.item(), \
kl_dist, self.last_lr, lr_mul, \
Expand Down
101 changes: 62 additions & 39 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,49 +31,58 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, num_steps, n
self.clip_value = config['clip_value']
self.normalize_input = config['normalize_input']
self.normalize_value = config.get('normalize_value', False)
self.running_mean_std = None
if self.normalize_input:
self.running_mean_std = RunningMeanStd(state_shape)

self.writter = writter
self.use_joint_obs_actions = config.get('use_joint_obs_actions', False)
self.optimizer = torch.optim.Adam(self.model.parameters(), float(self.lr), eps=1e-07)
self.frame = 0
self.running_mean_std = None

self.grad_norm = config.get('grad_norm', 1)
self.truncate_grads = config.get('truncate_grads', False)
self.e_clip = config.get('e_clip', 0.2)
if self.normalize_input:
self.running_mean_std = RunningMeanStd(state_shape)

# todo - from the ьфшт config!
self.mixed_precision = self.config.get('mixed_precision', True)
self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision)

self.is_rnn = self.model.is_rnn()
self.rnn_states = None
self.batch_size = self.num_steps * self.num_actors

if self.is_rnn:
self.rnn_states = self.model.get_default_rnn_state()
self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]
num_seqs = self.num_steps * self.num_actors // self.seq_len
assert((self.num_steps * self.num_actors // self.num_minibatches) % self.seq_len == 0)
self.mb_rnn_states = [torch.zeros((s.size()[0], num_seqs, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]
with torch.cuda.amp.autocast(enabled=self.mixed_precision):
self.rnn_states = self.model.get_default_rnn_state()
self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]
num_seqs = self.num_steps * self.num_actors // self.seq_len
assert((self.num_steps * self.num_actors // self.num_minibatches) % self.seq_len == 0)
self.mb_rnn_states = [torch.zeros((s.size()[0], num_seqs, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]

self.dataset = datasets.PPODataset(self.batch_size, self.mini_batch, True, self.is_rnn, self.ppo_device, self.seq_len)

def get_stats_weights(self):
def get_stats_weights(self):
if self.normalize_input:
return self.running_mean_std.state_dict()
else:
return None

def set_stats_weights(self, weights):
self.running_mean_std.load_state_dict(weights)

def update_dataset(self, batch_dict):
value_preds = batch_dict['old_values']
returns = batch_dict['returns']
actions = batch_dict['actions']
rnn_masks = batch_dict['rnn_masks']

if self.num_agents > 1:
res = self.update_multiagent_tensors(value_preds, returns, actions, rnn_masks)
batch_dict['old_values'] = res[0]
batch_dict['returns'] = res[1]
batch_dict['actions'] = res[2]

if self.is_rnn:
batch_dict['rnn_states'] = self.mb_rnn_states
if self.num_agents > 1:
Expand All @@ -90,9 +99,11 @@ def _preproc_obs(self, obs_batch):
obs_batch = obs_batch.permute((0, 3, 1, 2))
if self.normalize_input:
obs_batch = self.running_mean_std(obs_batch)

return obs_batch

def pre_step_rnn(self, rnn_indices, state_indices):
#with torch.cuda.amp.autocast(enabled=self.mixed_precision):
if self.num_agents > 1:
rnn_indices = rnn_indices[::self.num_agents]
shifts = rnn_indices % (self.num_steps // self.seq_len)
Expand All @@ -105,10 +116,12 @@ def pre_step_rnn(self, rnn_indices, state_indices):
def post_step_rnn(self, all_done_indices):
all_done_indices = all_done_indices[::self.num_agents] // self.num_agents
for s in self.rnn_states:
s[:,all_done_indices,:] = s[:,all_done_indices,:] * 0.0
s[:, all_done_indices, :] = s[:, all_done_indices, :] * 0.0

def forward(self, input_dict):
#with torch.cuda.amp.autocast(enabled=self.mixed_precision):
value, rnn_states = self.model(input_dict)

return value, rnn_states

def get_value(self, input_dict):
Expand All @@ -118,7 +131,7 @@ def get_value(self, input_dict):
actions = input_dict.get('actions', None)

obs_batch = self._preproc_obs(obs_batch)
value, self.rnn_states = self.forward({'obs' : obs_batch, 'actions': actions,
value, self.rnn_states = self.forward({'obs': obs_batch, 'actions': actions,
'rnn_states': self.rnn_states})
if self.num_agents > 1:
value = value.repeat(1, self.num_agents)
Expand All @@ -135,19 +148,20 @@ def train_critic(self, input_dict, opt_step = True):
def update_multiagent_tensors(self, value_preds, returns, actions, rnn_masks):
batch_size = self.batch_size
ma_batch_size = self.num_actors * self.num_agents * self.num_steps
value_preds = value_preds.view(self.num_actors, self.num_agents, self.num_steps, self.value_size).transpose(0,1)
returns = returns.view(self.num_actors, self.num_agents, self.num_steps, self.value_size).transpose(0,1)
value_preds = value_preds.view(self.num_actors, self.num_agents, self.num_steps, self.value_size).transpose(0, 1)
returns = returns.view(self.num_actors, self.num_agents, self.num_steps, self.value_size).transpose(0, 1)
value_preds = value_preds.contiguous().view(ma_batch_size, self.value_size)[:batch_size]
returns = returns.contiguous().view(ma_batch_size, self.value_size)[:batch_size]

if self.use_joint_obs_actions:
assert(len(actions.size()) == 2, 'use_joint_obs_actions not yet supported in continuous environment for central value')
actions = actions.view(self.num_actors, self.num_agents, self.num_steps).transpose(0,1)
assert(len(actions.size() == 2), 'use_joint_obs_actions not yet supported in continuous environment for central value')
actions = actions.view(self.num_actors, self.num_agents, self.num_steps).transpose(0, 1)
actions = actions.contiguous().view(batch_size, self.num_agents)

if self.is_rnn:
rnn_masks = rnn_masks.view(self.num_actors, self.num_agents, self.num_steps).transpose(0,1)
rnn_masks = rnn_masks.flatten(0)[:batch_size]
rnn_masks = rnn_masks.view(self.num_actors, self.num_agents, self.num_steps).transpose(0, 1)
rnn_masks = rnn_masks.flatten(0)[:batch_size]

return value_preds, returns, actions, rnn_masks

def train_net(self):
Expand All @@ -157,33 +171,42 @@ def train_net(self):
for idx in range(len(self.dataset)):
loss += self.train_critic(self.dataset[idx])
avg_loss = loss / (self.mini_epoch * self.num_minibatches)

self.writter.add_scalar('losses/cval_loss', avg_loss, self.frame)
self.frame += self.batch_size

return avg_loss

def calc_gradients(self, batch, opt_step):
obs_batch = self._preproc_obs(batch['obs'])
value_preds_batch = batch['old_values']
returns_batch = batch['returns']
actions_batch = batch['actions']
rnn_masks_batch = batch.get('rnn_masks')

batch_dict = {'obs' : obs_batch,
'actions' : actions_batch,
'seq_length' : self.seq_len }
if self.is_rnn:
batch_dict['rnn_states'] = batch['rnn_states']

values, _ = self.forward(batch_dict)
loss = common_losses.critic_loss(value_preds_batch, values, self.e_clip, returns_batch, self.clip_value)
losses, _ = torch_ext.apply_masks([loss], rnn_masks_batch)
loss = losses[0]
with torch.cuda.amp.autocast(enabled=self.mixed_precision):
obs_batch = self._preproc_obs(batch['obs'])
value_preds_batch = batch['old_values']
returns_batch = batch['returns']
actions_batch = batch['actions']
rnn_masks_batch = batch.get('rnn_masks')

batch_dict = {'obs': obs_batch,
'actions': actions_batch,
'seq_length': self.seq_len }

if self.is_rnn:
batch_dict['rnn_states'] = batch['rnn_states']

values, _ = self.forward(batch_dict)
loss = common_losses.critic_loss(value_preds_batch, values, self.e_clip, returns_batch, self.clip_value)
losses, _ = torch_ext.apply_masks([loss], rnn_masks_batch)
loss = losses[0]

for param in self.model.parameters():
param.grad = None
loss.backward()
if self.truncate_grads:

self.scaler.scale(loss).backward()
if self.config['truncate_grads']:
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)

if opt_step:
self.optimizer.step()
self.scaler.step(self.optimizer)
self.scaler.update()

return loss
Loading