Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,27 @@ cdef class Dataset:
result = self.dataset.NewScanWithContext(context)
return ScannerBuilder.wrap(GetResultValue(result))

def _scanner(self, columns=None, filter=None, use_threads=None,
MemoryPool memory_pool=None):
builder = self.new_scan(memory_pool)
if columns is not None:
builder.project(columns)
if filter is not None:
builder.filter(filter)
if use_threads is not None:
builder.use_threads(use_threads)
return builder.finish()

def scan(self, columns=None, filter=None, use_threads=None,
MemoryPool memory_pool=None):
scanner = self._scanner(columns, filter, use_threads, memory_pool)
return scanner.scan()

def to_table(self, columns=None, filter=None, use_threads=None,
MemoryPool memory_pool=None):
scanner = self._scanner(columns, filter, use_threads, memory_pool)
return scanner.to_table()

@property
def sources(self):
"""List of the data sources"""
Expand Down Expand Up @@ -775,7 +796,7 @@ cdef class ScanTask:

@staticmethod
cdef wrap(shared_ptr[CScanTask]& sp):
cdef SimpleScanTask self = SimpleScanTask.__new__(SimpleScanTask)
cdef ScanTask self = ScanTask.__new__(ScanTask)
self.init(sp)
return self

Expand Down Expand Up @@ -806,17 +827,6 @@ cdef class ScanTask:
yield pyarrow_wrap_batch(record_batch)


cdef class SimpleScanTask(ScanTask):
"""A trivial ScanTask that yields the RecordBatch of an array."""

cdef:
CSimpleScanTask* simple_task

cdef init(self, shared_ptr[CScanTask]& sp):
ScanTask.init(self, sp)
self.simple_task = <CSimpleScanTask*> sp.get()


