Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use custom network? #301

Open
ErcBunny opened this issue Aug 9, 2024 · 8 comments
Open

How to use custom network? #301

ErcBunny opened this issue Aug 9, 2024 · 8 comments

Comments

@ErcBunny
Copy link
Contributor

ErcBunny commented Aug 9, 2024

I would like to use the following network for my project, but I am not sure how exactly to do it.

                                                     actor    
    ┌─────┐      ┌─────────┐   ┌────┐  ┌─────────┐   ┌───┐    
x──►│ CNN ├─────►│torch.cat│──►│LSTM├─►│torch.cat├┬─►│MLP├──►a
    └─────┘      └─────────┘   └────┘  └─────────┘│  └───┘    
                      ▲                   ▲  ▲    │  ┌───┐    
                      │                   │  │    └─►│MLP├──►v 
                      y───────────────────┘  z       └───┘    
                                                     value    

In the diagram, x, y, z come from the observation dictionary, and a represents action, v is the value.

Thank you very much for considering my question and I look forward to the guidance.

@Denys88
Copy link
Owner

Denys88 commented Aug 11, 2024

https://github.com/Denys88/IsaacGymEnvs/blob/main/isaacgymenvs/learning/networks/ig_networks.py here is a good example how I tested pretty complex networks with IsaacGym.
Let me know if it is enough for you.

@ViktorM
Copy link
Collaborator

ViktorM commented Aug 11, 2024

Not exactly your example, but here is a very similar Resnet network builder with RNN (LSTM) layers support.

@ErcBunny
Copy link
Contributor Author

ErcBunny commented Aug 11, 2024

Thank you @Denys88 and @ViktorM for providing the examples and the pointer to the A2CResnetBuilder.

While waiting for the answer, I was also looking at the code in network builder and found A2CBuilder and A2CResnetBuilder, which all provide blocks to create the CNN/Resnet + LSTM + MLP network.

They all seem to only accept obs_dict['obs'] as the single input to the forward function, but in my project I have not only the image tensor obs_dict['x'] but also other state tensors obs_dict['y'] and obs_dict['z'] to be consumed by different blocks of the net.

So, I am planning to create a derived class of NetworkBuilder mimicking either A2CBuilder or A2CResnetBuilder (btw which one is better for my single channel, normalized depth image of size (256,192)?) and modify the forward function (and perhaps other necessary intialization parts) to adapt it to my obs_dict. I guess I'll also need a new model derived from ModelA2CContinuousLogStd to make it work. Is this approach feasible and will it bring potential problems?

Please correct me if I've misunderstood anything. Looking forward to hearing your thoughts on this approach and any recommendations you might have!

@ankurhanda
Copy link
Collaborator

ankurhanda commented Aug 11, 2024

@ViktorM @Denys88 the example above assumes that you are using a frozen network. You can't optimise the weights of this network because rl_games has torch.infernce() context for doing running mean and std normalisation which breaks the compute graph for the vision network.

So, this is only suitable for pre-trained networks and not end to end visual RL.

@ErcBunny
Copy link
Contributor Author

ErcBunny commented Aug 11, 2024

@ViktorM @Denys88 the example above assumes that you are using a frozen network. You can't optimise the weights of this network because rl_games has torch.infernce() context for doing running mean and std normalisation which breaks the compute graph for the vision network.

So, this is only suitable for pre-trained networks and not end to end visual RL.

Thanks for your comment @ankurhanda. I have a question about standardization breaking the compute graph for vision net.

I decided to first implement a simpler version of my network illustrated like this:

                                                     actor    
    ┌─────┐      ┌─────────┐   ┌────┐  ┌─────────┐   ┌───┐    
x──►│ CNN ├─────►│torch.cat│──►│LSTM├─►│torch.cat├┬─►│MLP├──►a
    └─────┘      └─────────┘   └────┘  └─────────┘│  └───┘    
                      ▲                           │  ┌───┐    
                      │                           └─►│MLP├──►v 
                      y                              └───┘    
                                                     value    

where x is retrieved from input_dict["obs"]["image"] and y from input_dict["obs"]["state"].

And my question would be: If I only use running statistics to standardize y and manually normalize x inside my env step to [0, 1], is it possible to do e2e learning with CNN?

@ankurhanda
Copy link
Collaborator

ankurhanda commented Aug 11, 2024

As long as you don't do anything to the CNN, you should be fine. Normalizing x should be OK.

My main concern is if you want to do end-to-end optimising CNN weights. Current settings don't allow that because compute graph is broken during normalisation inside the rl_games code.

def norm_obs(self, observation):

@ErcBunny
Copy link
Contributor Author

ErcBunny commented Aug 11, 2024

I am trying to do e2e learning to also optimize the CNN weights. Why does normalizing the input to a network with no grad break the compute graph? Could you share more details?

I assume if the concatenated tensor of (x, y) is normalized through no grad, then CNN params will not be updated. But in my case normalization happens at inputs, I guess it is probably fine? Please correct me if I am wrong...

@ViktorM
Copy link
Collaborator

ViktorM commented Aug 13, 2024

@ViktorM @Denys88 the example above assumes that you are using a frozen network. You can't optimise the weights of this network because rl_games has torch.infernce() context for doing running mean and std normalisation which breaks the compute graph for the vision network.

So, this is only suitable for pre-trained networks and not end to end visual RL.

@ankurhanda I don't think we use torch.inference() in the code, can you point to the exact place. The example above: https://github.com/Denys88/rl_games/blob/master/rl_games/algos_torch/network_builder.py#L623 is for end2end training, we have configs Atari training from scratch: https://github.com/Denys88/rl_games/blob/master/rl_games/configs/atari/ppo_breakout_torch_impala.yaml

It can easily be modified to load pre-trained weights and freeze them, or not, but the default variant is exactly for e2e training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants