From 006fe2cae6935e080c045d8b1c4763e5ab6b1479 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Mon, 19 Aug 2024 10:44:29 -0700 Subject: [PATCH] Resnet18 works! --- rl_games/algos_torch/network_builder.py | 25 +++++++++------ .../atari/ppo_pong_envpool_backbone.yaml | 32 +++++++++---------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 6fb00ff0..3323dc84 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -1072,7 +1072,6 @@ def build(self, name, **kwargs): from torchvision import models -from timm import create_model # timm is required for ConvNeXt and ViT class VisionBackboneBuilder(NetworkBuilder): def __init__(self, **kwargs): @@ -1172,6 +1171,8 @@ def forward(self, obs_dict): proprio = obs_dict['proprio'] else: obs = obs_dict['obs'] + + # print('obs.shape: ', obs.shape) if self.permute_input: obs = obs.permute((0, 3, 1, 2)) @@ -1181,7 +1182,10 @@ def forward(self, obs_dict): out = obs out = self.cnn(out) + #print(out.shape) out = out.flatten(1) + #print(out.shape) + #print('AAAAAAAAAAAAAAAAAaaa') out = self.flatten_act(out) if self.proprio_size > 0: @@ -1281,19 +1285,20 @@ def _build_backbone(self, input_shape, backbone_params): print('backbone_output_size: ', backbone_output_size) backbone = nn.Sequential(*list(backbone.children())[:-1]) elif backbone_type == 'convnext_tiny': - backbone = create_model('convnext_tiny', pretrained=pretrained) - # Modify the first convolution layer to match input shape if needed - #backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False) - # Remove the fully connected layer - backbone_output_size = backbone.head.fc.in_features + backbone = models.convnext_tiny(pretrained=pretrained) + backbone_output_size = backbone.classifier[2].in_features + backbone.classifier = nn.Identity() - backbone = nn.Sequential(*list(backbone.children())[:-1]) + # Do we need it? + # backbone = nn.Sequential(*list(backbone.children())[:-1]) elif backbone_type == 'vit_tiny_patch16_224': - backbone = create_model('vit_tiny_patch16_224', pretrained=pretrained) - # # ViT outputs a single token, so no need to remove layers - # backbone = models.vit_small_patch16_224(pretrained=pretrained) + backbone = models.vit_small_patch16_224(pretrained=pretrained) backbone_output_size = backbone.heads.head.in_features backbone.heads.head = nn.Identity() + + # ViT outputs a single token, so no need to remove layers + # Is it true? + else: raise ValueError(f'Unknown backbone type: {backbone_type}') diff --git a/rl_games/configs/atari/ppo_pong_envpool_backbone.yaml b/rl_games/configs/atari/ppo_pong_envpool_backbone.yaml index 6ae0f5c1..ff63181f 100644 --- a/rl_games/configs/atari/ppo_pong_envpool_backbone.yaml +++ b/rl_games/configs/atari/ppo_pong_envpool_backbone.yaml @@ -27,20 +27,20 @@ params: # name: 'None' backbone: - type: resnet18 #vit_tiny_patch16_224 #convnext_tiny #resnet18 - pretrained: False - permute_input: True - freeze: False + type: resnet18 #convnext_tiny #vit_tiny_patch16_224 + pretrained: True + permute_input: False + freeze: True args: - zero_init_residual: False + zero_init_residual: True norm_layer: None mlp: units: [512] activation: relu regularizer: - name: 'None' + name: None initializer: name: default @@ -49,22 +49,22 @@ params: # units: 256 # layers: 1 config: - name: pong_resnet18_nopretrained_novaluenorm_weightdecay_nomaxpool + name: pong_resnet18_pretrained_2_mini_epoch_1e-4_linear_lr_norm_value_frozen env_name: envpool reward_shaper: min_val: -1 max_val: 1 - mixed_precision: False + mixed_precision: True normalize_input: False - normalize_value: False + normalize_value: True normalize_advantage: True - gamma: 0.995 + gamma: 0.99 tau: 0.95 learning_rate: 1e-4 score_to_win: 100000 - grad_norm: 1.5 + grad_norm: 1.0 entropy_coef: 0.01 truncate_grads: True @@ -72,12 +72,12 @@ params: clip_value: True save_best_after: 25 save_frequency: 50 - num_actors: 64 - horizon_length: 128 - minibatch_size: 2048 - mini_epochs: 4 + num_actors: 32 + horizon_length: 64 + minibatch_size: 512 + mini_epochs: 2 critic_coef: 1 - lr_schedule: None + lr_schedule: linear kl_threshold: 0.01 use_diagnostics: True seq_length: 8