Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
kellyguo11 committed Aug 3, 2024
1 parent 213d88f commit 9e1edb3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class CartpoleRGBCameraEnvCfg(DirectRLEnvCfg):
height=80,
)
num_observations = num_channels * tiled_camera.height * tiled_camera.width
write_image_to_file = False
write_image_to_file = True

# change viewer settings
viewer = ViewerCfg(eye=(20.0, 20.0, 20.0))
Expand Down Expand Up @@ -172,14 +172,16 @@ def _apply_action(self) -> None:
def _get_observations(self) -> dict:
data_type = "rgb" if "rgb" in self.cfg.tiled_camera.data_types else "depth"
if "rgb" in self.cfg.tiled_camera.data_types:
camera_data = 1 - self._tiled_camera.data.output[data_type]
camera_data = self._tiled_camera.data.output[data_type]
mean_tensor = torch.mean(camera_data, dim=(1, 2), keepdim=True)
camera_data -= mean_tensor
elif "depth" in self.cfg.tiled_camera.data_types:
camera_data = self._tiled_camera.data.output[data_type]
camera_data[camera_data == float("inf")] = 0
observations = {"policy": camera_data.clone()}

if self.cfg.write_image_to_file:
save_images_to_file(observations["policy"], f"cartpole_{data_type}.png")
save_images_to_file(torch.clamp(observations["policy"], 0, 1), f"cartpole_{data_type}.png")

return observations

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,41 @@ params:
model:
name: continuous_a2c_logstd

# doesn't have this fine grained control but made it close
network:
name: actor_critic
separate: False

space:
continuous:
mu_activation: None
sigma_activation: None

mu_init:
name: default
sigma_init:
name: const_initializer
val: 0
fixed_sigma: True
mlp:
units: [512, 512, 256, 128]
activation: elu
units: [512]
activation: relu
d2rl: False

initializer:
name: default
regularizer:
name: None
rnn:
name: lstm
units: 1024
layers: 1
before_mlp: True
layer_norm: True

load_checkpoint: False # flag which sets whether to load the checkpoint
load_path: '' # path to the checkpoint to load

config:
name: shadow_hand_resnet
name: shadow_hand_image_lstm
env_name: rlgpu
device: 'cuda:0'
device_name: 'cuda:0'
Expand All @@ -52,29 +57,28 @@ params:
mixed_precision: False
normalize_input: True
normalize_value: True
value_bootstrap: True
num_actors: -1 # configured from the script (based on num_envs)
reward_shaper:
scale_value: 0.01
normalize_advantage: True
gamma: 0.99
tau : 0.95
learning_rate: 5e-4
gamma: 0.998
tau: 0.95
learning_rate: 1e-4
lr_schedule: adaptive
schedule_type: standard
kl_threshold: 0.016
score_to_win: 100000
max_epochs: 50000
max_epochs: 100000
save_best_after: 100
save_frequency: 200
print_stats: True
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
horizon_length: 32
minibatch_size: 8192
mini_epochs: 5
horizon_length: 64
minibatch_size: 32768
mini_epochs: 4
critic_coef: 4
clip_value: True
seq_length: 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
@configclass
class ShadowHandRGBCameraEnvCfg(ShadowHandEnvCfg):
# scene
scene: InteractiveSceneCfg = InteractiveSceneCfg(num_envs=128, env_spacing=5.0, replicate_physics=True)
scene: InteractiveSceneCfg = InteractiveSceneCfg(num_envs=512, env_spacing=5.0, replicate_physics=True)

# camera
tiled_camera: TiledCameraCfg = TiledCameraCfg(
Expand All @@ -50,7 +50,7 @@ class ShadowHandRGBCameraEnvCfg(ShadowHandEnvCfg):

# env
num_channels = 3
num_observations = 157-17+128+128#649#536 #num_channels * tiled_camera.height * tiled_camera.width #+ 157
num_observations = 157-17+512#649#536 #num_channels * tiled_camera.height * tiled_camera.width #+ 157


@configclass
Expand Down Expand Up @@ -203,55 +203,72 @@ def forward(self, x):
return x

class CustomCNN(nn.Module):
def __init__(self, depth=False):
def __init__(self, device, depth=False):
self.device = device
super().__init__()
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
num_channel = 1 if depth else 3
self.cnn = nn.Sequential(
nn.Conv2d(num_channel, 16, kernel_size=6, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=0),
# nn.BatchNorm2d(16),
nn.LayerNorm([16, 110, 110]),
nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
# nn.BatchNorm2d(32),
nn.LayerNorm([32, 54, 54]),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=0),
# nn.BatchNorm2d(64),
nn.LayerNorm([64, 26, 26]),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
# nn.BatchNorm2d(128),
nn.LayerNorm([128, 12, 12]),
nn.AvgPool2d(12)
)

self.linear = nn.Sequential(
nn.Linear(64*22*22, 1024),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(1024, 128),
nn.Linear(256, 512),
nn.ReLU(),
)

self.resnet18_mean = torch.tensor([0.485, 0.0456, 0.0406], device=self.device)
self.resnet18_std = torch.tensor([0.229, 0.224, 0.225], device=self.device)
self.resnet_transform = transforms.Normalize(self.resnet18_mean, self.resnet18_std)

def forward(self, x):
# save_images_to_file(x, f"shadow_hand_transformed.png")
cnn_x = self.cnn(x.permute(0, 3, 1, 2))
out = self.linear(cnn_x.reshape(-1, 64*22*22))
cnn_x = self.cnn(self.resnet_transform(x.permute(0, 3, 1, 2)))
out = self.linear(cnn_x.view(-1, 128))
return out

# model = ResNet18()
self.rgb_model = CustomCNN()
self.depth_model = CustomCNN(depth=True)
self.rgb_model = CustomCNN(self.device)
# self.depth_model = CustomCNN(depth=True)
self.rgb_model.to(self.device)
self.depth_model.to(self.device)
# self.depth_model.to(self.device)


def compute_embeddings_observations(self, state_obs):
rgb_img = 1 - self._tiled_camera.data.output["rgb"][..., :3].clone()
depth_img = self._tiled_camera.data.output["depth"].clone()
depth_img[depth_img==float("inf")] = 0
depth_img /= 5.0
depth_img /= torch.max(depth_img)
rgb_img = self._tiled_camera.data.output["rgb"][..., :3].clone()


# mean_tensor = torch.mean(rgb_img, dim=(1, 2), keepdim=True)
# rgb_img -= mean_tensor
# depth_img = self._tiled_camera.data.output["depth"].clone()
# depth_img[depth_img==float("inf")] = 0
# depth_img /= 5.0
# depth_img /= torch.max(depth_img)
rgb_embeddings = self.rgb_model(rgb_img)
depth_embeddings = self.depth_model(depth_img)
# depth_embeddings = self.depth_model(depth_img)

obs = torch.cat(
(
state_obs,
rgb_embeddings,
depth_embeddings
# depth_embeddings
),
dim=-1
)
Expand Down

0 comments on commit 9e1edb3

Please sign in to comment.