diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 3c6396b58ff..fe4a0364a45 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -744,6 +744,27 @@ cdef class Dataset: result = self.dataset.NewScanWithContext(context) return ScannerBuilder.wrap(GetResultValue(result)) + def _scanner(self, columns=None, filter=None, use_threads=None, + MemoryPool memory_pool=None): + builder = self.new_scan(memory_pool) + if columns is not None: + builder.project(columns) + if filter is not None: + builder.filter(filter) + if use_threads is not None: + builder.use_threads(use_threads) + return builder.finish() + + def scan(self, columns=None, filter=None, use_threads=None, + MemoryPool memory_pool=None): + scanner = self._scanner(columns, filter, use_threads, memory_pool) + return scanner.scan() + + def to_table(self, columns=None, filter=None, use_threads=None, + MemoryPool memory_pool=None): + scanner = self._scanner(columns, filter, use_threads, memory_pool) + return scanner.to_table() + @property def sources(self): """List of the data sources""" @@ -775,7 +796,7 @@ cdef class ScanTask: @staticmethod cdef wrap(shared_ptr[CScanTask]& sp): - cdef SimpleScanTask self = SimpleScanTask.__new__(SimpleScanTask) + cdef ScanTask self = ScanTask.__new__(ScanTask) self.init(sp) return self @@ -806,17 +827,6 @@ cdef class ScanTask: yield pyarrow_wrap_batch(record_batch) -cdef class SimpleScanTask(ScanTask): - """A trivial ScanTask that yields the RecordBatch of an array.""" - - cdef: - CSimpleScanTask* simple_task - - cdef init(self, shared_ptr[CScanTask]& sp): - ScanTask.init(self, sp) - self.simple_task = sp.get() - - cdef class ScannerBuilder: """Factory class to construct a Scanner. diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py index c84b1782a20..8a6f15cae9a 100644 --- a/python/pyarrow/dataset.py +++ b/python/pyarrow/dataset.py @@ -20,10 +20,13 @@ from __future__ import absolute_import import sys +import functools if sys.version_info < (3,): raise ImportError("Python Dataset bindings require Python 3") +import pyarrow as pa +from pyarrow.fs import FileSelector, FileType, LocalFileSystem from pyarrow._dataset import ( # noqa AndExpression, CastExpression, @@ -31,6 +34,7 @@ ComparisonExpression, Dataset, DataSource, + DataSourceDiscovery, DefaultPartitionScheme, Expression, FieldExpression, @@ -53,3 +57,123 @@ SchemaPartitionScheme, TreeDataSource, ) + + +def partitioning(field_names=None, flavor=None): + if flavor is None: + if field_names is None: + return None # no partitioning + elif isinstance(field_names, pa.Schema): + return SchemaPartitionScheme(field_names) + elif isinstance(field_names, list): + return SchemaPartitionScheme.discover(field_names) + else: + raise ValueError('Either pass a schema or a list of field names') + elif flavor == 'hive': + if isinstance(field_names, pa.Schema): + return HivePartitionScheme(field_names) + elif isinstance(field_names, list): + raise ValueError('Not yet supported') + # # it would be nice to push down until the C++ implementation + # schema = HivePartitionScheme.discover() + # # limit the schema to have only the required fields + # schema = pa.schema([schema[name] for name in field_names]) + # # create the partitioning + elif field_names is None: + return HivePartitionScheme.discover() + else: + raise ValueError('Not yet supported') + else: + return None + + +def _ensure_partitioning(obj): + if isinstance(obj, (PartitionScheme, PartitionSchemeDiscovery)): + return obj + elif isinstance(obj, str): + return partitioning(flavor=obj) + else: + return partitioning(obj) + + +def _ensure_format(obj): + if isinstance(obj, FileFormat): + return obj + elif obj == "parquet": + return ParquetFileFormat() + else: + raise ValueError("format '{0}' is not supported".format(format)) + + +def _ensure_selector(fs, obj): + if isinstance(obj, str): + path = fs.get_target_stats([obj])[0] + if path.type == FileType.Directory: + # for directory, pass a selector + return FileSelector(obj, recursive=True) + else: + # is a single file path, pass it as a list + return [obj] + elif isinstance(obj, list): + assert all(isinstance(path) for path in obj) + return obj + else: + raise ValueError('Unsupported paths or selector') + + +def source(src, fs=None, partitioning=None, format=None, **options): + # src: path/paths/table + if isinstance(src, pa.Table): + raise NotImplementedError('InMemorySource is not yet supported') + + if fs is None: + # TODO handle other file systems + fs = LocalFileSystem() + + paths = _ensure_selector(fs, src) + format = _ensure_format(format) + partitioning = _ensure_partitioning(partitioning) + + options = FileSystemDiscoveryOptions(**options) + if isinstance(partitioning, PartitionSchemeDiscovery): + options.partition_scheme_discovery = partitioning + elif isinstance(partitioning, PartitionScheme): + options.partition_scheme = partitioning + + return FileSystemDataSourceDiscovery(fs, paths, format, options) + + +def _unify_schemas(schemas): + # calculate the subset of fields available in all schemas + keys_in_order = schemas[0].names + keys = set(keys_in_order).intersection(*[s.names for s in schemas[1:]]) + if not keys: + raise ValueError('No common fields found in ...') + + # create subschemas from each individual schema + schemas = [pa.schema(s.field(k) for k in keys_in_order) for s in schemas] + + # check that the subschemas' fields are equal except their additional + # key-value metadata + if any(schemas[0] != s for s in schemas): + raise ValueError('Schema fields are not equal ...') + + return schemas[0] + + +def dataset(sources, schema=None): + # DataSource has no schema, so we cannot check whether it is compatible + # with the rest of the DataSource objects or DataSourceDiscovery.Inspect() + # results. So limiting explusively to Discovery objects. + # TODO(kszucs): support list of filesystem uris + if isinstance(sources, DataSourceDiscovery): + return dataset([sources], schema=schema) + elif isinstance(sources, list): + assert all(isinstance(obj, DataSourceDiscovery) for obj in sources) + # trying to create a schema which all data source can coerce to + schema = _unify_schemas([discovery.inspect() for discovery in sources]) + # finalize the discovery objects to actual data sources + sources = [discovery.finish() for discovery in sources] + return Dataset(sources, schema=schema) + else: + raise ValueError('wrong usage') diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 17631b94d47..6fc5fda4d89 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -17,9 +17,14 @@ import pytest +import numpy as np +import pandas as pd + import pyarrow as pa import pyarrow.fs as fs +from pyarrow.fs import FileSelector, FileType, LocalFileSystem + try: import pyarrow.dataset as ds except ImportError: @@ -62,6 +67,96 @@ def mockfs(): return mockfs +def _generate_data(n): + import datetime + import itertools + + day = datetime.datetime(2000, 1, 1) + interval = datetime.timedelta(days=1) + colors = itertools.cycle(['green', 'blue', 'yellow', 'red', 'orange']) + + data = [] + for i in range(n): + data.append((day, i, float(i), next(colors))) + day += interval + + return pd.DataFrame(data, columns=['date', 'index', 'value', 'color']) + + +def _table_from_pandas(df): + schema = pa.schema([ + pa.field('date', pa.date32()), + pa.field('index', pa.int64()), + pa.field('value', pa.float64()), + pa.field('color', pa.string()), + ]) + table = pa.Table.from_pandas(df, schema=schema, preserve_index=False) + return table.replace_schema_metadata() + + +@pytest.fixture(scope='module') +@pytest.mark.parquet +def multisourcefs(): + import pyarrow.parquet as pq + + df = _generate_data(3000) + mockfs = fs._MockFileSystem() + + # simply split the dataframe into three chunks to construct a data source + # from each chunk into its own directory + df_a, df_b, df_c = np.array_split(df, 3) + + # create a directory containing a flat sequence of parquet files without + # any partitioning involved + mockfs.create_dir('plain') + for i, chunk in enumerate(np.array_split(df_a, 10)): + path = 'plain/chunk-{}.parquet'.format(i) + with mockfs.open_output_stream(path) as out: + pq.write_table(_table_from_pandas(chunk), out) + + # create one with schema partitioning by week and color + mockfs.create_dir('schema') + for part, chunk in df.groupby([df.date.dt.week, df.color]): + folder = 'schema/{}/{}'.format(*part) + path = '{}/chunk.parquet'.format(folder) + mockfs.create_dir(folder) + with mockfs.open_output_stream(path) as out: + pq.write_table(_table_from_pandas(chunk), out) + + # create one with hive partitioning by year and month + mockfs.create_dir('hive') + for part, chunk in df.groupby([df.date.dt.year, df.date.dt.month]): + folder = 'hive/year={}/month={}'.format(*part) + path = '{}/chunk.parquet'.format(folder) + mockfs.create_dir(folder) + with mockfs.open_output_stream(path) as out: + pq.write_table(_table_from_pandas(chunk), out) + + return mockfs + + +def test_multiple_sources(multisourcefs): + src1 = ds.source('/plain', fs=multisourcefs, format='parquet') + src2 = ds.source('/schema', fs=multisourcefs, format='parquet', + partitioning=['week', 'color']) + src3 = ds.source('/hive', fs=multisourcefs, format='parquet', + partitioning='hive') + + assembled = ds.dataset([src1, src2, src2]) + assert isinstance(assembled, ds.Dataset) + + expected_schema = pa.schema([ + pa.field('date', pa.date32()), + pa.field('index', pa.int64()), + pa.field('value', pa.float64()), + pa.field('color', pa.string()), + ]) + assert assembled.schema.equals(expected_schema) + + table = assembled.to_table() + print(table.to_pandas()) + + @pytest.fixture def dataset(mockfs): format = ds.ParquetFileFormat()