Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

from .blob import BlobFile
from .dependencies import (
_check_for_hugging_face,
_check_for_numpy,
torch,
)
Expand Down Expand Up @@ -5259,15 +5258,6 @@ def write_dataset(
else:
data_storage_version = "stable"

if _check_for_hugging_face(data_obj):
# Huggingface datasets
from .dependencies import datasets

if isinstance(data_obj, datasets.Dataset):
if schema is None:
schema = data_obj.features.arrow_schema
data_obj = data_obj.data.to_batches()

reader = _coerce_reader(data_obj, schema)
_validate_schema(reader.schema)
# TODO add support for passing in LanceDataset and LanceScanner here
Expand Down
33 changes: 32 additions & 1 deletion python/python/lance/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pyarrow import RecordBatch

from . import dataset
from .dependencies import _check_for_pandas
from .dependencies import _check_for_hugging_face, _check_for_pandas
from .dependencies import pandas as pd

if TYPE_CHECKING:
Expand Down Expand Up @@ -74,6 +74,37 @@ def _coerce_reader(
and data_obj.__class__.__name__ == "DataFrame"
):
return data_obj.to_arrow().to_reader()
elif _check_for_hugging_face(data_obj):
from .dependencies import datasets as hf_datasets

if isinstance(data_obj, hf_datasets.Dataset):
if schema is None:
schema = data_obj.features.arrow_schema
return data_obj.data.to_reader()
elif isinstance(data_obj, hf_datasets.DatasetDict):
raise ValueError(
"DatasetDict is not yet supported. For now please "
"iterate through the DatasetDict and pass in single "
"Dataset instances (e.g., from dataset_dict.data) to "
"`write_dataset`. "
)
elif isinstance(data_obj, hf_datasets.IterableDataset):
if schema is None:
schema = data_obj.features.arrow_schema

def batch_iter():
# Try to provide a reasonable batch size. If the user needs to
# override this, they can do the conversion to a reader themselves.
for dict_batch in data_obj.iter(batch_size=1000):
yield pa.RecordBatch.from_pydict(dict_batch, schema=schema)

return pa.RecordBatchReader.from_batches(schema, batch_iter())
else:
raise TypeError(
f"Unknown HuggingFace dataset type: {type(data_obj)}. "
"Please provide a single Dataset or DatasetDict."
)

elif isinstance(data_obj, dict):
batch = pa.RecordBatch.from_pydict(data_obj, schema=schema)
return pa.RecordBatchReader.from_batches(batch.schema, [batch])
Expand Down
29 changes: 29 additions & 0 deletions python/python/tests/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import lance
import numpy as np
import pyarrow as pa
import pytest

datasets = pytest.importorskip("datasets")
Expand Down Expand Up @@ -45,3 +46,31 @@ def test_image_hf_dataset(tmp_path: Path):
(isinstance(img, pil.Image.Image) and np.all(np.array(img) == 0))
for img in batch
)


def test_iterable_dataset(tmp_path: Path):
# IterableDataset yields dict of arrays

def gen():
yield {"text": "Good", "label": 0}
yield {"text": "Bad", "label": 1}

arrow_schema = pa.schema([("text", pa.string()), ("label", pa.int64())])
features = datasets.Features.from_arrow_schema(arrow_schema)

iter_ds = datasets.IterableDataset.from_generator(gen, features=features)
# streaming batch size is controlled by max_rows_per_group
ds1 = lance.write_dataset(iter_ds, tmp_path / "ds1.lance")
assert ds1.count_rows() == 2
assert ds1.schema == iter_ds.features.arrow_schema

# to manually control streaming batch size
ds2 = lance.write_dataset(
pa.Table.from_arrays([[], []], schema=arrow_schema), tmp_path / "ds2.lance"
)
for batch in iter_ds.iter(batch_size=1):
# shouldn't fail
ds2 = lance.write_dataset(batch, tmp_path / "ds2.lance", mode="append")

assert len(ds1) == len(ds2)
assert ds1.schema == ds2.schema
Loading