Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Torch upgrade. Aux loss. Vision backbone networks. #304

Open
wants to merge 35 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c41212b
Updated to torch 2. Compile without max-autotune is working.
ViktorM May 21, 2024
6a91bd3
Fixed minibatch_per_env bug.
ViktorM Jun 12, 2024
46c8c48
Merge from master.
ViktorM Jun 27, 2024
ae043b9
Merge branch 'master' into VM/torch_upgrade
ViktorM Jul 12, 2024
1e58c1e
WIP: adding proprio observations to resnet network.
ViktorM Aug 16, 2024
80f43f7
Added pretrained resnet e2e network. Fixes.
ViktorM Aug 16, 2024
7ea73fa
Clean up.
ViktorM Aug 16, 2024
2c285bd
Temporary moved vision_actor_critic back to network_builder.
ViktorM Aug 16, 2024
ec4e4f1
E2e training is working with Impala network and proprieception.
ViktorM Aug 17, 2024
04d653a
Fixed input dictinary observation shapes for resnet network. Fixed amp.
ViktorM Aug 17, 2024
637d2bc
Added support for more visual backbones.
ViktorM Aug 17, 2024
0dc5b43
WIP: network with different vision backbones.
ViktorM Aug 18, 2024
eeae30e
Vision backbone improvements.
ViktorM Aug 19, 2024
006fe2c
Resnet18 works!
ViktorM Aug 19, 2024
61d1bc4
Added convnext and vit backbones support. Added preprocessing.
ViktorM Aug 20, 2024
c1eeeba
Fixed observation dict.
ViktorM Aug 21, 2024
3531fb8
Pacman training.
ViktorM Aug 22, 2024
d073f4d
Refactored vision backbone networks. Maniskill first pass.
ViktorM Aug 31, 2024
ead882c
Fixed Maniskill state-based training.
ViktorM Aug 31, 2024
8be57aa
Fixed Maniskill resets. Push and pick configs.
ViktorM Sep 1, 2024
d21fd8f
Maniskill RGB and Depth only observations. Added resnet and impala cu…
ViktorM Sep 1, 2024
dc4e279
Fix.
ViktorM Sep 1, 2024
a78caac
WIP: norm layer with Impala for pick cube vision training.
ViktorM Sep 4, 2024
2fe6c5f
added test aux_loss
DenSumy Sep 8, 2024
8b274d1
Layer norm for vision model. Better maniskill training configs.
ViktorM Sep 8, 2024
00cbd3d
Merge with aux_loss branch.
ViktorM Sep 8, 2024
4d29423
Aux loss.
ViktorM Sep 9, 2024
b3467f2
Merge from master.
ViktorM Sep 9, 2024
80ee4df
Auxilarry losses visual training WIP. Maniskill env switched to using…
ViktorM Sep 10, 2024
39e0ff8
Aux loss is now optional. Confif fix.
ViktorM Sep 10, 2024
ee19a1c
Fixed impala network.
ViktorM Sep 11, 2024
fa3469a
Added aux loss to the backbone network. Added debug save images. Trai…
ViktorM Sep 16, 2024
092e559
Depth training config. Dino2 vit support.
ViktorM Sep 19, 2024
ed14b38
WIP: more backbones.
ViktorM Sep 22, 2024
e2c6bd0
WIP: e2e. Mor vit architectures. Better aux loss support from yaml. t…
ViktorM Sep 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ Additional environment supported properties and functions

## Release Notes

1.6.5

* Pytorch 2

1.6.1
* Fixed Central Value RNN bug which occurs if you train ma multi agent environment.
* Added Deepmind Control PPO benchmark.
Expand Down
5 changes: 4 additions & 1 deletion docs/HOW_TO_RL_GAMES.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def launch_rlg_hydra(cfg: DictConfig):
# Make a copy of RayVecEnv

class CustomRayVecEnv(IVecEnv):
import ray
try:
import ray
except ImportError:
pass