cdef class ScannerBuilder:
"""Factory class to construct a Scanner.

Expand Down
124 changes: 124 additions & 0 deletions python/pyarrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,21 @@
from __future__ import absolute_import

import sys
import functools

if sys.version_info < (3,):
raise ImportError("Python Dataset bindings require Python 3")

import pyarrow as pa
from pyarrow.fs import FileSelector, FileType, LocalFileSystem
from pyarrow._dataset import ( # noqa
AndExpression,
CastExpression,
CompareOperator,
ComparisonExpression,
Dataset,
DataSource,
DataSourceDiscovery,
DefaultPartitionScheme,
Expression,
FieldExpression,
Expand All @@ -53,3 +57,123 @@
SchemaPartitionScheme,
TreeDataSource,
)


def partitioning(field_names=None, flavor=None):
if flavor is None:
if field_names is None:
return None # no partitioning
elif isinstance(field_names, pa.Schema):
return SchemaPartitionScheme(field_names)
elif isinstance(field_names, list):
return SchemaPartitionScheme.discover(field_names)
else:
raise ValueError('Either pass a schema or a list of field names')
elif flavor == 'hive':
if isinstance(field_names, pa.Schema):
return HivePartitionScheme(field_names)
elif isinstance(field_names, list):
raise ValueError('Not yet supported')
# # it would be nice to push down until the C++ implementation
# schema = HivePartitionScheme.discover()
# # limit the schema to have only the required fields
# schema = pa.schema([schema[name] for name in field_names])
# # create the partitioning
elif field_names is None:
return HivePartitionScheme.discover()
else:
raise ValueError('Not yet supported')
else:
return None


def _ensure_partitioning(obj):
if isinstance(obj, (PartitionScheme, PartitionSchemeDiscovery)):
return obj
elif isinstance(obj, str):
return partitioning(flavor=obj)
else:
return partitioning(obj)


def _ensure_format(obj):
if isinstance(obj, FileFormat):
return obj
elif obj == "parquet":
return ParquetFileFormat()
else:
raise ValueError("format '{0}' is not supported".format(format))


def _ensure_selector(fs, obj):
if isinstance(obj, str):
path = fs.get_target_stats([obj])[0]
if path.type == FileType.Directory:
# for directory, pass a selector
return FileSelector(obj, recursive=True)
else:
# is a single file path, pass it as a list
return [obj]
elif isinstance(obj, list):
assert all(isinstance(path) for path in obj)
return obj
else:
raise ValueError('Unsupported paths or selector')


def source(src, fs=None, partitioning=None, format=None, **options):
# src: path/paths/table
if isinstance(src, pa.Table):
raise NotImplementedError('InMemorySource is not yet supported')

if fs is None:
# TODO handle other file systems
fs = LocalFileSystem()

paths = _ensure_selector(fs, src)
format = _ensure_format(format)
partitioning = _ensure_partitioning(partitioning)

options = FileSystemDiscoveryOptions(**options)
if isinstance(partitioning, PartitionSchemeDiscovery):
options.partition_scheme_discovery = partitioning
elif isinstance(partitioning, PartitionScheme):
options.partition_scheme = partitioning

return FileSystemDataSourceDiscovery(fs, paths, format, options)


def _unify_schemas(schemas):
# calculate the subset of fields available in all schemas
keys_in_order = schemas[0].names
keys = set(keys_in_order).intersection(*[s.names for s in schemas[1:]])
if not keys:
raise ValueError('No common fields found in ...')

# create subschemas from each individual schema
schemas = [pa.schema(s.field(k) for k in keys_in_order) for s in schemas]

# check that the subschemas' fields are equal except their additional
# key-value metadata
if any(schemas[0] != s for s in schemas):
raise ValueError('Schema fields are not equal ...')

return schemas[0]


def dataset(sources, schema=None):
# DataSource has no schema, so we cannot check whether it is compatible
# with the rest of the DataSource objects or DataSourceDiscovery.Inspect()
# results. So limiting explusively to Discovery objects.
# TODO(kszucs): support list of filesystem uris
if isinstance(sources, DataSourceDiscovery):
return dataset([sources], schema=schema)
elif isinstance(sources, list):
assert all(isinstance(obj, DataSourceDiscovery) for obj in sources)
# trying to create a schema which all data source can coerce to
schema = _unify_schemas([discovery.inspect() for discovery in sources])
# finalize the discovery objects to actual data sources
sources = [discovery.finish() for discovery in sources]
return Dataset(sources, schema=schema)
else:
raise ValueError('wrong usage')
95 changes: 95 additions & 0 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

import pytest

import numpy as np
import pandas as pd

import pyarrow as pa
import pyarrow.fs as fs

from pyarrow.fs import FileSelector, FileType, LocalFileSystem

try:
import pyarrow.dataset as ds
except ImportError:
Expand Down Expand Up @@ -62,6 +67,96 @@ def mockfs():
return mockfs


def _generate_data(n):
import datetime
import itertools

day = datetime.datetime(2000, 1, 1)
interval = datetime.timedelta(days=1)
colors = itertools.cycle(['green', 'blue', 'yellow', 'red', 'orange'])

data = []
for i in range(n):
data.append((day, i, float(i), next(colors)))
day += interval

return pd.DataFrame(data, columns=['date', 'index', 'value', 'color'])


def _table_from_pandas(df):
schema = pa.schema([
pa.field('date', pa.date32()),
pa.field('index', pa.int64()),
pa.field('value', pa.float64()),
pa.field('color', pa.string()),
])
table = pa.Table.from_pandas(df, schema=schema, preserve_index=False)
return table.replace_schema_metadata()


@pytest.fixture(scope='module')
@pytest.mark.parquet
def multisourcefs():
import pyarrow.parquet as pq

df = _generate_data(3000)
mockfs = fs._MockFileSystem()

# simply split the dataframe into three chunks to construct a data source
# from each chunk into its own directory
df_a, df_b, df_c = np.array_split(df, 3)

# create a directory containing a flat sequence of parquet files without
# any partitioning involved
mockfs.create_dir('plain')
for i, chunk in enumerate(np.array_split(df_a, 10)):
path = 'plain/chunk-{}.parquet'.format(i)
with mockfs.open_output_stream(path) as out:
pq.write_table(_table_from_pandas(chunk), out)

# create one with schema partitioning by week and color
mockfs.create_dir('schema')
for part, chunk in df.groupby([df.date.dt.week, df.color]):
folder = 'schema/{}/{}'.format(*part)
path = '{}/chunk.parquet'.format(folder)
mockfs.create_dir(folder)
with mockfs.open_output_stream(path) as out:
pq.write_table(_table_from_pandas(chunk), out)

# create one with hive partitioning by year and month
mockfs.create_dir('hive')
for part, chunk in df.groupby([df.date.dt.year, df.date.dt.month]):
folder = 'hive/year={}/month={}'.format(*part)
path = '{}/chunk.parquet'.format(folder)
mockfs.create_dir(folder)
with mockfs.open_output_stream(path) as out:
pq.write_table(_table_from_pandas(chunk), out)

return mockfs


def test_multiple_sources(multisourcefs):
src1 = ds.source('/plain', fs=multisourcefs, format='parquet')
src2 = ds.source('/schema', fs=multisourcefs, format='parquet',
partitioning=['week', 'color'])
src3 = ds.source('/hive', fs=multisourcefs, format='parquet',
partitioning='hive')

assembled = ds.dataset([src1, src2, src2])
assert isinstance(assembled, ds.Dataset)

expected_schema = pa.schema([
pa.field('date', pa.date32()),
pa.field('index', pa.int64()),
pa.field('value', pa.float64()),
pa.field('color', pa.string()),
])
assert assembled.schema.equals(expected_schema)

table = assembled.to_table()
print(table.to_pandas())


@pytest.fixture
def dataset(mockfs):
format = ds.ParquetFileFormat()
Expand Down