Skip to content

Commit 8b94972

Browse files
committed
Merge remote-tracking branch 'origin/user/aliberts/2024_09_25_reshape_dataset' into user/rcadene/2024_11_01_examples_port_datasets
2 parents da67242 + c72ad49 commit 8b94972

14 files changed

+309
-145
lines changed

README.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@
5555

5656
<table>
5757
<tr>
58-
<td><img src="http://remicadene.com/assets/gif/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
59-
<td><img src="http://remicadene.com/assets/gif/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
60-
<td><img src="http://remicadene.com/assets/gif/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
58+
<td><img src="media/gym/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
59+
<td><img src="media/gym/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
60+
<td><img src="media/gym/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
6161
</tr>
6262
<tr>
6363
<td align="center">ACT policy on ALOHA env</td>
@@ -144,7 +144,7 @@ wandb login
144144

145145
### Visualize datasets
146146

147-
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically download data from the Hugging Face hub.
147+
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
148148

149149
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
150150
```bash
@@ -280,7 +280,7 @@ To use wandb for logging training and evaluation curves, make sure you've run `w
280280
wandb.enable=true
281281
```
282282

283-
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explaination of some commonly used metrics in logs.
283+
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs.
284284

285285
![](media/wandb.png)
286286

lerobot/common/datasets/compute_stats.py

+92
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,95 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
206206
)
207207
)
208208
return stats
209+
210+
211+
# TODO(aliberts): refactor stats in save_episodes
212+
# import numpy as np
213+
# from lerobot.common.datasets.utils import load_image_as_numpy
214+
# def aggregate_stats_v2(stats_list: list) -> dict:
215+
# """Aggregate stats from multiple compute_stats outputs into a single set of stats.
216+
217+
# The final stats will have the union of all data keys from each of the stats dicts.
218+
219+
# For instance:
220+
# - new_min = min(min_dataset_0, min_dataset_1, ...)
221+
# - new_max = max(max_dataset_0, max_dataset_1, ...)
222+
# - new_mean = (mean of all data, weighted by counts)
223+
# - new_std = (std of all data)
224+
# """
225+
# data_keys = set(key for stats in stats_list for key in stats.keys())
226+
# aggregated_stats = {key: {} for key in data_keys}
227+
228+
# for key in data_keys:
229+
# # Collect stats for the current key from all datasets where it exists
230+
# stats_with_key = [stats[key] for stats in stats_list if key in stats]
231+
232+
# # Aggregate 'min' and 'max' using np.minimum and np.maximum
233+
# aggregated_stats[key]['min'] = np.minimum.reduce([s['min'] for s in stats_with_key])
234+
# aggregated_stats[key]['max'] = np.maximum.reduce([s['max'] for s in stats_with_key])
235+
236+
# # Extract means, variances (std^2), and counts
237+
# means = np.array([s['mean'] for s in stats_with_key])
238+
# variances = np.array([s['std']**2 for s in stats_with_key])
239+
# counts = np.array([s['count'] for s in stats_with_key])
240+
241+
# # Ensure counts can broadcast with means/variances if they have additional dimensions
242+
# counts = counts.reshape(-1, *[1]*(means.ndim - 1))
243+
244+
# # Compute total counts
245+
# total_count = counts.sum(axis=0)
246+
247+
# # Compute the weighted mean
248+
# weighted_means = means * counts
249+
# total_mean = weighted_means.sum(axis=0) / total_count
250+
251+
# # Compute the variance using the parallel algorithm
252+
# delta_means = means - total_mean
253+
# weighted_variances = (variances + delta_means**2) * counts
254+
# total_variance = weighted_variances.sum(axis=0) / total_count
255+
256+
# # Store the aggregated stats
257+
# aggregated_stats[key]['mean'] = total_mean
258+
# aggregated_stats[key]['std'] = np.sqrt(total_variance)
259+
# aggregated_stats[key]['count'] = total_count
260+
261+
# return aggregated_stats
262+
263+
264+
# def compute_episode_stats(episode_buffer: dict, features: dict, episode_length: int, image_sampling: int = 10) -> dict:
265+
# stats = {}
266+
# for key, data in episode_buffer.items():
267+
# if features[key]["dtype"] in ["image", "video"]:
268+
# stats[key] = compute_image_stats(data, sampling=image_sampling)
269+
# else:
270+
# axes_to_reduce = 0 # Compute stats over the first axis
271+
# stats[key] = {
272+
# "min": np.min(data, axis=axes_to_reduce),
273+
# "max": np.max(data, axis=axes_to_reduce),
274+
# "mean": np.mean(data, axis=axes_to_reduce),
275+
# "std": np.std(data, axis=axes_to_reduce),
276+
# "count": episode_length,
277+
# }
278+
# return stats
279+
280+
281+
# def compute_image_stats(image_paths: list[str], sampling: int = 10) -> dict:
282+
# images = []
283+
# samples = range(0, len(image_paths), sampling)
284+
# for idx in samples:
285+
# path = image_paths[idx]
286+
# img = load_image_as_numpy(path, channel_first=True)
287+
# images.append(img)
288+
289+
# images = np.stack(images)
290+
# axes_to_reduce = (0, 2, 3) # keep channel dim
291+
# image_stats = {
292+
# "min": np.min(images, axis=axes_to_reduce, keepdims=True),
293+
# "max": np.max(images, axis=axes_to_reduce, keepdims=True),
294+
# "mean": np.mean(images, axis=axes_to_reduce, keepdims=True),
295+
# "std": np.std(images, axis=axes_to_reduce, keepdims=True)
296+
# }
297+
# for key in image_stats: # squeeze batch dim
298+
# image_stats[key] = np.squeeze(image_stats[key], axis=0)
299+
300+
# return image_stats

0 commit comments

Comments
 (0)