Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 6 additions & 33 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ test-act-ete-train:
--dataset.image_transforms.enable=true \
--dataset.episodes="[0]" \
--batch_size=2 \
--offline.steps=4 \
--online.steps=0 \
--steps=4 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_freq=2 \
Expand Down Expand Up @@ -76,8 +76,8 @@ test-diffusion-ete-train:
--dataset.image_transforms.enable=true \
--dataset.episodes="[0]" \
--batch_size=2 \
--offline.steps=2 \
--online.steps=0 \
--steps=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_checkpoint=true \
Expand Down Expand Up @@ -106,8 +106,8 @@ test-tdmpc-ete-train:
--dataset.image_transforms.enable=true \
--dataset.episodes="[0]" \
--batch_size=2 \
--offline.steps=2 \
--online.steps=0 \
--steps=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_checkpoint=true \
Expand All @@ -126,30 +126,3 @@ test-tdmpc-ete-eval:
--eval.n_episodes=1 \
--eval.batch_size=1 \
--device=$(DEVICE)

# TODO(rcadene): fix online buffer to storing "task"
# test-tdmpc-ete-train-with-online:
# python lerobot/scripts/train.py \
# --policy.type=tdmpc \
# --env.type=pusht \
# --env.obs_type=environment_state_agent_pos \
# --env.episode_length=5 \
# --dataset.repo_id=lerobot/pusht_keypoints \
# --dataset.image_transforms.enable=true \
# --dataset.episodes="[0]" \
# --batch_size=2 \
# --offline.steps=2 \
# --online.steps=20 \
# --online.rollout_n_episodes=2 \
# --online.rollout_batch_size=2 \
# --online.steps_between_rollouts=10 \
# --online.buffer_capacity=1000 \
# --online.env_seed=10000 \
# --save_checkpoint=false \
# --save_freq=10 \
# --log_freq=1 \
# --eval.use_async_envs=true \
# --eval.n_episodes=1 \
# --eval.batch_size=1 \
# --device=$(DEVICE) \
# --output_dir=tests/outputs/tdmpc_online/
3 changes: 1 addition & 2 deletions examples/3_train_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ def main():
while not done:
for batch in dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
output_dict = policy.forward(batch)
loss = output_dict["loss"]
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Expand Down
22 changes: 13 additions & 9 deletions examples/4_train_policy_with_script.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,26 +161,30 @@ python lerobot/scripts/train.py \
```
You should see from the logging that your training picks up from where it left off.

Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--offline.steps`, which is 100 000 by default.
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default.
You could double the number of steps of the previous run with:
```bash
python lerobot/scripts/train.py \
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
--resume=true \
--offline.steps=200000
--steps=200000
```

## Outputs of a run
In the output directory, there will be a folder called `checkpoints` with the following structure:
```bash
outputs/train/run_resumption/checkpoints
├── 000100 # checkpoint_dir for training step 100
│   ├── pretrained_model
│   │   ├── config.json # pretrained policy config
│   │   ├── model.safetensors # model weights
│   │   ├── train_config.json # train config
│ │ └── README.md # model card
│   └── training_state.pth # optimizer/scheduler/rng state and training step
│ ├── pretrained_model/
│ │ ├── config.json # policy config
│ │ ├── model.safetensors # policy weights
│ │ └── train_config.json # train config
│ └── training_state/
│ ├── optimizer_param_groups.json # optimizer param groups
│ ├── optimizer_state.safetensors # optimizer state
│ ├── rng_state.safetensors # rng states
│ ├── scheduler_state.json # scheduler state
│ └── training_step.json # training step
├── 000200
└── last -> 000200 # symlink to the last available checkpoint
```
Expand Down Expand Up @@ -250,7 +254,7 @@ python lerobot/scripts/train.py \
python lerobot/scripts/train.py \
--config_path=checkpoint/pretrained_model/ \
--resume=true \
--offline.steps=200000 # <- you can change some training parameters
--steps=200000 # <- you can change some training parameters
```

#### Fine-tuning
Expand Down
4 changes: 2 additions & 2 deletions examples/advanced/2_calculate_validation_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def main():
n_examples_evaluated = 0
for batch in val_dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
output_dict = policy.forward(batch)
loss, _ = policy.forward(batch)

loss_cumsum += output_dict["loss"].item()
loss_cumsum += loss.item()
n_examples_evaluated += batch["index"].shape[0]

# Calculate the average loss over the validation set.
Expand Down
11 changes: 11 additions & 0 deletions lerobot/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,14 @@
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
ACTION = "action"

# files & directories
CHECKPOINTS_DIR = "checkpoints"
LAST_CHECKPOINT_LINK = "last"
PRETRAINED_MODEL_DIR = "pretrained_model"
TRAINING_STATE_DIR = "training_state"
RNG_STATE = "rng_state.safetensors"
TRAINING_STEP = "training_step.json"
OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"
240 changes: 0 additions & 240 deletions lerobot/common/logger.py

This file was deleted.

Loading