Skip to content

Commit

Permalink
Clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Aug 16, 2024
1 parent 80f43f7 commit 7ea73fa
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 22 deletions.
10 changes: 0 additions & 10 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,6 @@ def get_action_values(self, obs):
'rnn_states' : self.rnn_states
}

# if 'proprio' in obs:
# input_dict['proprio'] = obs['proprio']

with torch.no_grad():
res_dict = self.model(input_dict)
if self.has_central_value:
Expand Down Expand Up @@ -452,8 +449,6 @@ def get_values(self, obs):
'obs' : processed_obs,
'rnn_states' : self.rnn_states
}
# if 'proprio' in obs:
# input_dict['proprio'] = obs['proprio']

result = self.model(input_dict)
value = result['values']
Expand Down Expand Up @@ -835,8 +830,6 @@ def play_steps_rnn(self):

self.rnn_states = res_dict['rnn_states']
self.experience_buffer.update_data('obses', n, self.obs['obs'])
# if 'proprio' in self.obs:
# self.experience_buffer.update_data('proprio', n, self.obs['proprio'])
self.experience_buffer.update_data('dones', n, self.dones.byte())

for k in update_list:
Expand Down Expand Up @@ -1030,9 +1023,6 @@ def prepare_dataset(self, batch_dict):
dataset_dict['rnn_states'] = rnn_states
dataset_dict['rnn_masks'] = rnn_masks

# if 'proprio' in batch_dict:
# dataset_dict['proprio'] = batch_dict['proprio']

if self.use_action_masks:
dataset_dict['action_masks'] = batch_dict['action_masks']

Expand Down
6 changes: 1 addition & 5 deletions rl_games/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,7 @@ def _init_from_env_info(self, env_info):
state_base_shape = self.state_base_shape

self.tensor_dict['obses'] = self._create_tensor_from_space(env_info['observation_space'], obs_base_shape)
# print("obs base shape", obs_base_shape)
# print('obses shape:', self.tensor_dict['obses'].shape)
# print('proprioception_space shape:', env_info.get('proprioception_space'))
# if env_info.get('proprieception_space') is not None:
# self.tensor_dict['proprio'] = self._create_tensor_from_space(env_info['proprioception_space'], self.obs_base_shape)

if self.has_central_value:
self.tensor_dict['states'] = self._create_tensor_from_space(env_info['state_space'], state_base_shape)

Expand Down
13 changes: 6 additions & 7 deletions rl_games/networks/vision_networks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
from torch import nn
from torchvision import models
import torch.nn.functional as F
import torch_ext
from rl_games.algos_torch.network_builder import NetworkBuilder
from rl_games.algos_torch.network_builder import NetworkBuilder, ImpalaSequential


class A2CVisionBuilder(NetworkBuilder):
Expand All @@ -20,7 +21,8 @@ def __init__(self, params, **kwargs):
if type(input_shape) is dict:
input_shape = input_shape['camera']
proprio_shape = input_shape['proprio']
self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1)

self.num_seqs = kwargs.pop('num_seqs', 1)
self.value_size = kwargs.pop('value_size', 1)

NetworkBuilder.BaseNetwork.__init__(self)
Expand Down Expand Up @@ -208,10 +210,6 @@ def build(self, name, **kwargs):
return net


import torch
import torch.nn as nn
from torchvision import models

class A2CVisionBackboneBuilder(NetworkBuilder):
def __init__(self, **kwargs):
NetworkBuilder.__init__(self)
Expand All @@ -227,7 +225,8 @@ def __init__(self, params, **kwargs):
if isinstance(input_shape, dict):
input_shape = input_shape['camera']
proprio_shape = input_shape['proprio']
self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1)

self.num_seqs = kwargs.pop('num_seqs', 1)
self.value_size = kwargs.pop('value_size', 1)

NetworkBuilder.BaseNetwork.__init__(self)
Expand Down

0 comments on commit 7ea73fa

Please sign in to comment.