def __init__(self, config_dict, config_name, num_actors, **kwargs):
### ADDED CHANGE ###
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "rl_games"
version = "1.6.1"
version = "1.6.5"
description = ""
readme = "README.md"
authors = [
Expand Down
91 changes: 59 additions & 32 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from rl_games.common import datasets

from torch import optim
import torch
import torch

try:
from apex.optimizers import FusedAdam as AdamImpl
except ImportError:
AdamImpl = optim.Adam


class A2CAgent(a2c_common.ContinuousA2CBase):
Expand All @@ -30,18 +35,18 @@ def __init__(self, base_name, params):
'actions_num' : self.actions_num,
'input_shape' : obs_shape,
'num_seqs' : self.num_actors * self.num_agents,
'value_size': self.env_info.get('value_size',1),
'value_size': self.env_info.get('value_size', 1),
'normalize_value' : self.normalize_value,
'normalize_input': self.normalize_input,
}

self.model = self.network.build(build_config)
self.model.to(self.ppo_device)
self.states = None
self.init_rnn_from_model(self.model)
self.last_lr = float(self.last_lr)
self.bound_loss_type = self.config.get('bound_loss_type', 'bound') # 'regularisation' or 'bound'
self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay)
self.bound_loss_type = self.config.get('bound_loss_type', 'bound') # 'regularization' or 'bound'
self.optimizer = AdamImpl(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay)

if self.has_central_value:
cv_config = {
Expand Down Expand Up @@ -74,7 +79,7 @@ def __init__(self, base_name, params):
def update_epoch(self):
self.epoch_num += 1
return self.epoch_num

def save(self, fn):
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)
Expand All @@ -90,6 +95,28 @@ def restore_central_value_function(self, fn):
def get_masked_action_values(self, obs, action_masks):
assert False

@torch.compile() #(mode='max-autotune')
def calc_losses(self, actor_loss_func, old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip, value_preds_batch, values, return_batch, mu, entropy, rnn_masks):
a_loss = actor_loss_func(old_action_log_probs_batch, action_log_probs, advantage, self.ppo, curr_e_clip)

if self.has_value_loss:
c_loss = common_losses.critic_loss(self.model,value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)
else:
c_loss = torch.zeros(1, device=self.ppo_device)
if self.bound_loss_type == 'regularisation':
b_loss = self.reg_loss(mu)
elif self.bound_loss_type == 'bound':
b_loss = self.bound_loss(mu)
else:
b_loss = torch.zeros(1, device=self.ppo_device)

losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1), b_loss.unsqueeze(1)], rnn_masks)
a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3]

loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef

return loss, a_loss, c_loss, entropy, b_loss, sum_mask

