File tree 2 files changed +17
-16
lines changed
2 files changed +17
-16
lines changed Original file line number Diff line number Diff line change @@ -100,24 +100,22 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
100
100
101
101
102
102
def visualize_dataset (
103
- repo_id : str ,
103
+ dataset : LeRobotDataset ,
104
104
episode_index : int ,
105
105
batch_size : int = 32 ,
106
106
num_workers : int = 0 ,
107
107
mode : str = "local" ,
108
108
web_port : int = 9090 ,
109
109
ws_port : int = 9087 ,
110
110
save : bool = False ,
111
- root : Path | None = None ,
112
111
output_dir : Path | None = None ,
113
112
) -> Path | None :
114
113
if save :
115
114
assert (
116
115
output_dir is not None
117
116
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
118
117
119
- logging .info ("Loading dataset" )
120
- dataset = LeRobotDataset (repo_id , root = root )
118
+ repo_id = dataset .repo_id
121
119
122
120
logging .info ("Loading dataloader" )
123
121
episode_sampler = EpisodeSampler (dataset , episode_index )
@@ -268,7 +266,14 @@ def main():
268
266
)
269
267
270
268
args = parser .parse_args ()
271
- visualize_dataset (** vars (args ))
269
+ kwargs = vars (args )
270
+ repo_id = kwargs .pop ("repo_id" )
271
+ root = kwargs .pop ("root" )
272
+
273
+ logging .info ("Loading dataset" )
274
+ dataset = LeRobotDataset (repo_id , root = root , local_files_only = True )
275
+
276
+ visualize_dataset (dataset , ** vars (args ))
272
277
273
278
274
279
if __name__ == "__main__" :
Original file line number Diff line number Diff line change 13
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
- from pathlib import Path
17
-
18
16
import pytest
19
17
20
18
from lerobot .scripts .visualize_dataset import visualize_dataset
21
19
22
20
23
- @pytest .mark .parametrize (
24
- "repo_id" ,
25
- ["lerobot/pusht" ],
26
- )
27
- @pytest .mark .parametrize ("root" , [Path (__file__ ).parent / "data" ])
28
- def test_visualize_local_dataset (tmpdir , repo_id , root ):
21
+ @pytest .mark .skip ("TODO: add dummy videos" )
22
+ def test_visualize_local_dataset (tmp_path , lerobot_dataset_factory ):
23
+ root = tmp_path / "dataset"
24
+ output_dir = tmp_path / "outputs"
25
+ dataset = lerobot_dataset_factory (root = root )
29
26
rrd_path = visualize_dataset (
30
- repo_id ,
27
+ dataset ,
31
28
episode_index = 0 ,
32
29
batch_size = 32 ,
33
30
save = True ,
34
- output_dir = tmpdir ,
35
- root = root ,
31
+ output_dir = output_dir ,
36
32
)
37
33
assert rrd_path .exists ()
You can’t perform that action at this time.
0 commit comments