Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 50 additions & 63 deletions python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
RetryingPyFileSystem,
_check_pyarrow_version,
_is_local_scheme,
call_with_retry,
iterate_with_retry,
)
from ray.data.block import Block, BlockAccessor
Expand All @@ -33,10 +32,13 @@
from ray.data.datasource.datasource import ReadTask
from ray.data.datasource.file_based_datasource import FileShuffleConfig
from ray.data.datasource.file_meta_provider import (
DefaultFileMetadataProvider,
_handle_read_os_error,
_list_files,
)
from ray.data.datasource.parquet_meta_provider import (
ParquetFileMetadata,
ParquetMetadataProvider,
)
from ray.data.datasource.parquet_meta_provider import ParquetMetadataProvider
from ray.data.datasource.partitioning import (
PartitionDataType,
Partitioning,
Expand Down Expand Up @@ -101,12 +103,19 @@ class _SampleInfo:
estimated_bytes_per_row: Optional[int]


# TODO(ekl) this is a workaround for a pyarrow serialization bug, where serializing a
# raw pyarrow file fragment causes S3 network calls.
class SerializedFragment:
def __init__(self, frag: "ParquetFileFragment"):
self._data = cloudpickle.dumps(
(frag.format, frag.path, frag.filesystem, frag.partition_expression)
class _NoIOSerializableFragmentWrapper:
"""This is a workaround to avoid utilizing `ParquetFileFragment` original
serialization protocol that actually does network RPCs during serialization
(to fetch metadata)"""

def __init__(self, f: "ParquetFileFragment"):
self._fragment = f

def __reduce__(self):
return self._fragment.format.make_fragment, (
self._fragment.path,
self._fragment.filesystem,
self._fragment.partition_expression,
)

def deserialize(self) -> "ParquetFileFragment":
Expand All @@ -122,7 +131,7 @@ def deserialize(self) -> "ParquetFileFragment":

# Visible for test mocking.
def _deserialize_fragments(
serialized_fragments: List[SerializedFragment],
serialized_fragments: List[_NoIOSerializableFragmentWrapper],
) -> List["pyarrow._dataset.ParquetFileFragment"]:
return [p.deserialize() for p in serialized_fragments]

Expand Down Expand Up @@ -204,28 +213,24 @@ def __init__(
retryable_errors=DataContext.get_current().retried_io_errors,
)

# HACK: PyArrow's `ParquetDataset` errors if input paths contain non-parquet
# files. To avoid this, we expand the input paths with the default metadata
# provider and then apply the partition filter or file extensions.
if partition_filter is not None or file_extensions is not None:
default_meta_provider = DefaultFileMetadataProvider()
expanded_paths, _ = map(
list, zip(*default_meta_provider.expand_paths(paths, filesystem))
)

paths = list(expanded_paths)
if partition_filter is not None:
paths = partition_filter(paths)
if file_extensions is not None:
paths = [
path for path in paths if _has_file_extension(path, file_extensions)
]
listed_files = _list_files(
paths,
filesystem,
partition_filter=partition_filter,
file_extensions=file_extensions,
)

filtered_paths = set(expanded_paths) - set(paths)
if filtered_paths:
logger.info(f"Filtered out {len(filtered_paths)} paths")
if listed_files:
paths, file_sizes = zip(*listed_files)
else:
paths, file_sizes = [], []

if dataset_kwargs is None:
if dataset_kwargs is not None:
logger.warning(
"Please note that `ParquetDatasource.__init__`s `dataset_kwargs` "
"is a deprecated parameter and will be removed in the future."
)
else:
dataset_kwargs = {}

if "partitioning" in dataset_kwargs:
Expand All @@ -238,7 +243,9 @@ def __init__(
# duplicating the partition data, we disable PyArrow's partitioning.
dataset_kwargs["partitioning"] = None

pq_ds = get_parquet_dataset(paths, filesystem, dataset_kwargs)
# NOTE: ParquetDataset only accepts list of paths, hence we need to convert
# it to a list
pq_ds = get_parquet_dataset(list(paths), filesystem, dataset_kwargs)

# `read_schema` is the schema object that will be used to perform
# read operations.
Expand Down Expand Up @@ -277,12 +284,13 @@ def __init__(
"scheduling_strategy"
] = DataContext.get_current().scheduling_strategy

self._metadata = (
meta_provider.prefetch_file_metadata(
pq_ds.fragments, **prefetch_remote_args
self._metadata = [
ParquetFileMetadata(
num_bytes=num_bytes,
)
or []
)
for num_bytes in file_sizes
]

except OSError as e:
_handle_read_os_error(e, paths)

Expand All @@ -292,7 +300,9 @@ def __init__(
# NOTE: Store the custom serialized `ParquetFileFragment` to avoid unexpected
# network calls when `_ParquetDatasourceReader` is serialized. See
# `_SerializedFragment()` implementation for more details.
self._pq_fragments = [SerializedFragment(p) for p in pq_ds.fragments]
self._pq_fragments = [
_NoIOSerializableFragmentWrapper(p) for p in pq_ds.fragments
]
self._pq_paths = [p.path for p in pq_ds.fragments]
self._meta_provider = meta_provider
self._block_udf = _block_udf
Expand Down Expand Up @@ -332,7 +342,7 @@ def __init__(
def estimate_inmemory_data_size(self) -> Optional[int]:
total_size = 0
for file_metadata in self._metadata:
total_size += file_metadata.total_byte_size
total_size += file_metadata.num_bytes
return total_size * self._encoding_ratio

def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
Expand Down Expand Up @@ -443,18 +453,13 @@ def read_fragments(
data_columns,
partition_columns,
schema,
serialized_fragments: List[SerializedFragment],
fragments: List["ParquetFileFragment"],
include_paths: bool,
partitioning: Partitioning,
) -> Iterator["pyarrow.Table"]:
# This import is necessary to load the tensor extension type.
from ray.data.extensions.tensor_extension import ArrowTensorType # noqa

# Deserialize after loading the filesystem class.
fragments: List[
"pyarrow._dataset.ParquetFileFragment"
] = _deserialize_fragments_with_retry(serialized_fragments)

# Ensure that we're reading at least one dataset fragment.
assert len(fragments) > 0

Expand Down Expand Up @@ -510,30 +515,12 @@ def get_batch_iterable():
yield table


def _deserialize_fragments_with_retry(fragments):
# The deserialization retry helps when the upstream datasource is not able to
# handle overloaded read request or failed with some retriable failures.
# For example when reading data from HA hdfs service, hdfs might
# lose connection for some unknown reason expecially when
# simutaneously running many hyper parameter tuning jobs
# with ray.data parallelism setting at high value like the default 200
# Such connection failure can be restored with some waiting and retry.
return call_with_retry(
lambda: _deserialize_fragments(fragments),
description="deserialize fragments",
max_attempts=FILE_READING_RETRY,
)


def _sample_fragment(
to_batches_kwargs,
columns,
schema,
file_fragment: SerializedFragment,
fragment: "ParquetFileFragment",
) -> _SampleInfo:
# Sample the first rows batch from file fragment `serialized_fragment`.
fragment = _deserialize_fragments_with_retry([file_fragment])[0]

# If the fragment has no row groups, it's an empty or metadata-only file.
# Skip it by returning empty sample info.
if fragment.metadata.num_row_groups == 0:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

from ray.data._internal.logical.interfaces import LogicalOperator


Expand All @@ -15,6 +13,6 @@ class Count(LogicalOperator):

def __init__(
self,
input_dependencies: List["LogicalOperator"],
input_op: LogicalOperator,
):
super().__init__("Count", input_dependencies)
super().__init__("Count", [input_op])
5 changes: 4 additions & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3396,7 +3396,10 @@ def count(self) -> int:
return meta_count

plan = self._plan.copy()
count_op = Count([self._logical_plan.dag])

# NOTE: Project the dataset to avoid the need to carrying actual
# data when we're only interested in the total count
count_op = Count(Project(self._logical_plan.dag, cols=[]))
logical_plan = LogicalPlan(count_op, self.context)
count_ds = Dataset(plan, logical_plan)

Expand Down
43 changes: 42 additions & 1 deletion python/ray/data/datasource/file_meta_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import RetryingPyFileSystem
from ray.data.block import BlockMetadata
from ray.data.datasource.partitioning import Partitioning
from ray.data.datasource.partitioning import Partitioning, PathPartitionFilter
from ray.data.datasource.path_util import _has_file_extension
from ray.util.annotations import DeveloperAPI

if TYPE_CHECKING:
Expand Down Expand Up @@ -243,6 +244,46 @@ def _handle_read_os_error(error: OSError, paths: Union[str, List[str]]) -> str:
raise error


def _list_files(
paths: List[str],
filesystem: "RetryingPyFileSystem",
*,
partition_filter: Optional[PathPartitionFilter],
file_extensions: Optional[List[str]],
) -> List[Tuple[str, int]]:
return list(
_list_files_internal(
paths,
filesystem,
partition_filter=partition_filter,
file_extensions=file_extensions,
)
)


def _list_files_internal(
paths: List[str],
filesystem: "RetryingPyFileSystem",
*,
partition_filter: Optional[PathPartitionFilter],
file_extensions: Optional[List[str]],
) -> Iterator[Tuple[str, int]]:
default_meta_provider = DefaultFileMetadataProvider()

for path, file_size in default_meta_provider.expand_paths(paths, filesystem):
# HACK: PyArrow's `ParquetDataset` errors if input paths contain non-parquet
# files. To avoid this, we expand the input paths with the default metadata
# provider and then apply the partition filter or file extensions.
if (
partition_filter
and not partition_filter.apply(path)
or not _has_file_extension(path, file_extensions)
):
continue

yield path, file_size


def _expand_paths(
paths: List[str],
filesystem: "RetryingPyFileSystem",
Expand Down
Loading