def calc_gradients(self, input_dict):
"""Compute gradients needed to step the networks of the algorithm.

Expand Down Expand Up @@ -127,30 +154,28 @@ def calc_gradients(self, input_dict):
if self.zero_rnn_on_done:
batch_dict['dones'] = input_dict['dones']

with torch.cuda.amp.autocast(enabled=self.mixed_precision):
with torch.amp.autocast("cuda", enabled=self.mixed_precision):
res_dict = self.model(batch_dict)
action_log_probs = res_dict['prev_neglogp']
values = res_dict['values']
entropy = res_dict['entropy']
mu = res_dict['mus']
sigma = res_dict['sigmas']

a_loss = self.actor_loss_func(old_action_log_probs_batch, action_log_probs, advantage, self.ppo, curr_e_clip)

if self.has_value_loss:
c_loss = common_losses.critic_loss(self.model,value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)
else:
c_loss = torch.zeros(1, device=self.ppo_device)
if self.bound_loss_type == 'regularisation':
b_loss = self.reg_loss(mu)
elif self.bound_loss_type == 'bound':
b_loss = self.bound_loss(mu)
else:
b_loss = torch.zeros(1, device=self.ppo_device)
losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss , entropy.unsqueeze(1), b_loss.unsqueeze(1)], rnn_masks)
a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3]
loss, a_loss, c_loss, entropy, b_loss, sum_mask = self.calc_losses(
self.actor_loss_func,
old_action_log_probs_batch,
action_log_probs,
advantage,
curr_e_clip,
value_preds_batch,
values,
return_batch,
mu,
entropy,
rnn_masks
)

loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef
aux_loss = self.model.get_aux_loss()
self.aux_loss_dict = {}
if aux_loss is not None:
Expand All @@ -160,29 +185,33 @@ def calc_gradients(self, input_dict):
self.aux_loss_dict[k] = v.detach()
else:
self.aux_loss_dict[k] = [v.detach()]

if self.multi_gpu:
self.optimizer.zero_grad()
else:
for param in self.model.parameters():
param.grad = None

self.scaler.scale(loss).backward()
#TODO: Refactor this ugliest code of they year
# TODO: Refactor this ugliest code of they year
self.trancate_gradients_and_step()

with torch.no_grad():
reduce_kl = rnn_masks is None
kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl)
kl_dist = torch_ext.policy_kl(
mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch,
reduce_kl
)
if rnn_masks is not None:
kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel() #/ sum_mask

self.diagnostics.mini_batch(self,
{
'values' : value_preds_batch,
'returns' : return_batch,
'new_neglogp' : action_log_probs,
'old_neglogp' : old_action_log_probs_batch,
'masks' : rnn_masks
'values': value_preds_batch,
'returns': return_batch,
'new_neglogp': action_log_probs,
'old_neglogp': old_action_log_probs_batch,
'masks': rnn_masks
}, curr_e_clip, 0)

self.train_result = (a_loss, c_loss, entropy, \
Expand All @@ -208,6 +237,4 @@ def bound_loss(self, mu):
b_loss = (mu_loss_low + mu_loss_high).sum(axis=-1)
else:
b_loss = 0
return b_loss


return b_loss
11 changes: 8 additions & 3 deletions rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@

from torch import optim
import torch
from torch import nn

import numpy as np

try:
from apex.optimizers import FusedAdam as AdamImpl
except ImportError:
AdamImpl = optim.Adam


class DiscreteA2CAgent(a2c_common.DiscreteA2CBase):
"""Discrete PPO Agent
Expand Down Expand Up @@ -43,7 +48,7 @@ def __init__(self, base_name, params):
self.init_rnn_from_model(self.model)

self.last_lr = float(self.last_lr)
self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay)
self.optimizer = AdamImpl(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay)

