diff --git a/examples/baselines/bc/bc.py b/examples/baselines/bc/bc.py index 84b7bc30a..5d5b86c8c 100644 --- a/examples/baselines/bc/bc.py +++ b/examples/baselines/bc/bc.py @@ -1,8 +1,8 @@ -from collections import defaultdict import os import random -from dataclasses import dataclass import time +from collections import defaultdict +from dataclasses import dataclass from typing import Optional import gymnasium as gym @@ -12,16 +12,18 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from torch.utils.tensorboard import SummaryWriter import tyro -import wandb from mani_skill.utils import gym_utils from mani_skill.utils.io_utils import load_json from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import BatchSampler, RandomSampler +from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from behavior_cloning.make_env import make_eval_envs + from behavior_cloning.evaluate import evaluate +from behavior_cloning.make_env import make_eval_envs + + @dataclass class Args: exp_name: Optional[str] = None @@ -43,7 +45,7 @@ class Args: env_id: str = "PegInsertionSide-v0" """the id of the environment""" - demo_path: str = 'data/ms2_official_demos/rigid_body/PegInsertionSide-v0/trajectory.state.pd_ee_delta_pose.h5' + demo_path: str = "data/ms2_official_demos/rigid_body/PegInsertionSide-v0/trajectory.state.pd_ee_delta_pose.h5" """the path of demo dataset (pkl or h5)""" num_demos: Optional[int] = None """number of trajectories to load from the demo dataset""" @@ -76,12 +78,13 @@ class Args: """the simulation backend to use for evaluation environments. can be "cpu" or "gpu""" num_dataload_workers: int = 0 """the number of workers to use for loading the training data in the torch dataloader""" - control_mode: str = 'pd_joint_delta_pos' + control_mode: str = "pd_joint_delta_pos" """the control mode to use for the evaluation environments. Must match the control mode of the demonstration dataset.""" # additional tags/configs for logging purposes to wandb and shared comparisons with other algorithms demo_type: Optional[str] = None + # taken from here # https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py class IterationBasedBatchSampler(BatchSampler): @@ -203,10 +206,14 @@ def forward(self, state: torch.Tensor) -> torch.Tensor: def save_ckpt(run_name, tag): - os.makedirs(f'runs/{run_name}/checkpoints', exist_ok=True) - torch.save({ - "actor": actor.state_dict(), - }, f'runs/{run_name}/checkpoints/{tag}.pt') + os.makedirs(f"runs/{run_name}/checkpoints", exist_ok=True) + torch.save( + { + "actor": actor.state_dict(), + }, + f"runs/{run_name}/checkpoints/{tag}.pt", + ) + if __name__ == "__main__": args = tyro.cli(Args) @@ -217,18 +224,21 @@ def save_ckpt(run_name, tag): else: run_name = args.exp_name - if args.demo_path.endswith('.h5'): + if args.demo_path.endswith(".h5"): import json - json_file = args.demo_path[:-2] + 'json' - with open(json_file, 'r') as f: + + json_file = args.demo_path[:-2] + "json" + with open(json_file, "r") as f: demo_info = json.load(f) - if 'control_mode' in demo_info['env_info']['env_kwargs']: - control_mode = demo_info['env_info']['env_kwargs']['control_mode'] - elif 'control_mode' in demo_info['episodes'][0]: - control_mode = demo_info['episodes'][0]['control_mode'] + if "control_mode" in demo_info["env_info"]["env_kwargs"]: + control_mode = demo_info["env_info"]["env_kwargs"]["control_mode"] + elif "control_mode" in demo_info["episodes"][0]: + control_mode = demo_info["episodes"][0]["control_mode"] else: - raise Exception('Control mode not found in json') - assert control_mode == args.control_mode, f"Control mode mismatched. Dataset has control mode {control_mode}, but args has control mode {args.control_mode}" + raise Exception("Control mode not found in json") + assert ( + control_mode == args.control_mode + ), f"Control mode mismatched. Dataset has control mode {control_mode}, but args has control mode {args.control_mode}" np.random.seed(args.seed) random.seed(args.seed) @@ -238,15 +248,32 @@ def save_ckpt(run_name, tag): device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup - env_kwargs = dict(control_mode=args.control_mode, reward_mode="sparse", obs_mode="state", render_mode="rgb_array") + env_kwargs = dict( + control_mode=args.control_mode, + reward_mode="sparse", + obs_mode="state", + render_mode="rgb_array", + ) if args.max_episode_steps is not None: env_kwargs["max_episode_steps"] = args.max_episode_steps - envs = make_eval_envs(args.env_id, args.num_eval_envs, args.sim_backend, env_kwargs, video_dir=f'runs/{run_name}/videos' if args.capture_video else None) + envs = make_eval_envs( + args.env_id, + args.num_eval_envs, + args.sim_backend, + env_kwargs, + video_dir=f"runs/{run_name}/videos" if args.capture_video else None, + ) if args.track: import wandb + config = vars(args) - config["eval_env_cfg"] = dict(**env_kwargs, num_envs=args.num_eval_envs, env_id=args.env_id, env_horizon=gym_utils.find_max_episode_steps_value(envs)) + config["eval_env_cfg"] = dict( + **env_kwargs, + num_envs=args.num_eval_envs, + env_id=args.env_id, + env_horizon=gym_utils.find_max_episode_steps_value(envs), + ) wandb.init( project=args.wandb_project_name, entity=args.wandb_entity, @@ -255,12 +282,13 @@ def save_ckpt(run_name, tag): name=run_name, save_code=True, group="BehaviorCloning", - tags=["behavior_cloning"] + tags=["behavior_cloning"], ) writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + "|param|value|\n|-|-|\n%s" + % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) ds = ManiSkillDataset( @@ -275,7 +303,9 @@ def save_ckpt(run_name, tag): sampler = RandomSampler(ds) batchsampler = BatchSampler(sampler, args.batch_size, drop_last=True) itersampler = IterationBasedBatchSampler(batchsampler, args.total_iters) - dataloader = DataLoader(ds, batch_sampler=itersampler, num_workers=args.num_dataload_workers) + dataloader = DataLoader( + ds, batch_sampler=itersampler, num_workers=args.num_dataload_workers + ) actor = Actor( envs.single_observation_space.shape[0], envs.single_action_space.shape[0] ) @@ -296,11 +326,14 @@ def save_ckpt(run_name, tag): if iteration % args.log_freq == 0: print(f"Iteration {iteration}, loss: {loss.item()}") - writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], iteration) + writer.add_scalar( + "charts/learning_rate", optimizer.param_groups[0]["lr"], iteration + ) writer.add_scalar("losses/total_loss", loss.item(), iteration) if iteration % args.eval_freq == 0: actor.eval() + def sample_fn(obs): if isinstance(obs, np.ndarray): obs = torch.from_numpy(obs).float().to(device) @@ -308,6 +341,7 @@ def sample_fn(obs): if args.sim_backend == "cpu": action = action.cpu().numpy() return action + with torch.no_grad(): eval_metrics = evaluate(args.num_eval_episodes, sample_fn, envs) actor.train() @@ -322,9 +356,12 @@ def sample_fn(obs): if k in eval_metrics and eval_metrics[k] > best_eval_metrics[k]: best_eval_metrics[k] = eval_metrics[k] save_ckpt(run_name, f"best_eval_{k}") - print(f'New best {k}_rate: {eval_metrics[k]:.4f}. Saving checkpoint.') + print( + f"New best {k}_rate: {eval_metrics[k]:.4f}. Saving checkpoint." + ) if args.save_freq is not None and iteration % args.save_freq == 0: save_ckpt(run_name, str(iteration)) envs.close() - wandb.finish() + if args.track: + wandb.finish() diff --git a/examples/baselines/bc/bc_rgbd.py b/examples/baselines/bc/bc_rgbd.py index 7a4bda841..9425273d3 100644 --- a/examples/baselines/bc/bc_rgbd.py +++ b/examples/baselines/bc/bc_rgbd.py @@ -1,8 +1,8 @@ -from collections import defaultdict import os -from dataclasses import dataclass import random import time +from collections import defaultdict +from dataclasses import dataclass from typing import Optional import gymnasium as gym @@ -13,15 +13,17 @@ import torch.nn.functional as F import torch.optim as optim import tyro -import wandb -from torch.utils.tensorboard import SummaryWriter +from mani_skill.utils import gym_utils from mani_skill.utils.io_utils import load_json from mani_skill.utils.wrappers.flatten import FlattenRGBDObservationWrapper from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import BatchSampler, RandomSampler +from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from behavior_cloning.make_env import make_eval_envs + from behavior_cloning.evaluate import evaluate +from behavior_cloning.make_env import make_eval_envs + @dataclass class Args: @@ -44,7 +46,7 @@ class Args: env_id: str = "PegInsertionSide-v0" """the id of the environment""" - demo_path: str = 'data/ms2_official_demos/rigid_body/PegInsertionSide-v0/trajectory.state.pd_ee_delta_pose.h5' + demo_path: str = "data/ms2_official_demos/rigid_body/PegInsertionSide-v0/trajectory.state.pd_ee_delta_pose.h5" """the path of demo dataset (pkl or h5)""" num_demos: Optional[int] = None """number of trajectories to load from the demo dataset""" @@ -79,14 +81,13 @@ class Args: """the simulation backend to use for evaluation environments. can be "cpu" or "gpu""" num_dataload_workers: int = 0 """the number of workers to use for loading the training data in the torch dataloader""" - control_mode: str = 'pd_joint_delta_pos' + control_mode: str = "pd_joint_delta_pos" """the control mode to use for the evaluation environments. Must match the control mode of the demonstration dataset.""" # additional tags/configs for logging purposes to wandb and shared comparisons with other algorithms demo_type: Optional[str] = None - def load_h5_data(data): out = dict() for k in data.keys(): @@ -107,6 +108,7 @@ def make_mlp(in_channels, mlp_channels, act_builder=nn.ReLU, last_act=True): c_in = c_out return nn.Sequential(*module_list) + def flatten_state_dict_with_space(state_dict: dict) -> np.ndarray: states = [] for key in state_dict.keys(): @@ -165,10 +167,7 @@ def __init__(self, dataset_file: str, device: torch.device, load_count) -> None: self.total_frames = 0 self.device = device - if load_count > len(self.episodes): - print( - f"Load count exceeds number of available episodes, loading {len(self.episodes)} which is the max number of episodes present" - ) + if load_count is None: load_count = len(self.episodes) for eps_id in tqdm(range(load_count)): @@ -320,10 +319,13 @@ def forward(self, rgbd, state): def save_ckpt(run_name, tag): - os.makedirs(f'runs/{run_name}/checkpoints', exist_ok=True) - torch.save({ - "actor": actor.state_dict(), - }, f'runs/{run_name}/checkpoints/{tag}.pt') + os.makedirs(f"runs/{run_name}/checkpoints", exist_ok=True) + torch.save( + { + "actor": actor.state_dict(), + }, + f"runs/{run_name}/checkpoints/{tag}.pt", + ) if __name__ == "__main__": @@ -335,18 +337,21 @@ def save_ckpt(run_name, tag): else: run_name = args.exp_name - if args.demo_path.endswith('.h5'): + if args.demo_path.endswith(".h5"): import json - json_file = args.demo_path[:-2] + 'json' - with open(json_file, 'r') as f: + + json_file = args.demo_path[:-2] + "json" + with open(json_file, "r") as f: demo_info = json.load(f) - if 'control_mode' in demo_info['env_info']['env_kwargs']: - control_mode = demo_info['env_info']['env_kwargs']['control_mode'] - elif 'control_mode' in demo_info['episodes'][0]: - control_mode = demo_info['episodes'][0]['control_mode'] + if "control_mode" in demo_info["env_info"]["env_kwargs"]: + control_mode = demo_info["env_info"]["env_kwargs"]["control_mode"] + elif "control_mode" in demo_info["episodes"][0]: + control_mode = demo_info["episodes"][0]["control_mode"] else: - raise Exception('Control mode not found in json') - assert control_mode == args.control_mode, f"Control mode mismatched. Dataset has control mode {control_mode}, but args has control mode {args.control_mode}" + raise Exception("Control mode not found in json") + assert ( + control_mode == args.control_mode + ), f"Control mode mismatched. Dataset has control mode {control_mode}, but args has control mode {args.control_mode}" np.random.seed(args.seed) random.seed(args.seed) @@ -357,15 +362,33 @@ def save_ckpt(run_name, tag): control_mode = os.path.split(args.demo_path)[1].split(".")[2] # env setup - env_kwargs = dict(control_mode=args.control_mode, reward_mode="sparse", obs_mode="rgbd", render_mode="all") + env_kwargs = dict( + control_mode=args.control_mode, + reward_mode="sparse", + obs_mode="rgbd", + render_mode="all", + ) if args.max_episode_steps is not None: env_kwargs["max_episode_steps"] = args.max_episode_steps - envs = make_eval_envs(args.env_id, args.num_eval_envs, args.sim_backend, env_kwargs, video_dir=f'runs/{run_name}/videos' if args.capture_video else None, wrappers=[FlattenRGBDObservationWrapper]) + envs = make_eval_envs( + args.env_id, + args.num_eval_envs, + args.sim_backend, + env_kwargs, + video_dir=f"runs/{run_name}/videos" if args.capture_video else None, + wrappers=[FlattenRGBDObservationWrapper], + ) if args.track: import wandb + config = vars(args) - config["eval_env_cfg"] = dict(**env_kwargs, num_envs=args.num_eval_envs, env_id=args.env_id, env_horizon=gym_utils.find_max_episode_steps_value(envs)) + config["eval_env_cfg"] = dict( + **env_kwargs, + num_envs=args.num_eval_envs, + env_id=args.env_id, + env_horizon=gym_utils.find_max_episode_steps_value(envs), + ) wandb.init( project=args.wandb_project_name, entity=args.wandb_entity, @@ -374,12 +397,13 @@ def save_ckpt(run_name, tag): name=run_name, save_code=True, group="BehaviorCloning", - tags=["behavior_cloning"] + tags=["behavior_cloning"], ) writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + "|param|value|\n|-|-|\n%s" + % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) ds = ManiSkillDataset( @@ -392,7 +416,7 @@ def save_ckpt(run_name, tag): sampler = RandomSampler(ds) batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True) - iter_sampler = IterationBasedBatchSampler(batch_sampler, args.max_timesteps) + iter_sampler = IterationBasedBatchSampler(batch_sampler, args.total_iters) data_loader = DataLoader(ds, batch_sampler=iter_sampler, num_workers=0) actor = Actor(ds.states.shape[1], envs.single_action_space.shape[0]).to( @@ -413,18 +437,26 @@ def save_ckpt(run_name, tag): if iteration % args.log_freq == 0: print(f"Iteration {iteration}, loss: {loss.item()}") - writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], iteration) + writer.add_scalar( + "charts/learning_rate", optimizer.param_groups[0]["lr"], iteration + ) writer.add_scalar("losses/total_loss", loss.item(), iteration) if iteration % args.eval_freq == 0: actor.eval() + norm_tensor = torch.Tensor([255.0, 255.0, 255.0, 1024.0]).to(device) + def sample_fn(obs): - if isinstance(obs, np.ndarray): - obs = torch.from_numpy(obs).float().to(device) - action = actor(obs) + if isinstance(obs["rgbd"], np.ndarray): + for k, v in obs.items(): + obs[k] = torch.from_numpy(v).float().to(device) + + obs["rgbd"] = torch.div(obs["rgbd"], norm_tensor) + action = actor(obs["rgbd"], obs["state"]) if args.sim_backend == "cpu": action = action.cpu().numpy() return action + with torch.no_grad(): eval_metrics = evaluate(args.num_eval_episodes, sample_fn, envs) actor.train() @@ -439,9 +471,12 @@ def sample_fn(obs): if k in eval_metrics and eval_metrics[k] > best_eval_metrics[k]: best_eval_metrics[k] = eval_metrics[k] save_ckpt(run_name, f"best_eval_{k}") - print(f'New best {k}_rate: {eval_metrics[k]:.4f}. Saving checkpoint.') + print( + f"New best {k}_rate: {eval_metrics[k]:.4f}. Saving checkpoint." + ) if args.save_freq is not None and iteration % args.save_freq == 0: save_ckpt(run_name, str(iteration)) envs.close() - wandb.finish() + if args.track: + wandb.finish() diff --git a/examples/baselines/bc/examples.sh b/examples/baselines/bc/examples.sh index b25992528..4e58ba24c 100644 --- a/examples/baselines/bc/examples.sh +++ b/examples/baselines/bc/examples.sh @@ -36,4 +36,4 @@ python -m mani_skill.trajectory.replay_trajectory \ python bc.py --env-id "PickCube-v1" \ --demo-path ~/.maniskill/demos/PickCube-v1/rl/trajectory.state.pd_joint_delta_pos.cpu.h5 \ --control-mode "pd_joint_delta_pos" --sim-backend "cpu" --max-episode-steps 100 \ - --total-iters 10000 \ No newline at end of file + --total-iters 10000 diff --git a/mani_skill/utils/wrappers/gymnasium.py b/mani_skill/utils/wrappers/gymnasium.py index 6d682b71e..17c4aca30 100644 --- a/mani_skill/utils/wrappers/gymnasium.py +++ b/mani_skill/utils/wrappers/gymnasium.py @@ -98,5 +98,5 @@ def reset(self, *, seed=None, options=None): def render(self): ret = self.env.render() - if self.render_mode in ["rgb_array", "sensors"]: + if self.render_mode in ["rgb_array", "sensors", "all"]: return common.unbatch(common.to_numpy(ret))