diff --git a/python/python/lance/torch/data.py b/python/python/lance/torch/data.py index d5adcbbfe19..dc09cde3dc4 100644 --- a/python/python/lance/torch/data.py +++ b/python/python/lance/torch/data.py @@ -443,10 +443,7 @@ def __getitems__(self, indices): """ if self._ds is None: # Worker-process initialization - import os - - self._ds = lance.dataset(self.uri) - print(f"Worker {os.getpid()} initialized dataset") + self._ds = lance.dataset(self.uri, **self.dataset_options) # Leverage native batch reading batch = self._ds.take(indices) diff --git a/python/python/tests/torch_tests/test_data.py b/python/python/tests/torch_tests/test_data.py index 890a536cc9e..38b9439802f 100644 --- a/python/python/tests/torch_tests/test_data.py +++ b/python/python/tests/torch_tests/test_data.py @@ -12,7 +12,7 @@ from lance.sampler import ShardedBatchSampler, ShardedFragmentSampler torch = pytest.importorskip("torch") -from lance.torch.data import LanceDataset # noqa: E402 +from lance.torch.data import LanceDataset, SafeLanceDataset # noqa: E402 def test_iter_over_dataset_fixed_shape_tensor(tmp_path): @@ -324,3 +324,32 @@ def to_tensor_fn(batch, *args, **kwargs): assert first["int"].shape == (4,) assert first["val"].dtype == torch.uint8 assert first["val"].shape == (4, 100) + + +def test_safe_lance_dataset_worker_uses_dataset_options(tmp_path: Path): + """Worker processes must reopen the dataset with dataset_options. + + Regression test for: worker init called lance.dataset(uri) without + dataset_options, silently dropping version, storage_options, etc. + """ + tbl_v1 = pa.table({"id": pa.array([1, 2, 3], pa.int64())}) + ds = lance.write_dataset(tbl_v1, tmp_path / "data.lance") + version_1 = ds.version + + # Write a second version with different data so we can distinguish them. + tbl_v2 = pa.table({"id": pa.array([10, 20, 30], pa.int64())}) + lance.write_dataset(tbl_v2, tmp_path / "data.lance", mode="overwrite") + + # Pin to version 1 via dataset_options. + safe_ds = SafeLanceDataset( + str(tmp_path / "data.lance"), + dataset_options={"version": version_1}, + ) + + # Simulate worker-process state: _ds is None so __getitems__ must reopen. + safe_ds._ds = None + rows = safe_ds.__getitems__([0, 1, 2]) + + assert [r["id"] for r in rows] == [1, 2, 3], ( + "Worker reopened dataset without dataset_options (got version 2 data)" + )