Skip to content

Commit eda02fa

Browse files
committed
Skip test_visualize_local_dataset
1 parent 8546358 commit eda02fa

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

lerobot/scripts/visualize_dataset.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -100,24 +100,22 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
100100

101101

102102
def visualize_dataset(
103-
repo_id: str,
103+
dataset: LeRobotDataset,
104104
episode_index: int,
105105
batch_size: int = 32,
106106
num_workers: int = 0,
107107
mode: str = "local",
108108
web_port: int = 9090,
109109
ws_port: int = 9087,
110110
save: bool = False,
111-
root: Path | None = None,
112111
output_dir: Path | None = None,
113112
) -> Path | None:
114113
if save:
115114
assert (
116115
output_dir is not None
117116
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
118117

119-
logging.info("Loading dataset")
120-
dataset = LeRobotDataset(repo_id, root=root)
118+
repo_id = dataset.repo_id
121119

122120
logging.info("Loading dataloader")
123121
episode_sampler = EpisodeSampler(dataset, episode_index)
@@ -268,7 +266,14 @@ def main():
268266
)
269267

270268
args = parser.parse_args()
271-
visualize_dataset(**vars(args))
269+
kwargs = vars(args)
270+
repo_id = kwargs.pop("repo_id")
271+
root = kwargs.pop("root")
272+
273+
logging.info("Loading dataset")
274+
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
275+
276+
visualize_dataset(dataset, **vars(args))
272277

273278

274279
if __name__ == "__main__":

tests/test_visualize_dataset.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,21 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
from pathlib import Path
17-
1816
import pytest
1917

2018
from lerobot.scripts.visualize_dataset import visualize_dataset
2119

2220

23-
@pytest.mark.parametrize(
24-
"repo_id",
25-
["lerobot/pusht"],
26-
)
27-
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
28-
def test_visualize_local_dataset(tmpdir, repo_id, root):
21+
@pytest.mark.skip("TODO: add dummy videos")
22+
def test_visualize_local_dataset(tmp_path, lerobot_dataset_factory):
23+
root = tmp_path / "dataset"
24+
output_dir = tmp_path / "outputs"
25+
dataset = lerobot_dataset_factory(root=root)
2926
rrd_path = visualize_dataset(
30-
repo_id,
27+
dataset,
3128
episode_index=0,
3229
batch_size=32,
3330
save=True,
34-
output_dir=tmpdir,
35-
root=root,
31+
output_dir=output_dir,
3632
)
3733
assert rrd_path.exists()

0 commit comments

Comments
 (0)