diff --git a/python/ray/data/checkpoint/checkpoint_filter.py b/python/ray/data/checkpoint/checkpoint_filter.py index 20d6e800a7d2..8436f7b7d5ed 100644 --- a/python/ray/data/checkpoint/checkpoint_filter.py +++ b/python/ray/data/checkpoint/checkpoint_filter.py @@ -1,5 +1,6 @@ import abc import logging +import os import time from typing import List, Optional @@ -15,6 +16,8 @@ from ray.data.datasource.path_util import _unwrap_protocol from ray.types import ObjectRef +import psutil + logger = logging.getLogger(__name__) @@ -35,26 +38,43 @@ def __init__(self, config: CheckpointConfig): @ray.remote(max_retries=-1) -def _combine_chunks(ckpt_block: pyarrow.Table) -> pyarrow.Table: +def _combine_chunks(ckpt_block: pyarrow.Table, id_column: str) -> numpy.ndarray: """Combine chunks for the checkpoint block. Args: ckpt_block: The checkpoint block to combine chunks for + id_column: The id column Returns: - The combined checkpoint block + The numpy.ndarray of combined checkpoint block """ - from ray.data._internal.arrow_ops.transform_pyarrow import combine_chunks - combined_ckpt_block = combine_chunks(ckpt_block) + # Combine chunks of ckpt_block + combined_ckpt_block = transform_pyarrow.combine_chunks(ckpt_block) logger.debug( "Checkpoint block stats for id column checkpoint: Combined block: type=%s, %d rows, %d bytes", combined_ckpt_block.schema.to_string(), combined_ckpt_block.num_rows, combined_ckpt_block.nbytes, ) + if combined_ckpt_block.num_rows == 0: + return numpy.array([]) + + # In some cases(e.g., the size of combined_ckpt_block[id_column] > 2GB), there may be multiple chunks. + combine_ckpt_chunks = combined_ckpt_block[id_column].chunks + + logger.debug( + "Checkpoint stats for id column chunks: Num of chunks: %d", + len(combine_ckpt_chunks), + ) - return combined_ckpt_block + # Convert checkpoint chunk to numpy for fast search. + # Use internal helper function for consistency and robustness (handles null-typed arrays, etc.) + ckpt_arrays = [] + for chunk in combine_ckpt_chunks: + ckpt_arrays.append(transform_pyarrow.to_numpy(chunk, zero_copy_only=False)) + final_ckpt_array = numpy.concatenate(ckpt_arrays) + return final_ckpt_array class CheckpointLoader: @@ -81,11 +101,11 @@ def __init__( self.id_column = id_column self.checkpoint_path_partition_filter = checkpoint_path_partition_filter - def load_checkpoint(self) -> ObjectRef[Block]: + def load_checkpoint(self) -> ObjectRef[numpy.ndarray]: """Loading checkpoint data. Returns: - ObjectRef[Block]: ObjectRef to the checkpointed IDs block. + ObjectRef[numpy.ndarray]: ObjectRef to the checkpointed IDs array. """ start_t = time.time() @@ -118,20 +138,21 @@ def load_checkpoint(self) -> ObjectRef[Block]: metadata: BlockMetadata = ref_bundle.blocks[0][1] # Post-process the block - checkpoint_block_ref: ObjectRef[Block] = self._postprocess_block(block_ref) + checkpoint_ndarray_ref: ObjectRef[numpy.ndarray] = self._postprocess_block( + block_ref + ) # Validate the loaded checkpoint self._validate_loaded_checkpoint(schema, metadata) logger.info( - "Checkpoint loaded for %s in %.2f seconds. SizeBytes = %d, Schema = %s", + "Checkpoint loaded for %s in %.2f seconds. Arrow SizeBytes = %d, Schema = %s", type(self).__name__, time.time() - start_t, metadata.size_bytes, schema.to_string(), ) - - return checkpoint_block_ref + return checkpoint_ndarray_ref @abc.abstractmethod def _preprocess_data_pipeline( @@ -140,9 +161,11 @@ def _preprocess_data_pipeline( """Pre-process the checkpoint dataset. To be implemented by subclasses.""" raise NotImplementedError("Subclasses must implement this method") - def _postprocess_block(self, block_ref: ObjectRef[Block]) -> ObjectRef[Block]: + def _postprocess_block( + self, block_ref: ObjectRef[Block] + ) -> ObjectRef[numpy.ndarray]: """Combine the block so it has fewer chunks.""" - return _combine_chunks.remote(block_ref) + return _combine_chunks.remote(block_ref, self.id_column) def _validate_loaded_checkpoint( self, schema: Schema, metadata: BlockMetadata @@ -174,11 +197,11 @@ def _preprocess_data_pipeline( class BatchBasedCheckpointFilter(CheckpointFilter): """CheckpointFilter for batch-based backends.""" - def load_checkpoint(self) -> ObjectRef[Block]: + def load_checkpoint(self) -> ObjectRef[numpy.ndarray]: """Load checkpointed ids as a sorted block. Returns: - ObjectRef[Block]: ObjectRef to the checkpointed IDs block. + ObjectRef[numpy.ndarray]: ObjectRef to the checkpointed IDs array. """ loader = IdColumnCheckpointLoader( checkpoint_path=self.checkpoint_path, @@ -191,76 +214,79 @@ def load_checkpoint(self) -> ObjectRef[Block]: def delete_checkpoint(self) -> None: self.filesystem.delete_dir(self.checkpoint_path_unwrapped) + def _warn_on_insufficient_memory(self): + """When using checkpoints, each read process needs to maintain checkpointed_ids + to filter the input blocks. This will increase the memory usage of each process. + When there is a potential risk of OOM, this function warn the user. + + """ + process = psutil.Process(os.getpid()) + process_memory_usage = process.memory_info().rss + node_memory = psutil.virtual_memory() + # This node can not accept more read task + if process_memory_usage > node_memory.available: + logger.warning( + "Memory usage of current node: %.1f%%, per read task costs at least %d bytes." + ' To prevent oom, set ray_remote_args={"memory": ${memory}} in ray.data.read_datasource()' + " and make sure ${memory} > %d", + node_memory.percent, + process_memory_usage, + process_memory_usage, + ) + def filter_rows_for_block( self, block: Block, - checkpointed_ids: Block, + checkpointed_ids: numpy.ndarray, ) -> Block: """For the given block, filter out rows that have already been checkpointed, and return the resulting block. Args: block: The input block to filter. - checkpointed_ids: A block containing IDs of all rows that have + checkpointed_ids: A numpy ndarray containing IDs of all rows that have been checkpointed. Returns: A new block with rows that have not been checkpointed. """ - if len(checkpointed_ids) == 0 or len(block) == 0: + if checkpointed_ids.shape[0] == 0 or len(block) == 0: return block assert isinstance(block, pyarrow.Table) - assert isinstance(checkpointed_ids, pyarrow.Table) + assert isinstance(checkpointed_ids, numpy.ndarray) + self._warn_on_insufficient_memory() # The checkpointed_ids block is sorted (see load_checkpoint). # We'll use binary search to filter out processed rows. - # And we process a single chunk at a time, otherwise `to_numpy` below - # will copy the data from shared memory to worker's heap memory. - - import concurrent.futures - # Get all chunks of the checkpointed ID column. - ckpt_chunks = checkpointed_ids[self.id_column].chunks # Convert the block's ID column to a numpy array for fast processing. block_ids = block[self.id_column].to_numpy() - def filter_with_ckpt_chunk(ckpt_chunk: pyarrow.ChunkedArray) -> numpy.ndarray: - # Convert checkpoint chunk to numpy for fast search. - # Use internal helper function for consistency and robustness (handles null-typed arrays, etc.) - ckpt_ids = transform_pyarrow.to_numpy(ckpt_chunk, zero_copy_only=False) + def filter_with_ckpt() -> numpy.ndarray: # Start with a mask of all True (keep all rows). mask = numpy.ones(len(block_ids), dtype=bool) # Use binary search to find where block_ids would be in ckpt_ids. - sorted_indices = numpy.searchsorted(ckpt_ids, block_ids) + sorted_indices = numpy.searchsorted(checkpointed_ids, block_ids) # Only consider indices that are within bounds. - valid_indices = sorted_indices < len(ckpt_ids) + valid_indices = sorted_indices < len(checkpointed_ids) # For valid indices, check for exact matches. potential_matches = sorted_indices[valid_indices] - matched = ckpt_ids[potential_matches] == block_ids[valid_indices] + matched = checkpointed_ids[potential_matches] == block_ids[valid_indices] # Mark matched IDs as False (filter out these rows). mask[valid_indices] = ~matched - # Delete the chunk to free memory. - del ckpt_chunk return mask - # Use ThreadPoolExecutor to process each checkpoint chunk in parallel. - with concurrent.futures.ThreadPoolExecutor( - max_workers=self.filter_num_threads or None - ) as executor: - masks = list(executor.map(filter_with_ckpt_chunk, ckpt_chunks)) - - # Combine all masks using logical AND (row must not be in any checkpoint chunk). - final_mask = numpy.logical_and.reduce(masks) + mask = filter_with_ckpt() # Convert the final mask to a PyArrow array and filter the block. - mask_array = pyarrow.array(final_mask) + mask_array = pyarrow.array(mask) filtered_block = block.filter(mask_array) return filtered_block def filter_rows_for_batch( self, batch: DataBatch, - checkpointed_ids: Block, + checkpointed_ids: numpy.ndarray, ) -> DataBatch: """For the given batch, filter out rows that have already been checkpointed, and return the resulting batch. diff --git a/python/ray/data/checkpoint/load_checkpoint_callback.py b/python/ray/data/checkpoint/load_checkpoint_callback.py index c637d6ca9998..bda89ff6b7b5 100644 --- a/python/ray/data/checkpoint/load_checkpoint_callback.py +++ b/python/ray/data/checkpoint/load_checkpoint_callback.py @@ -1,12 +1,13 @@ import logging from typing import Optional +import numpy + from ray.data._internal.execution.execution_callback import ( ExecutionCallback, remove_execution_callback, ) from ray.data._internal.execution.streaming_executor import StreamingExecutor -from ray.data.block import Block from ray.data.checkpoint import CheckpointConfig from ray.data.checkpoint.checkpoint_filter import BatchBasedCheckpointFilter from ray.types import ObjectRef @@ -22,7 +23,7 @@ def __init__(self, config: CheckpointConfig): self._config = config self._ckpt_filter = self._create_checkpoint_filter(config) - self._checkpoint_ref: Optional[ObjectRef[Block]] = None + self._checkpoint_ref: Optional[ObjectRef[numpy.ndarray]] = None def _create_checkpoint_filter( self, config: CheckpointConfig @@ -57,6 +58,6 @@ def after_execution_fails(self, executor: StreamingExecutor, error: Exception): # Remove the callback from the DataContext. remove_execution_callback(self, executor._data_context) - def load_checkpoint(self) -> ObjectRef[Block]: + def load_checkpoint(self) -> ObjectRef[numpy.ndarray]: assert self._checkpoint_ref is not None return self._checkpoint_ref diff --git a/python/ray/data/tests/test_checkpoint.py b/python/ray/data/tests/test_checkpoint.py index 699e06d33b16..ac16cbd338b7 100644 --- a/python/ray/data/tests/test_checkpoint.py +++ b/python/ray/data/tests/test_checkpoint.py @@ -12,6 +12,7 @@ from pytest_lazy_fixtures import lf as lazy_fixture import ray +from ray.data._internal.arrow_ops import transform_pyarrow from ray.data._internal.datasource.csv_datasource import CSVDatasource from ray.data._internal.datasource.parquet_datasink import ParquetDatasink from ray.data._internal.logical.interfaces.logical_plan import LogicalPlan @@ -682,8 +683,16 @@ def test_filter_rows_for_block(): chunk1 = pyarrow.table({ID_COL: [1, 2, 4]}) chunk2 = pyarrow.table({ID_COL: [6, 8, 9, 11]}) chunk3 = pyarrow.table({ID_COL: [12, 13]}) - checkpointed_ids = pyarrow.concat_tables([chunk1, chunk2, chunk3]) - assert len(checkpointed_ids[ID_COL].chunks) == 3 + pyarrow_checkpointed_ids = pyarrow.concat_tables([chunk1, chunk2, chunk3]) + assert len(pyarrow_checkpointed_ids[ID_COL].chunks) == 3 + + combined_ckpt_block = transform_pyarrow.combine_chunks(pyarrow_checkpointed_ids) + + combine_ckpt_chunks = combined_ckpt_block[ID_COL].chunks + assert len(combine_ckpt_chunks) == 1 + checkpoint_ids = transform_pyarrow.to_numpy( + combine_ckpt_chunks[0], zero_copy_only=False + ) expected_block = pyarrow.table( { @@ -695,7 +704,7 @@ def test_filter_rows_for_block(): filter_instance = BatchBasedCheckpointFilter(config) filtered_block = filter_instance.filter_rows_for_block( block=block, - checkpointed_ids=checkpointed_ids, + checkpointed_ids=checkpoint_ids, ) assert filtered_block.equals(expected_block)