Skip to content

Commit

Permalink
Resnet18 works!
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Aug 19, 2024
1 parent eeae30e commit 006fe2c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 26 deletions.
25 changes: 15 additions & 10 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,6 @@ def build(self, name, **kwargs):


from torchvision import models
from timm import create_model # timm is required for ConvNeXt and ViT

class VisionBackboneBuilder(NetworkBuilder):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -1172,6 +1171,8 @@ def forward(self, obs_dict):
proprio = obs_dict['proprio']
else:
obs = obs_dict['obs']

# print('obs.shape: ', obs.shape)
if self.permute_input:
obs = obs.permute((0, 3, 1, 2))

Expand All @@ -1181,7 +1182,10 @@ def forward(self, obs_dict):

out = obs
out = self.cnn(out)
#print(out.shape)
out = out.flatten(1)
#print(out.shape)
#print('AAAAAAAAAAAAAAAAAaaa')
out = self.flatten_act(out)

if self.proprio_size > 0:
Expand Down Expand Up @@ -1281,19 +1285,20 @@ def _build_backbone(self, input_shape, backbone_params):
print('backbone_output_size: ', backbone_output_size)
backbone = nn.Sequential(*list(backbone.children())[:-1])
elif backbone_type == 'convnext_tiny':
backbone = create_model('convnext_tiny', pretrained=pretrained)
# Modify the first convolution layer to match input shape if needed
#backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False)
# Remove the fully connected layer
backbone_output_size = backbone.head.fc.in_features
backbone = models.convnext_tiny(pretrained=pretrained)
backbone_output_size = backbone.classifier[2].in_features
backbone.classifier = nn.Identity()

backbone = nn.Sequential(*list(backbone.children())[:-1])
# Do we need it?
# backbone = nn.Sequential(*list(backbone.children())[:-1])
elif backbone_type == 'vit_tiny_patch16_224':
backbone = create_model('vit_tiny_patch16_224', pretrained=pretrained)
# # ViT outputs a single token, so no need to remove layers
# backbone = models.vit_small_patch16_224(pretrained=pretrained)
backbone = models.vit_small_patch16_224(pretrained=pretrained)
backbone_output_size = backbone.heads.head.in_features
backbone.heads.head = nn.Identity()

# ViT outputs a single token, so no need to remove layers
# Is it true?

else:
raise ValueError(f'Unknown backbone type: {backbone_type}')

Expand Down
32 changes: 16 additions & 16 deletions rl_games/configs/atari/ppo_pong_envpool_backbone.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@ params:
# name: 'None'

backbone:
type: resnet18 #vit_tiny_patch16_224 #convnext_tiny #resnet18
pretrained: False
permute_input: True
freeze: False
type: resnet18 #convnext_tiny #vit_tiny_patch16_224
pretrained: True
permute_input: False
freeze: True

args:
zero_init_residual: False
zero_init_residual: True
norm_layer: None

mlp:
units: [512]
activation: relu
regularizer:
name: 'None'
name: None
initializer:
name: default

Expand All @@ -49,35 +49,35 @@ params:
# units: 256
# layers: 1
config:
name: pong_resnet18_nopretrained_novaluenorm_weightdecay_nomaxpool
name: pong_resnet18_pretrained_2_mini_epoch_1e-4_linear_lr_norm_value_frozen
env_name: envpool
reward_shaper:
min_val: -1
max_val: 1

mixed_precision: False
mixed_precision: True
normalize_input: False
normalize_value: False
normalize_value: True
normalize_advantage: True
gamma: 0.995
gamma: 0.99
tau: 0.95
learning_rate: 1e-4

score_to_win: 100000
grad_norm: 1.5
grad_norm: 1.0
entropy_coef: 0.01
truncate_grads: True

e_clip: 0.2
clip_value: True
save_best_after: 25
save_frequency: 50
num_actors: 64
horizon_length: 128
minibatch_size: 2048
mini_epochs: 4
num_actors: 32
horizon_length: 64
minibatch_size: 512
mini_epochs: 2
critic_coef: 1
lr_schedule: None
lr_schedule: linear
kl_threshold: 0.01
use_diagnostics: True
seq_length: 8
Expand Down

0 comments on commit 006fe2c

Please sign in to comment.