diff --git a/python/ray/data/_internal/datasource/parquet_datasource.py b/python/ray/data/_internal/datasource/parquet_datasource.py index 44ccbd1e55b2..9589036aefd4 100644 --- a/python/ray/data/_internal/datasource/parquet_datasource.py +++ b/python/ray/data/_internal/datasource/parquet_datasource.py @@ -24,7 +24,6 @@ RetryingPyFileSystem, _check_pyarrow_version, _is_local_scheme, - call_with_retry, iterate_with_retry, ) from ray.data.block import Block, BlockAccessor @@ -33,10 +32,13 @@ from ray.data.datasource.datasource import ReadTask from ray.data.datasource.file_based_datasource import FileShuffleConfig from ray.data.datasource.file_meta_provider import ( - DefaultFileMetadataProvider, _handle_read_os_error, + _list_files, +) +from ray.data.datasource.parquet_meta_provider import ( + ParquetFileMetadata, + ParquetMetadataProvider, ) -from ray.data.datasource.parquet_meta_provider import ParquetMetadataProvider from ray.data.datasource.partitioning import ( PartitionDataType, Partitioning, @@ -101,12 +103,19 @@ class _SampleInfo: estimated_bytes_per_row: Optional[int] -# TODO(ekl) this is a workaround for a pyarrow serialization bug, where serializing a -# raw pyarrow file fragment causes S3 network calls. -class SerializedFragment: - def __init__(self, frag: "ParquetFileFragment"): - self._data = cloudpickle.dumps( - (frag.format, frag.path, frag.filesystem, frag.partition_expression) +class _NoIOSerializableFragmentWrapper: + """This is a workaround to avoid utilizing `ParquetFileFragment` original + serialization protocol that actually does network RPCs during serialization + (to fetch metadata)""" + + def __init__(self, f: "ParquetFileFragment"): + self._fragment = f + + def __reduce__(self): + return self._fragment.format.make_fragment, ( + self._fragment.path, + self._fragment.filesystem, + self._fragment.partition_expression, ) def deserialize(self) -> "ParquetFileFragment": @@ -122,7 +131,7 @@ def deserialize(self) -> "ParquetFileFragment": # Visible for test mocking. def _deserialize_fragments( - serialized_fragments: List[SerializedFragment], + serialized_fragments: List[_NoIOSerializableFragmentWrapper], ) -> List["pyarrow._dataset.ParquetFileFragment"]: return [p.deserialize() for p in serialized_fragments] @@ -204,28 +213,24 @@ def __init__( retryable_errors=DataContext.get_current().retried_io_errors, ) - # HACK: PyArrow's `ParquetDataset` errors if input paths contain non-parquet - # files. To avoid this, we expand the input paths with the default metadata - # provider and then apply the partition filter or file extensions. - if partition_filter is not None or file_extensions is not None: - default_meta_provider = DefaultFileMetadataProvider() - expanded_paths, _ = map( - list, zip(*default_meta_provider.expand_paths(paths, filesystem)) - ) - - paths = list(expanded_paths) - if partition_filter is not None: - paths = partition_filter(paths) - if file_extensions is not None: - paths = [ - path for path in paths if _has_file_extension(path, file_extensions) - ] + listed_files = _list_files( + paths, + filesystem, + partition_filter=partition_filter, + file_extensions=file_extensions, + ) - filtered_paths = set(expanded_paths) - set(paths) - if filtered_paths: - logger.info(f"Filtered out {len(filtered_paths)} paths") + if listed_files: + paths, file_sizes = zip(*listed_files) + else: + paths, file_sizes = [], [] - if dataset_kwargs is None: + if dataset_kwargs is not None: + logger.warning( + "Please note that `ParquetDatasource.__init__`s `dataset_kwargs` " + "is a deprecated parameter and will be removed in the future." + ) + else: dataset_kwargs = {} if "partitioning" in dataset_kwargs: @@ -238,7 +243,9 @@ def __init__( # duplicating the partition data, we disable PyArrow's partitioning. dataset_kwargs["partitioning"] = None - pq_ds = get_parquet_dataset(paths, filesystem, dataset_kwargs) + # NOTE: ParquetDataset only accepts list of paths, hence we need to convert + # it to a list + pq_ds = get_parquet_dataset(list(paths), filesystem, dataset_kwargs) # `read_schema` is the schema object that will be used to perform # read operations. @@ -277,12 +284,13 @@ def __init__( "scheduling_strategy" ] = DataContext.get_current().scheduling_strategy - self._metadata = ( - meta_provider.prefetch_file_metadata( - pq_ds.fragments, **prefetch_remote_args + self._metadata = [ + ParquetFileMetadata( + num_bytes=num_bytes, ) - or [] - ) + for num_bytes in file_sizes + ] + except OSError as e: _handle_read_os_error(e, paths) @@ -292,7 +300,9 @@ def __init__( # NOTE: Store the custom serialized `ParquetFileFragment` to avoid unexpected # network calls when `_ParquetDatasourceReader` is serialized. See # `_SerializedFragment()` implementation for more details. - self._pq_fragments = [SerializedFragment(p) for p in pq_ds.fragments] + self._pq_fragments = [ + _NoIOSerializableFragmentWrapper(p) for p in pq_ds.fragments + ] self._pq_paths = [p.path for p in pq_ds.fragments] self._meta_provider = meta_provider self._block_udf = _block_udf @@ -332,7 +342,7 @@ def __init__( def estimate_inmemory_data_size(self) -> Optional[int]: total_size = 0 for file_metadata in self._metadata: - total_size += file_metadata.total_byte_size + total_size += file_metadata.num_bytes return total_size * self._encoding_ratio def get_read_tasks(self, parallelism: int) -> List[ReadTask]: @@ -443,18 +453,13 @@ def read_fragments( data_columns, partition_columns, schema, - serialized_fragments: List[SerializedFragment], + fragments: List["ParquetFileFragment"], include_paths: bool, partitioning: Partitioning, ) -> Iterator["pyarrow.Table"]: # This import is necessary to load the tensor extension type. from ray.data.extensions.tensor_extension import ArrowTensorType # noqa - # Deserialize after loading the filesystem class. - fragments: List[ - "pyarrow._dataset.ParquetFileFragment" - ] = _deserialize_fragments_with_retry(serialized_fragments) - # Ensure that we're reading at least one dataset fragment. assert len(fragments) > 0 @@ -510,30 +515,12 @@ def get_batch_iterable(): yield table -def _deserialize_fragments_with_retry(fragments): - # The deserialization retry helps when the upstream datasource is not able to - # handle overloaded read request or failed with some retriable failures. - # For example when reading data from HA hdfs service, hdfs might - # lose connection for some unknown reason expecially when - # simutaneously running many hyper parameter tuning jobs - # with ray.data parallelism setting at high value like the default 200 - # Such connection failure can be restored with some waiting and retry. - return call_with_retry( - lambda: _deserialize_fragments(fragments), - description="deserialize fragments", - max_attempts=FILE_READING_RETRY, - ) - - def _sample_fragment( to_batches_kwargs, columns, schema, - file_fragment: SerializedFragment, + fragment: "ParquetFileFragment", ) -> _SampleInfo: - # Sample the first rows batch from file fragment `serialized_fragment`. - fragment = _deserialize_fragments_with_retry([file_fragment])[0] - # If the fragment has no row groups, it's an empty or metadata-only file. # Skip it by returning empty sample info. if fragment.metadata.num_row_groups == 0: diff --git a/python/ray/data/_internal/logical/operators/count_operator.py b/python/ray/data/_internal/logical/operators/count_operator.py index 409c99e3c000..39ec706f7e50 100644 --- a/python/ray/data/_internal/logical/operators/count_operator.py +++ b/python/ray/data/_internal/logical/operators/count_operator.py @@ -1,5 +1,3 @@ -from typing import List - from ray.data._internal.logical.interfaces import LogicalOperator @@ -15,6 +13,6 @@ class Count(LogicalOperator): def __init__( self, - input_dependencies: List["LogicalOperator"], + input_op: LogicalOperator, ): - super().__init__("Count", input_dependencies) + super().__init__("Count", [input_op]) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 079ac1f0fcf7..7b96cf7a3051 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -3396,7 +3396,10 @@ def count(self) -> int: return meta_count plan = self._plan.copy() - count_op = Count([self._logical_plan.dag]) + + # NOTE: Project the dataset to avoid the need to carrying actual + # data when we're only interested in the total count + count_op = Count(Project(self._logical_plan.dag, cols=[])) logical_plan = LogicalPlan(count_op, self.context) count_ds = Dataset(plan, logical_plan) diff --git a/python/ray/data/datasource/file_meta_provider.py b/python/ray/data/datasource/file_meta_provider.py index ed5ae26f903c..354f761d9651 100644 --- a/python/ray/data/datasource/file_meta_provider.py +++ b/python/ray/data/datasource/file_meta_provider.py @@ -20,7 +20,8 @@ from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.util import RetryingPyFileSystem from ray.data.block import BlockMetadata -from ray.data.datasource.partitioning import Partitioning +from ray.data.datasource.partitioning import Partitioning, PathPartitionFilter +from ray.data.datasource.path_util import _has_file_extension from ray.util.annotations import DeveloperAPI if TYPE_CHECKING: @@ -243,6 +244,46 @@ def _handle_read_os_error(error: OSError, paths: Union[str, List[str]]) -> str: raise error +def _list_files( + paths: List[str], + filesystem: "RetryingPyFileSystem", + *, + partition_filter: Optional[PathPartitionFilter], + file_extensions: Optional[List[str]], +) -> List[Tuple[str, int]]: + return list( + _list_files_internal( + paths, + filesystem, + partition_filter=partition_filter, + file_extensions=file_extensions, + ) + ) + + +def _list_files_internal( + paths: List[str], + filesystem: "RetryingPyFileSystem", + *, + partition_filter: Optional[PathPartitionFilter], + file_extensions: Optional[List[str]], +) -> Iterator[Tuple[str, int]]: + default_meta_provider = DefaultFileMetadataProvider() + + for path, file_size in default_meta_provider.expand_paths(paths, filesystem): + # HACK: PyArrow's `ParquetDataset` errors if input paths contain non-parquet + # files. To avoid this, we expand the input paths with the default metadata + # provider and then apply the partition filter or file extensions. + if ( + partition_filter + and not partition_filter.apply(path) + or not _has_file_extension(path, file_extensions) + ): + continue + + yield path, file_size + + def _expand_paths( paths: List[str], filesystem: "RetryingPyFileSystem", diff --git a/python/ray/data/datasource/parquet_meta_provider.py b/python/ray/data/datasource/parquet_meta_provider.py index 73a3c41ef6e2..25b33f27411e 100644 --- a/python/ray/data/datasource/parquet_meta_provider.py +++ b/python/ray/data/datasource/parquet_meta_provider.py @@ -1,6 +1,7 @@ -from typing import TYPE_CHECKING, List, Optional +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Tuple -import ray.cloudpickle as cloudpickle from ray.data._internal.util import call_with_retry from ray.data.block import BlockMetadata from ray.data.datasource.file_meta_provider import ( @@ -11,8 +12,7 @@ if TYPE_CHECKING: import pyarrow - - from ray.data._internal.datasource.parquet_datasource import SerializedFragment + from pyarrow.dataset import ParquetFileFragment FRAGMENTS_PER_META_FETCH = 6 @@ -28,35 +28,21 @@ RETRY_MAX_BACKOFF_S_FOR_META_FETCH_TASK = 64 -class _ParquetFileFragmentMetaData: - """Class to store metadata of a Parquet file fragment. This includes - all attributes from `pyarrow.parquet.FileMetaData` except for `schema`, - which is stored in `self.schema_pickled` as a pickled object from - `cloudpickle.loads()`, used in deduplicating schemas across multiple fragments.""" - - def __init__(self, fragment_metadata: "pyarrow.parquet.FileMetaData"): - self.created_by = fragment_metadata.created_by - self.format_version = fragment_metadata.format_version - self.num_columns = fragment_metadata.num_columns - self.num_row_groups = fragment_metadata.num_row_groups - self.num_rows = fragment_metadata.num_rows - self.serialized_size = fragment_metadata.serialized_size +logger = logging.getLogger(__name__) - # Serialize the schema directly in the constructor - schema_ser = cloudpickle.dumps(fragment_metadata.schema.to_arrow_schema()) - self.schema_pickled = schema_ser - # Calculate the total byte size of the file fragment using the original - # object, as it is not possible to access row groups from this class. - self.total_byte_size = 0 - for row_group_idx in range(fragment_metadata.num_row_groups): - row_group_metadata = fragment_metadata.row_group(row_group_idx) - self.total_byte_size += row_group_metadata.total_byte_size +@DeveloperAPI(stability="alpha") +@dataclass +class ParquetFileMetadata: + num_bytes: int + num_rows: Optional[int] = field(default=None) - def set_schema_pickled(self, schema_pickled: bytes): - """Note: to get the underlying schema, use - `cloudpickle.loads(self.schema_pickled)`.""" - self.schema_pickled = schema_pickled + @classmethod + def from_(cls, pqm: "pyarrow.parquet.FileMetaData"): + return ParquetFileMetadata( + num_rows=pqm.num_rows, + num_bytes=_get_total_bytes(pqm), + ) @DeveloperAPI @@ -68,7 +54,7 @@ def _get_block_metadata( paths: List[str], *, num_fragments: int, - prefetched_metadata: Optional[List["_ParquetFileFragmentMetaData"]], + prefetched_metadata: Optional[List["ParquetFileMetadata"]], ) -> BlockMetadata: """Resolves and returns block metadata for files of a single dataset block. @@ -88,11 +74,13 @@ def _get_block_metadata( and len(prefetched_metadata) == num_fragments and all(m is not None for m in prefetched_metadata) ): + total_bytes, total_rows = self._derive_totals(prefetched_metadata) + # Fragment metadata was available, construct a normal # BlockMetadata. block_metadata = BlockMetadata( - num_rows=sum(m.num_rows for m in prefetched_metadata), - size_bytes=sum(m.total_byte_size for m in prefetched_metadata), + num_rows=total_rows, + size_bytes=total_bytes, input_files=paths, exec_stats=None, ) # Exec stats filled in later. @@ -107,11 +95,29 @@ def _get_block_metadata( ) return block_metadata + @staticmethod + def _derive_totals( + prefetched_metadata: List["ParquetFileMetadata"], + ) -> Tuple[int, int]: + total_bytes = 0 + total_rows = 0 + + for m in prefetched_metadata: + total_bytes += m.num_bytes + + if total_rows is not None: + if m.num_rows is not None: + total_rows += m.num_rows + else: + total_rows = None + + return total_bytes, total_rows + def prefetch_file_metadata( self, fragments: List["pyarrow.dataset.ParquetFileFragment"], **ray_remote_args, - ) -> Optional[List[_ParquetFileFragmentMetaData]]: + ) -> Optional[List[ParquetFileMetadata]]: """Pre-fetches file metadata for all Parquet file fragments in a single batch. Subsets of the metadata returned will be provided as input to subsequent calls @@ -126,15 +132,18 @@ def prefetch_file_metadata( must be returned in the same order as all input file fragments, such that `metadata[i]` always contains the metadata for `fragments[i]`. """ - from ray.data._internal.datasource.parquet_datasource import SerializedFragment + from ray.data._internal.datasource.parquet_datasource import ( + _NoIOSerializableFragmentWrapper, + ) if len(fragments) > PARALLELIZE_META_FETCH_THRESHOLD: # Wrap Parquet fragments in serialization workaround. - fragments = [SerializedFragment(fragment) for fragment in fragments] + fragments = [ + _NoIOSerializableFragmentWrapper(fragment) for fragment in fragments + ] # Fetch Parquet metadata in parallel using Ray tasks. - - def fetch_func(fragments): - return _fetch_metadata_serialization_wrapper( + def _remote_fetch(fragments: List["ParquetFileFragment"]): + return _fetch_metadata_with_retry( fragments, # Ensure that retry settings are propagated to remote tasks. retry_match=RETRY_EXCEPTIONS_FOR_META_FETCH_TASK, @@ -145,13 +154,13 @@ def fetch_func(fragments): raw_metadata = list( _fetch_metadata_parallel( fragments, - fetch_func, + _remote_fetch, FRAGMENTS_PER_META_FETCH, **ray_remote_args, ) ) - return _dedupe_schemas(raw_metadata) + return raw_metadata else: # We don't deduplicate schemas in this branch because they're already @@ -162,20 +171,15 @@ def fetch_func(fragments): return raw_metadata -def _fetch_metadata_serialization_wrapper( - fragments: List["SerializedFragment"], +def _fetch_metadata_with_retry( + fragments: List["ParquetFileFragment"], retry_match: Optional[List[str]], retry_max_attempts: int, retry_max_interval: int, -) -> List["_ParquetFileFragmentMetaData"]: - from ray.data._internal.datasource.parquet_datasource import ( - _deserialize_fragments_with_retry, - ) - - deserialized_fragments = _deserialize_fragments_with_retry(fragments) +) -> List["ParquetFileMetadata"]: try: metadata = call_with_retry( - lambda: _fetch_metadata(deserialized_fragments), + lambda: _fetch_metadata(fragments), description="fetch metdata", match=retry_match, max_attempts=retry_max_attempts, @@ -215,53 +219,18 @@ def _fetch_metadata_serialization_wrapper( def _fetch_metadata( fragments: List["pyarrow.dataset.ParquetFileFragment"], -) -> List[_ParquetFileFragmentMetaData]: +) -> List["ParquetFileMetadata"]: fragment_metadatas = [] for f in fragments: try: # Convert directly to _ParquetFileFragmentMetaData - fragment_metadatas.append(_ParquetFileFragmentMetaData(f.metadata)) - except AttributeError: + fragment_metadatas.append(ParquetFileMetadata.from_(f.metadata)) + except AttributeError as ae: + logger.warning(f"Failed to extract metadata from parquet file: {ae}") break # Deduplicate schemas to reduce memory usage - return _dedupe_schemas(fragment_metadatas) - - -def _dedupe_schemas( - metadatas: List[_ParquetFileFragmentMetaData], -) -> List[_ParquetFileFragmentMetaData]: - """Deduplicates schema objects across existing _ParquetFileFragmentMetaData objects. - - For datasets with a large number of columns, the pickled schema can be very large. - This function reduces memory usage by ensuring that identical schemas across multiple - fragment metadata objects reference the same underlying pickled schema object, - rather than each fragment maintaining its own copy. - - Args: - metadatas: List of _ParquetFileFragmentMetaData objects that already have - pickled schemas set. - - Returns: - The same list of _ParquetFileFragmentMetaData objects, but with duplicate - schemas deduplicated to reference the same object in memory. - """ - schema_to_id = {} # schema_ser -> schema_id - id_to_schema = {} # schema_id -> schema_ser - - for metadata in metadatas: - # Get the current schema serialization - schema_ser = metadata.schema_pickled - - if schema_ser not in schema_to_id: - # This is a new unique schema - schema_id = len(schema_to_id) - schema_to_id[schema_ser] = schema_id - id_to_schema[schema_id] = schema_ser - # No need to set schema_pickled - it already has the correct value - else: - # This schema already exists, reuse the existing one - schema_id = schema_to_id[schema_ser] - existing_schema_ser = id_to_schema[schema_id] - metadata.set_schema_pickled(existing_schema_ser) + return fragment_metadatas + - return metadatas +def _get_total_bytes(pqm: "pyarrow.parquet.FileMetaData") -> int: + return sum(pqm.row_group(i).total_byte_size for i in range(pqm.num_row_groups)) diff --git a/python/ray/data/datasource/partitioning.py b/python/ray/data/datasource/partitioning.py index 2d83fe6b67de..20a626b09bce 100644 --- a/python/ray/data/datasource/partitioning.py +++ b/python/ray/data/datasource/partitioning.py @@ -434,11 +434,12 @@ def __call__(self, paths: List[str]) -> List[str]: """ filtered_paths = paths if self._filter_fn is not None: - filtered_paths = [ - path for path in paths if self._filter_fn(self._parser(path)) - ] + filtered_paths = [path for path in paths if self.apply(path)] return filtered_paths + def apply(self, path: str) -> bool: + return self._filter_fn(self._parser(path)) + @property def parser(self) -> PathPartitionParser: """Returns the path partition parser for this filter.""" diff --git a/python/ray/data/datasource/path_util.py b/python/ray/data/datasource/path_util.py index 6498300caa9f..5d9527243f36 100644 --- a/python/ray/data/datasource/path_util.py +++ b/python/ray/data/datasource/path_util.py @@ -39,7 +39,7 @@ def _has_file_extension(path: str, extensions: Optional[List[str]]) -> bool: def _resolve_paths_and_filesystem( paths: Union[str, List[str]], - filesystem: "pyarrow.fs.FileSystem" = None, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, ) -> Tuple[List[str], "pyarrow.fs.FileSystem"]: """ Resolves and normalizes all provided paths, infers a filesystem from the @@ -69,7 +69,7 @@ def _resolve_paths_and_filesystem( elif not isinstance(paths, list) or any(not isinstance(p, str) for p in paths): raise ValueError( "Expected `paths` to be a `str`, `pathlib.Path`, or `list[str]`, but got " - f"`{paths}`." + f"`{paths}`" ) elif len(paths) == 0: raise ValueError("Must provide at least one path.") diff --git a/python/ray/data/tests/test_consumption.py b/python/ray/data/tests/test_consumption.py index c5b4fd3f5b65..c8437d5aaa22 100644 --- a/python/ray/data/tests/test_consumption.py +++ b/python/ray/data/tests/test_consumption.py @@ -1863,7 +1863,7 @@ def test_dataset_plan_as_string(ray_start_cluster): ds = ray.data.read_parquet("example://iris.parquet", override_num_blocks=8) assert ds._plan.get_plan_as_string(type(ds)) == ( "Dataset(\n" - " num_rows=150,\n" + " num_rows=?,\n" " schema={\n" " sepal.length: double,\n" " sepal.width: double,\n" @@ -1882,7 +1882,7 @@ def test_dataset_plan_as_string(ray_start_cluster): " +- MapBatches()\n" " +- MapBatches()\n" " +- Dataset(\n" - " num_rows=150,\n" + " num_rows=?,\n" " schema={\n" " sepal.length: double,\n" " sepal.width: double,\n" diff --git a/python/ray/data/tests/test_metadata_provider.py b/python/ray/data/tests/test_metadata_provider.py index ef4899abd085..c67e0890d506 100644 --- a/python/ray/data/tests/test_metadata_provider.py +++ b/python/ray/data/tests/test_metadata_provider.py @@ -27,6 +27,7 @@ _get_file_infos_parallel, _get_file_infos_serial, ) +from ray.data.datasource.parquet_meta_provider import _get_total_bytes from ray.data.datasource.path_util import ( _resolve_paths_and_filesystem, _unwrap_protocol, @@ -40,13 +41,6 @@ def df_to_csv(dataframe, path, **kwargs): dataframe.to_csv(path, **kwargs) -def _get_parquet_file_meta_size_bytes(file_metas): - return sum( - sum(m.row_group(i).total_byte_size for i in range(m.num_row_groups)) - for m in file_metas - ) - - def _get_file_sizes_bytes(paths, fs): from pyarrow.fs import FileType @@ -111,9 +105,11 @@ def test_default_parquet_metadata_provider(fs, data_path): num_fragments=len(pq_ds.fragments), prefetched_metadata=fragment_file_metas, ) - expected_meta_size_bytes = _get_parquet_file_meta_size_bytes( - [f.metadata for f in pq_ds.fragments] + + expected_meta_size_bytes = sum( + [_get_total_bytes(f.metadata) for f in pq_ds.fragments] ) + assert meta.size_bytes == expected_meta_size_bytes assert meta.num_rows == 6 assert len(paths) == 2 diff --git a/python/ray/data/tests/test_parquet.py b/python/ray/data/tests/test_parquet.py index bacf2c649941..67a517cfc002 100644 --- a/python/ray/data/tests/test_parquet.py +++ b/python/ray/data/tests/test_parquet.py @@ -2,7 +2,7 @@ import shutil import time from dataclasses import dataclass -from typing import Any, Optional +from typing import Optional import numpy as np import pandas as pd @@ -23,8 +23,6 @@ from ray.data._internal.datasource.parquet_datasource import ( NUM_CPUS_FOR_META_FETCH_TASK, ParquetDatasource, - SerializedFragment, - _deserialize_fragments_with_retry, ) from ray.data._internal.execution.interfaces.ref_bundle import ( _ref_bundles_iterator_to_block_refs_list, @@ -114,70 +112,6 @@ def test_include_paths( assert paths == [path, path] -@pytest.mark.parametrize( - "fs,data_path", - [ - (lazy_fixture("local_fs"), lazy_fixture("local_path")), - ], -) -def test_parquet_deserialize_fragments_with_retry( - ray_start_regular_shared, fs, data_path, monkeypatch -): - setup_data_path = _unwrap_protocol(data_path) - df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]}) - table = pa.Table.from_pandas(df1) - path1 = os.path.join(setup_data_path, "test1.parquet") - pq.write_table(table, path1, filesystem=fs) - df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]}) - table = pa.Table.from_pandas(df2) - path2 = os.path.join(setup_data_path, "test2.parquet") - pq.write_table(table, path2, filesystem=fs) - - dataset_kwargs = {} - pq_ds = pq.ParquetDataset( - data_path, - **dataset_kwargs, - filesystem=fs, - ) - serialized_fragments = [SerializedFragment(p) for p in pq_ds.fragments] - - # test 1st attempt succeed - fragments = _deserialize_fragments_with_retry(serialized_fragments) - assert "test1.parquet" in fragments[0].path - assert "test2.parquet" in fragments[1].path - - # test the 3rd attempt succeed with a mock function constructed - # to throw in the first two attempts - class MockDeserializer: - def __init__(self, planned_exp_or_return): - self.planned_exp_or_return = planned_exp_or_return - self.cur_index = 0 - - def __call__(self, *args: Any, **kwds: Any) -> Any: - exp_or_ret = self.planned_exp_or_return[self.cur_index] - self.cur_index += 1 - if isinstance(exp_or_ret, Exception): - raise exp_or_ret - else: - return exp_or_ret - - mock_deserializer = MockDeserializer( - [ - Exception("1st mock failed attempt"), - Exception("2nd mock failed attempt"), - fragments, - ] - ) - monkeypatch.setattr( - ray.data._internal.datasource.parquet_datasource, - "_deserialize_fragments", - mock_deserializer, - ) - retried_fragments = _deserialize_fragments_with_retry(serialized_fragments) - assert "test1.parquet" in retried_fragments[0].path - assert "test2.parquet" in retried_fragments[1].path - - @pytest.mark.parametrize( "fs,data_path", [ @@ -267,7 +201,10 @@ def test_parquet_read_meta_provider(ray_start_regular_shared, fs, data_path): pq.write_table(table, path2, filesystem=fs) expected_num_rows = len(df1) + len(df2) - expected_byte_size = 787500 + # NOTE: Since we're testing against various Pyarrow versions size + # on disk could be varying slightly as it on top of data it also + # includes metadata + expected_byte_size = pytest.approx(463500, abs=500) # # Case 1: Test metadata fetching happy path (obtaining, caching and propagating @@ -290,14 +227,19 @@ def prefetch_file_metadata(self, fragments, **ray_remote_args): ) # Expect precomputed row counts and block sizes to be missing. - assert ds._meta_count() == expected_num_rows + assert ds._meta_count() is None read_op = ds._plan._logical_plan.dag # Assert Read op metadata propagation - assert read_op.infer_metadata() == BlockMetadata( - num_rows=expected_num_rows, - size_bytes=expected_byte_size, + metadata = read_op.infer_metadata() + # NOTE: We assert on byte size separately, since we're using `pytest.approx` + # object for it + assert metadata.size_bytes == expected_byte_size + + assert metadata == BlockMetadata( + num_rows=None, + size_bytes=metadata.size_bytes, exec_stats=None, input_files=[path1, path2], ) @@ -371,8 +313,6 @@ def prefetch_file_metadata(self, fragments, **ray_remote_args): assert ds.schema() == Schema(expected_schema) assert set(ds.input_files()) == {path1, path2} - assert ds._plan.has_computed_output() - @pytest.mark.parametrize( "fs,data_path", @@ -899,7 +839,11 @@ def test_parquet_reader_estimate_data_size(shutdown_only, tmp_path): ctx.decoding_size_estimation = True try: tensor_output_path = os.path.join(tmp_path, "tensor") - ray.data.range_tensor(1000, shape=(1000,)).write_parquet(tensor_output_path) + # NOTE: It's crucial to override # of blocks to get stable # of files + # produced and make sure data size estimates are stable + ray.data.range_tensor( + 1000, shape=(1000,), override_num_blocks=10 + ).write_parquet(tensor_output_path) ds = ray.data.read_parquet( tensor_output_path, meta_provider=ParquetMetadataProvider() ) @@ -940,7 +884,7 @@ def test_parquet_reader_estimate_data_size(shutdown_only, tmp_path): assert ds._plan.initial_num_blocks() > 1 data_size = ds.size_bytes() assert ( - data_size >= 800_000 and data_size <= 2_000_000 + data_size >= 800_000 and data_size <= 2_200_000 ), "estimated data size is out of expected bound" data_size = ds.materialize().size_bytes() assert ( @@ -955,7 +899,7 @@ def test_parquet_reader_estimate_data_size(shutdown_only, tmp_path): ), "encoding ratio is out of expected bound" data_size = datasource.estimate_inmemory_data_size() assert ( - data_size >= 800_000 and data_size <= 2_000_000 + data_size >= 800_000 and data_size <= 2_200_000 ), "estimated data size is out of expected bound" assert ( data_size