Skip to content

Commit

Permalink
Added convnext and vit backbones support. Added preprocessing.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Aug 20, 2024
1 parent 006fe2c commit 61d1bc4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 28 deletions.
52 changes: 39 additions & 13 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,20 @@ def build(self, name, **kwargs):
return net


from torchvision import models
from torchvision import models, transforms

def preprocess_image(image):
# Normalize the image using ImageNet's mean and standard deviation
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], # Mean of ImageNet dataset
std=[0.229, 0.224, 0.225] # Std of ImageNet dataset
)

# Apply the normalization
image = normalize(image)

return image


class VisionBackboneBuilder(NetworkBuilder):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -1103,6 +1116,7 @@ def __init__(self, params, **kwargs):

self.cnn, self.cnn_output_size = self._build_backbone(input_shape, params['backbone'])

self.resize_transform = transforms.Resize((224, 224))
mlp_input_size = self.cnn_output_size + self.proprio_size
if len(self.units) == 0:
out_size = self.cnn_output_size
Expand Down Expand Up @@ -1176,16 +1190,20 @@ def forward(self, obs_dict):
if self.permute_input:
obs = obs.permute((0, 3, 1, 2))

if self.preprocess_image:
obs = preprocess_image(obs)

# Assuming your input image is a tensor or PIL image, resize it to 224x224
#obs = self.resize_transform(obs)

dones = obs_dict.get('dones', None)
bptt_len = obs_dict.get('bptt_len', 0)
states = obs_dict.get('rnn_states', None)

out = obs
out = self.cnn(out)
#print(out.shape)
out = out.flatten(1)
#print(out.shape)
#print('AAAAAAAAAAAAAAAAAaaa')

out = self.flatten_act(out)

if self.proprio_size > 0:
Expand Down Expand Up @@ -1272,12 +1290,15 @@ def load(self, params):
def _build_backbone(self, input_shape, backbone_params):
backbone_type = backbone_params['type']
pretrained = backbone_params.get('pretrained', False)
self.preprocess_image = backbone_params.get('preprocess_image', False)

if backbone_type == 'resnet18':
backbone = models.resnet18(pretrained=pretrained, zero_init_residual=True) # norm_layer=nn.LayerNorm
# Modify the first convolution layer to match input shape if needed

# TODO: add low-res parameter
backbone.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False)
backbone.maxpool = nn.Identity()
#backbone.maxpool = nn.Identity()
# if input_shape[0] != 3:
# model.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False)
# Remove the fully connected layer
Expand All @@ -1289,16 +1310,21 @@ def _build_backbone(self, input_shape, backbone_params):
backbone_output_size = backbone.classifier[2].in_features
backbone.classifier = nn.Identity()

# Do we need it?
# backbone = nn.Sequential(*list(backbone.children())[:-1])
elif backbone_type == 'vit_tiny_patch16_224':
backbone = models.vit_small_patch16_224(pretrained=pretrained)
backbone_output_size = backbone.heads.head.in_features
backbone.heads.head = nn.Identity()
# Modify the first convolutional layer to work with smaller resolutions
backbone.features[0][0] = nn.Conv2d(
in_channels=input_shape[0],
out_channels=backbone.features[0][0].out_channels,
kernel_size=3, # Reduce kernel size to 3x3
stride=1, # Reduce stride to 1 to preserve spatial resolution
padding=1, # Add padding to preserve dimensions after convolution
bias=True # False
)

# ViT outputs a single token, so no need to remove layers
# Is it true?
elif backbone_type == 'vit_b_16':
backbone = models.vision_transformer.vit_b_16(pretrained=pretrained)

backbone_output_size = backbone.heads.head.in_features
backbone.heads.head = nn.Identity()
else:
raise ValueError(f'Unknown backbone type: {backbone_type}')

Expand Down
22 changes: 7 additions & 15 deletions rl_games/configs/atari/ppo_pong_envpool_backbone.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,13 @@ params:
value_shape: 1
space:
discrete:

# cnn:
# permute_input: False
# conv_depths: [16, 32, 32]
# activation: relu
# initializer:
# name: default
# regularizer:
# name: 'None'

backbone:
type: resnet18 #convnext_tiny #vit_tiny_patch16_224
type: resnet18 #convnext_tiny #vit_b_16 #resnet18
pretrained: True
permute_input: False
freeze: True
freeze: False
preprocess_image: False

args:
zero_init_residual: True
Expand All @@ -49,7 +41,7 @@ params:
# units: 256
# layers: 1
config:
name: pong_resnet18_pretrained_2_mini_epoch_1e-4_linear_lr_norm_value_frozen
name: pong_resnet18_maxpool
env_name: envpool
reward_shaper:
min_val: -1
Expand All @@ -71,7 +63,7 @@ params:
e_clip: 0.2
clip_value: True
save_best_after: 25
save_frequency: 50
save_frequency: 500
num_actors: 32
horizon_length: 64
minibatch_size: 512
Expand All @@ -81,14 +73,14 @@ params:
kl_threshold: 0.01
use_diagnostics: True
seq_length: 8
max_epochs: 1000
max_epochs: 2000
#weight_decay: 0.001

env_config:
env_name: Pong-v5
has_lives: False
use_dict_obs_space: False #True

stack_num: 3
player:
render: True
games_num: 10
Expand Down

0 comments on commit 61d1bc4

Please sign in to comment.