@@ -243,6 +243,10 @@ def write_video_info(self) -> None:
243
243
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
244
244
been encoded the same way. Also, this means it assumes the first episode exists.
245
245
"""
246
+ # TODO(rcadene): What should we do here?
247
+ if "videos" not in self .info :
248
+ self .info ["videos" ] = {}
249
+
246
250
for key in self .video_keys :
247
251
if key not in self .info ["videos" ]:
248
252
video_path = self .root / self .get_video_file_path (ep_index = 0 , vid_key = key )
@@ -260,6 +264,8 @@ def create(
260
264
robot_type : str | None = None ,
261
265
features : dict | None = None ,
262
266
use_videos : bool = True ,
267
+ # tags: list[str] | None = None,
268
+ # license_type: str | None = None,
263
269
) -> "LeRobotDatasetMetadata" :
264
270
"""Creates metadata for a LeRobotDataset."""
265
271
obj = cls .__new__ (cls )
@@ -283,8 +289,12 @@ def create(
283
289
else :
284
290
features = {** features , ** DEFAULT_FEATURES }
285
291
292
+ # TODO(rcadene): implement sanity check for features
293
+
286
294
obj .tasks , obj .stats , obj .episodes = {}, {}, []
287
295
obj .info = create_empty_dataset_info (CODEBASE_VERSION , fps , robot_type , features , use_videos )
296
+ # obj.tags = tags
297
+ # obj.license_type = license_type
288
298
if len (obj .video_keys ) > 0 and not use_videos :
289
299
raise ValueError ()
290
300
write_json (obj .info , obj .root / INFO_PATH )
@@ -646,7 +656,7 @@ def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
646
656
current_ep_idx = self .meta .total_episodes if episode_index is None else episode_index
647
657
return {
648
658
"size" : 0 ,
649
- ** {key : [] if key != "episode_index" else current_ep_idx for key in self .features },
659
+ ** {key : current_ep_idx if key == "episode_index" else [] for key in self .features },
650
660
}
651
661
652
662
def _get_image_file_path (self , episode_index : int , image_key : str , frame_index : int ) -> Path :
@@ -669,6 +679,9 @@ def add_frame(self, frame: dict) -> None:
669
679
temporary directory — nothing is written to disk. To save those frames, the 'add_episode()' method
670
680
then needs to be called.
671
681
"""
682
+ # TODO(rcadene): Add sanity check for the input, check it's numpy or torch,
683
+ # check the dtype and shape matches, etc.
684
+
672
685
frame_index = self .episode_buffer ["size" ]
673
686
for key , ft in self .features .items ():
674
687
if key == "frame_index" :
@@ -705,6 +718,11 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None:
705
718
# TODO(aliberts): Add option to use existing episode_index
706
719
raise NotImplementedError ()
707
720
721
+ if episode_length == 0 :
722
+ raise ValueError (
723
+ "You must add one or several frames with `add_frame` before calling `add_episode`."
724
+ )
725
+
708
726
task_index = self .meta .get_task_index (task )
709
727
710
728
if not set (self .episode_buffer .keys ()) == set (self .features ):
@@ -719,11 +737,14 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None:
719
737
self .episode_buffer [key ] = np .full ((episode_length ,), episode_index )
720
738
elif key == "task_index" :
721
739
self .episode_buffer [key ] = np .full ((episode_length ,), task_index )
722
- elif ft ["dtype" ] in [ "image" , "video" ] :
740
+ elif ft ["dtype" ] == "image" :
723
741
continue
742
+ elif ft ["dtype" ] == "video" :
743
+ del self .episode_buffer [key ]
724
744
elif ft ["shape" ][0 ] == 1 :
725
745
self .episode_buffer [key ] = torch .tensor (self .episode_buffer [key ])
726
746
elif ft ["shape" ][0 ] > 1 :
747
+ # TODO(rcadene): why torch over here?
727
748
self .episode_buffer [key ] = torch .stack (self .episode_buffer [key ])
728
749
else :
729
750
raise ValueError ()
0 commit comments