Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lerobot/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import draccus
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(self, cfg: TrainPipelineConfig):
notes=cfg.wandb.notes,
tags=cfg_to_group(cfg, return_list=True),
dir=self.log_dir,
config=OmegaConf.to_container(cfg, resolve=True),
config=draccus.encode(cfg),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: remove

# TODO(rcadene): try set to True
save_code=False,
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
Expand Down Expand Up @@ -228,7 +228,7 @@ def log_dict(self, d: dict, step: int, mode: str = "train"):
def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"}
assert self._wandb is not None
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
wandb_video = self._wandb.Video(video_path, fps=self._cfg.env.fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)


Expand Down
2 changes: 1 addition & 1 deletion lerobot/common/policies/vqbet/configuration_vqbet.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class VQBeTConfig(PretrainedConfig):

normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"VISUAL": None,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
Expand Down
13 changes: 13 additions & 0 deletions lerobot/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ class OnlineConfig:
# + eval + environment rendering simultaneously.
do_rollout_async: bool = False

def __post_init__(self):
if self.steps == 0:
return

if self.steps_between_rollouts is None:
raise ValueError(
"'steps_between_rollouts' must be set to a positive integer, but it is currently None."
)
if self.env_seed is None:
raise ValueError("'env_seed' must be set to a positive integer, but it is currently None.")
if self.buffer_capacity is None:
raise ValueError("'buffer_capacity' must be set to a positive integer, but it is currently None.")


@dataclass
class TrainPipelineConfig:
Expand Down
51 changes: 34 additions & 17 deletions lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,9 @@ def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):
},
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
"next.done": {"shape": (), "dtype": np.dtype("?")},
"next.success": {"shape": (), "dtype": np.dtype("?")},
"task_index": {"shape": (), "dtype": np.dtype("int64")},
# Removed next.success, since it's not used anywhere for now and offline dataset doesnt have it
# "next.success": {"shape": (), "dtype": np.dtype("?")},
},
buffer_capacity=cfg.online.buffer_capacity,
fps=online_env.unwrapped.metadata["render_fps"],
Expand Down Expand Up @@ -400,12 +402,14 @@ def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):
)
dl_iter = cycle(dataloader)

# Lock and thread pool executor for asynchronous online rollouts. When asynchronous mode is disabled,
# these are still used but effectively do nothing.
lock = Lock()
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
# parallelization of rollouts is handled within the job.
executor = ThreadPoolExecutor(max_workers=1)
if cfg.online.do_rollout_async:
# Lock and thread pool executor for asynchronous online rollouts.
lock = Lock()
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
# parallelization of rollouts is handled within the job.
executor = ThreadPoolExecutor(max_workers=1)
else:
lock = None

online_step = 0
online_rollout_s = 0 # time take to do online rollout
Expand All @@ -424,10 +428,13 @@ def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):

def sample_trajectory_and_update_buffer():
nonlocal rollout_start_seed
with lock:

with lock if lock is not None else nullcontext():
online_rollout_policy.load_state_dict(policy.state_dict())

online_rollout_policy.eval()
start_rollout_time = time.perf_counter()

with torch.no_grad():
eval_info = eval_policy(
online_env,
Expand All @@ -440,7 +447,14 @@ def sample_trajectory_and_update_buffer():
)
online_rollout_s = time.perf_counter() - start_rollout_time

with lock:
if len(offline_dataset.meta.tasks) > 1:
raise NotImplementedError("Add support for multi task.")

# Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
total_num_frames = eval_info["episodes"]["index"].shape[0]
eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64)

with lock if lock is not None else nullcontext():
start_update_buffer_time = time.perf_counter()
online_dataset.add_data(eval_info["episodes"])

Expand All @@ -463,19 +477,22 @@ def sample_trajectory_and_update_buffer():

return online_rollout_s, update_online_buffer_s

future = executor.submit(sample_trajectory_and_update_buffer)
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
if not cfg.online.do_rollout_async or len(online_dataset) <= cfg.online.buffer_seed_size:
online_rollout_s, update_online_buffer_s = future.result()
if lock is None:
online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()
else:
future = executor.submit(sample_trajectory_and_update_buffer)
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
if len(online_dataset) <= cfg.online.buffer_seed_size:
online_rollout_s, update_online_buffer_s = future.result()

if len(online_dataset) <= cfg.online.buffer_seed_size:
logging.info(f"Seeding online buffer: {len(online_dataset)}/{cfg.online.buffer_seed_size}")
continue

policy.train()
for _ in range(cfg.online.steps_between_rollouts):
with lock:
with lock if lock is not None else nullcontext():
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
Expand All @@ -498,7 +515,7 @@ def sample_trajectory_and_update_buffer():
train_info["online_rollout_s"] = online_rollout_s
train_info["update_online_buffer_s"] = update_online_buffer_s
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
with lock:
with lock if lock is not None else nullcontext():
train_info["online_buffer_size"] = len(online_dataset)

if step % cfg.log_freq == 0:
Expand All @@ -513,7 +530,7 @@ def sample_trajectory_and_update_buffer():

# If we're doing async rollouts, we should now wait until we've completed them before proceeding
# to do the next batch of rollouts.
if future.running():
if cfg.online.do_rollout_async and future.running():
start = time.perf_counter()
online_rollout_s, update_online_buffer_s = future.result()
await_update_online_buffer_s = time.perf_counter() - start
Expand Down
Loading