if self.has_central_value:
cv_config = {
Expand Down Expand Up @@ -154,7 +159,7 @@ def calc_gradients(self, input_dict):
if self.zero_rnn_on_done:
batch_dict['dones'] = input_dict['dones']

with torch.cuda.amp.autocast(enabled=self.mixed_precision):
with torch.amp.autocast("cuda", enabled=self.mixed_precision):
res_dict = self.model(batch_dict)
action_log_probs = res_dict['prev_neglogp']
values = res_dict['values']
Expand Down
39 changes: 27 additions & 12 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
import torch
from torch import nn
import torch.distributed as dist

from rl_games.algos_torch import torch_ext
from rl_games.algos_torch.running_mean_std import RunningMeanStd, RunningMeanStdObs
from rl_games.common import common_losses
from rl_games.common import datasets
from rl_games.common import schedulers
from rl_games.common import common_losses, datasets, schedulers

try:
from apex.optimizers import FusedAdam as AdamImpl
from apex.contrib.clip_grad import clip_grad_norm_
except ImportError:
AdamImpl = torch.optim.Adam
from torch.nn.utils import clip_grad_norm_


class CentralValueTrain(nn.Module):
Expand Down Expand Up @@ -62,7 +68,8 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng

self.writter = writter
self.weight_decay = config.get('weight_decay', 0.0)
self.optimizer = torch.optim.Adam(self.model.parameters(), float(self.lr), eps=1e-08, weight_decay=self.weight_decay)
self.optimizer = AdamImpl(self.model.parameters(), float(self.lr), eps=1e-08, weight_decay=self.weight_decay)

self.frame = 0
self.epoch_num = 0
self.running_mean_std = None
Expand Down Expand Up @@ -225,7 +232,7 @@ def train_net(self):
for idx in range(len(self.dataset)):
loss += self.train_critic(self.dataset[idx])
if self.normalize_input:
self.model.running_mean_std.eval() # don't need to update statstics more than one miniepoch
self.model.running_mean_std.eval() # don't need to update statistics more than one miniepoch
avg_loss = loss / (self.mini_epoch * self.num_minibatches)

self.epoch_num += 1
Expand All @@ -234,8 +241,16 @@ def train_net(self):
self.frame += self.batch_size
if self.writter != None:
self.writter.add_scalar('losses/cval_loss', avg_loss, self.frame)
self.writter.add_scalar('info/cval_lr', self.lr, self.frame)
self.writter.add_scalar('info/cval_lr', self.lr, self.frame)

return avg_loss

@torch.compile() #(mode='max-autotune')
def calc_loss(self, value_preds_batch, values, returns_batch, rnn_masks_batch):
loss = common_losses.critic_loss(self.model, value_preds_batch, values, self.e_clip, returns_batch, self.clip_value)
#print(loss.min(), loss.max(), loss.size(), rnn_masks_batch)
losses, _ = torch_ext.apply_masks([loss], rnn_masks_batch)
return losses[0]

def calc_gradients(self, batch):
obs_batch = self._preproc_obs(batch['obs'])
Expand All @@ -254,11 +269,9 @@ def calc_gradients(self, batch):

res_dict = self.model(batch_dict)
values = res_dict['values']
loss = common_losses.critic_loss(self.model, value_preds_batch, values, self.e_clip, returns_batch, self.clip_value)
#print(loss.min(), loss.max(), loss.size(), rnn_masks_batch)
losses, _ = torch_ext.apply_masks([loss], rnn_masks_batch)
loss = losses[0]
#6print('aaa', loss.min(), loss.max(), loss.size())

loss = self.calc_loss(value_preds_batch, values, returns_batch, rnn_masks_batch)

if self.multi_gpu:
self.optimizer.zero_grad()
else:
Expand All @@ -273,7 +286,9 @@ def calc_gradients(self, batch):
if param.grad is not None:
all_grads_list.append(param.grad.view(-1))
all_grads = torch.cat(all_grads_list)

dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)

offset = 0
for param in self.model.parameters():
if param.grad is not None:
Expand All @@ -283,7 +298,7 @@ def calc_gradients(self, batch):
offset += param.numel()

if self.truncate_grads:
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
clip_grad_norm_(self.model.parameters(), self.grad_norm)

self.optimizer.step()

Expand Down
1 change: 1 addition & 0 deletions rl_games/algos_torch/d2rl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch


class D2RLNet(torch.nn.Module):
def __init__(self, input_size,
units,
Expand Down
2 changes: 1 addition & 1 deletion rl_games/algos_torch/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import collections
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple

import torch
from torch import nn



@dataclass
class Schema:
"""
Expand Down
7 changes: 3 additions & 4 deletions rl_games/algos_torch/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from rl_games.common import object_factory
import rl_games.algos_torch
from rl_games.algos_torch import network_builder
from rl_games.algos_torch import models
from rl_games.algos_torch import network_builder, models

NETWORK_REGISTRY = {}
MODEL_REGISTRY = {}
Expand All @@ -19,8 +18,8 @@ def __init__(self):
self.network_factory.set_builders(NETWORK_REGISTRY)
self.network_factory.register_builder('actor_critic', lambda **kwargs: network_builder.A2CBuilder())
self.network_factory.register_builder('resnet_actor_critic',
lambda **kwargs: network_builder.A2CResnetBuilder())
self.network_factory.register_builder('rnd_curiosity', lambda **kwargs: network_builder.RNDCuriosityBuilder())
lambda **kwargs: network_builder.A2CResnetBuilder())

self.network_factory.register_builder('soft_actor_critic', lambda **kwargs: network_builder.SACBuilder())

def load(self, params):
Expand Down
Loading