Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Behavior cloning bugfixes #565

Merged
merged 2 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 66 additions & 29 deletions examples/baselines/bc/bc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections import defaultdict
import os
import random
from dataclasses import dataclass
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional

import gymnasium as gym
Expand All @@ -12,16 +12,18 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import tyro
import wandb
from mani_skill.utils import gym_utils
from mani_skill.utils.io_utils import load_json
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import BatchSampler, RandomSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from behavior_cloning.make_env import make_eval_envs

from behavior_cloning.evaluate import evaluate
from behavior_cloning.make_env import make_eval_envs


@dataclass
class Args:
exp_name: Optional[str] = None
Expand All @@ -43,7 +45,7 @@ class Args:

env_id: str = "PegInsertionSide-v0"
"""the id of the environment"""
demo_path: str = 'data/ms2_official_demos/rigid_body/PegInsertionSide-v0/trajectory.state.pd_ee_delta_pose.h5'
demo_path: str = "data/ms2_official_demos/rigid_body/PegInsertionSide-v0/trajectory.state.pd_ee_delta_pose.h5"
"""the path of demo dataset (pkl or h5)"""
num_demos: Optional[int] = None
"""number of trajectories to load from the demo dataset"""
Expand Down Expand Up @@ -76,12 +78,13 @@ class Args:
"""the simulation backend to use for evaluation environments. can be "cpu" or "gpu"""
num_dataload_workers: int = 0
"""the number of workers to use for loading the training data in the torch dataloader"""
control_mode: str = 'pd_joint_delta_pos'
control_mode: str = "pd_joint_delta_pos"
"""the control mode to use for the evaluation environments. Must match the control mode of the demonstration dataset."""

# additional tags/configs for logging purposes to wandb and shared comparisons with other algorithms
demo_type: Optional[str] = None


# taken from here
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py
class IterationBasedBatchSampler(BatchSampler):
Expand Down Expand Up @@ -203,10 +206,14 @@ def forward(self, state: torch.Tensor) -> torch.Tensor:


def save_ckpt(run_name, tag):
os.makedirs(f'runs/{run_name}/checkpoints', exist_ok=True)
torch.save({
"actor": actor.state_dict(),
}, f'runs/{run_name}/checkpoints/{tag}.pt')
os.makedirs(f"runs/{run_name}/checkpoints", exist_ok=True)
torch.save(
{
"actor": actor.state_dict(),
},
f"runs/{run_name}/checkpoints/{tag}.pt",
)


if __name__ == "__main__":
args = tyro.cli(Args)
Expand All @@ -217,18 +224,21 @@ def save_ckpt(run_name, tag):
else:
run_name = args.exp_name

if args.demo_path.endswith('.h5'):
if args.demo_path.endswith(".h5"):
import json
json_file = args.demo_path[:-2] + 'json'
with open(json_file, 'r') as f:

json_file = args.demo_path[:-2] + "json"
with open(json_file, "r") as f:
demo_info = json.load(f)
if 'control_mode' in demo_info['env_info']['env_kwargs']:
control_mode = demo_info['env_info']['env_kwargs']['control_mode']
elif 'control_mode' in demo_info['episodes'][0]:
control_mode = demo_info['episodes'][0]['control_mode']
if "control_mode" in demo_info["env_info"]["env_kwargs"]:
control_mode = demo_info["env_info"]["env_kwargs"]["control_mode"]
elif "control_mode" in demo_info["episodes"][0]:
control_mode = demo_info["episodes"][0]["control_mode"]
else:
raise Exception('Control mode not found in json')
assert control_mode == args.control_mode, f"Control mode mismatched. Dataset has control mode {control_mode}, but args has control mode {args.control_mode}"
raise Exception("Control mode not found in json")
assert (
control_mode == args.control_mode
), f"Control mode mismatched. Dataset has control mode {control_mode}, but args has control mode {args.control_mode}"

np.random.seed(args.seed)
random.seed(args.seed)
Expand All @@ -238,15 +248,32 @@ def save_ckpt(run_name, tag):
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
env_kwargs = dict(control_mode=args.control_mode, reward_mode="sparse", obs_mode="state", render_mode="rgb_array")
env_kwargs = dict(
control_mode=args.control_mode,
reward_mode="sparse",
obs_mode="state",
render_mode="rgb_array",
)
if args.max_episode_steps is not None:
env_kwargs["max_episode_steps"] = args.max_episode_steps
envs = make_eval_envs(args.env_id, args.num_eval_envs, args.sim_backend, env_kwargs, video_dir=f'runs/{run_name}/videos' if args.capture_video else None)
envs = make_eval_envs(
args.env_id,
args.num_eval_envs,
args.sim_backend,
env_kwargs,
video_dir=f"runs/{run_name}/videos" if args.capture_video else None,
)

if args.track:
import wandb

config = vars(args)
config["eval_env_cfg"] = dict(**env_kwargs, num_envs=args.num_eval_envs, env_id=args.env_id, env_horizon=gym_utils.find_max_episode_steps_value(envs))
config["eval_env_cfg"] = dict(
**env_kwargs,
num_envs=args.num_eval_envs,
env_id=args.env_id,
env_horizon=gym_utils.find_max_episode_steps_value(envs),
)
wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
Expand All @@ -255,12 +282,13 @@ def save_ckpt(run_name, tag):
name=run_name,
save_code=True,
group="BehaviorCloning",
tags=["behavior_cloning"]
tags=["behavior_cloning"],
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
"|param|value|\n|-|-|\n%s"
% ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

ds = ManiSkillDataset(
Expand All @@ -275,7 +303,9 @@ def save_ckpt(run_name, tag):
sampler = RandomSampler(ds)
batchsampler = BatchSampler(sampler, args.batch_size, drop_last=True)
itersampler = IterationBasedBatchSampler(batchsampler, args.total_iters)
dataloader = DataLoader(ds, batch_sampler=itersampler, num_workers=args.num_dataload_workers)
dataloader = DataLoader(
ds, batch_sampler=itersampler, num_workers=args.num_dataload_workers
)
actor = Actor(
envs.single_observation_space.shape[0], envs.single_action_space.shape[0]
)
Expand All @@ -296,18 +326,22 @@ def save_ckpt(run_name, tag):

if iteration % args.log_freq == 0:
print(f"Iteration {iteration}, loss: {loss.item()}")
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], iteration)
writer.add_scalar(
"charts/learning_rate", optimizer.param_groups[0]["lr"], iteration
)
writer.add_scalar("losses/total_loss", loss.item(), iteration)

if iteration % args.eval_freq == 0:
actor.eval()

def sample_fn(obs):
if isinstance(obs, np.ndarray):
obs = torch.from_numpy(obs).float().to(device)
action = actor(obs)
if args.sim_backend == "cpu":
action = action.cpu().numpy()
return action

with torch.no_grad():
eval_metrics = evaluate(args.num_eval_episodes, sample_fn, envs)
actor.train()
Expand All @@ -322,9 +356,12 @@ def sample_fn(obs):
if k in eval_metrics and eval_metrics[k] > best_eval_metrics[k]:
best_eval_metrics[k] = eval_metrics[k]
save_ckpt(run_name, f"best_eval_{k}")
print(f'New best {k}_rate: {eval_metrics[k]:.4f}. Saving checkpoint.')
print(
f"New best {k}_rate: {eval_metrics[k]:.4f}. Saving checkpoint."
)

if args.save_freq is not None and iteration % args.save_freq == 0:
save_ckpt(run_name, str(iteration))
envs.close()
wandb.finish()
if args.track:
wandb.finish()
Loading