diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 8956587d78a..cb67d62d13f 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -41,7 +41,6 @@ from .blob import BlobFile from .dependencies import ( - _check_for_hugging_face, _check_for_numpy, torch, ) @@ -5259,15 +5258,6 @@ def write_dataset( else: data_storage_version = "stable" - if _check_for_hugging_face(data_obj): - # Huggingface datasets - from .dependencies import datasets - - if isinstance(data_obj, datasets.Dataset): - if schema is None: - schema = data_obj.features.arrow_schema - data_obj = data_obj.data.to_batches() - reader = _coerce_reader(data_obj, schema) _validate_schema(reader.schema) # TODO add support for passing in LanceDataset and LanceScanner here diff --git a/python/python/lance/types.py b/python/python/lance/types.py index deafd59af51..9fcefec9a83 100644 --- a/python/python/lance/types.py +++ b/python/python/lance/types.py @@ -9,7 +9,7 @@ from pyarrow import RecordBatch from . import dataset -from .dependencies import _check_for_pandas +from .dependencies import _check_for_hugging_face, _check_for_pandas from .dependencies import pandas as pd if TYPE_CHECKING: @@ -74,6 +74,37 @@ def _coerce_reader( and data_obj.__class__.__name__ == "DataFrame" ): return data_obj.to_arrow().to_reader() + elif _check_for_hugging_face(data_obj): + from .dependencies import datasets as hf_datasets + + if isinstance(data_obj, hf_datasets.Dataset): + if schema is None: + schema = data_obj.features.arrow_schema + return data_obj.data.to_reader() + elif isinstance(data_obj, hf_datasets.DatasetDict): + raise ValueError( + "DatasetDict is not yet supported. For now please " + "iterate through the DatasetDict and pass in single " + "Dataset instances (e.g., from dataset_dict.data) to " + "`write_dataset`. " + ) + elif isinstance(data_obj, hf_datasets.IterableDataset): + if schema is None: + schema = data_obj.features.arrow_schema + + def batch_iter(): + # Try to provide a reasonable batch size. If the user needs to + # override this, they can do the conversion to a reader themselves. + for dict_batch in data_obj.iter(batch_size=1000): + yield pa.RecordBatch.from_pydict(dict_batch, schema=schema) + + return pa.RecordBatchReader.from_batches(schema, batch_iter()) + else: + raise TypeError( + f"Unknown HuggingFace dataset type: {type(data_obj)}. " + "Please provide a single Dataset or DatasetDict." + ) + elif isinstance(data_obj, dict): batch = pa.RecordBatch.from_pydict(data_obj, schema=schema) return pa.RecordBatchReader.from_batches(batch.schema, [batch]) diff --git a/python/python/tests/test_huggingface.py b/python/python/tests/test_huggingface.py index e4f460d7116..24dca5c8174 100644 --- a/python/python/tests/test_huggingface.py +++ b/python/python/tests/test_huggingface.py @@ -5,6 +5,7 @@ import lance import numpy as np +import pyarrow as pa import pytest datasets = pytest.importorskip("datasets") @@ -45,3 +46,31 @@ def test_image_hf_dataset(tmp_path: Path): (isinstance(img, pil.Image.Image) and np.all(np.array(img) == 0)) for img in batch ) + + +def test_iterable_dataset(tmp_path: Path): + # IterableDataset yields dict of arrays + + def gen(): + yield {"text": "Good", "label": 0} + yield {"text": "Bad", "label": 1} + + arrow_schema = pa.schema([("text", pa.string()), ("label", pa.int64())]) + features = datasets.Features.from_arrow_schema(arrow_schema) + + iter_ds = datasets.IterableDataset.from_generator(gen, features=features) + # streaming batch size is controlled by max_rows_per_group + ds1 = lance.write_dataset(iter_ds, tmp_path / "ds1.lance") + assert ds1.count_rows() == 2 + assert ds1.schema == iter_ds.features.arrow_schema + + # to manually control streaming batch size + ds2 = lance.write_dataset( + pa.Table.from_arrays([[], []], schema=arrow_schema), tmp_path / "ds2.lance" + ) + for batch in iter_ds.iter(batch_size=1): + # shouldn't fail + ds2 = lance.write_dataset(batch, tmp_path / "ds2.lance", mode="append") + + assert len(ds1) == len(ds2) + assert ds1.schema == ds2.schema