Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 70 additions & 44 deletions python/ray/data/checkpoint/checkpoint_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import logging
import os
import time
from typing import List, Optional

Expand All @@ -15,6 +16,8 @@
from ray.data.datasource.path_util import _unwrap_protocol
from ray.types import ObjectRef

import psutil

logger = logging.getLogger(__name__)


Expand All @@ -35,26 +38,43 @@ def __init__(self, config: CheckpointConfig):


@ray.remote(max_retries=-1)
def _combine_chunks(ckpt_block: pyarrow.Table) -> pyarrow.Table:
def _combine_chunks(ckpt_block: pyarrow.Table, id_column: str) -> numpy.ndarray:
"""Combine chunks for the checkpoint block.

Args:
ckpt_block: The checkpoint block to combine chunks for
id_column: The id column

Returns:
The combined checkpoint block
The numpy.ndarray of combined checkpoint block
"""
from ray.data._internal.arrow_ops.transform_pyarrow import combine_chunks

combined_ckpt_block = combine_chunks(ckpt_block)
# Combine chunks of ckpt_block
combined_ckpt_block = transform_pyarrow.combine_chunks(ckpt_block)
logger.debug(
"Checkpoint block stats for id column checkpoint: Combined block: type=%s, %d rows, %d bytes",
combined_ckpt_block.schema.to_string(),
combined_ckpt_block.num_rows,
combined_ckpt_block.nbytes,
)
if combined_ckpt_block.num_rows == 0:
return numpy.array([])

# In some cases(e.g., the size of combined_ckpt_block[id_column] > 2GB), there may be multiple chunks.
combine_ckpt_chunks = combined_ckpt_block[id_column].chunks

logger.debug(
"Checkpoint stats for id column chunks: Num of chunks: %d",
len(combine_ckpt_chunks),
)

return combined_ckpt_block
# Convert checkpoint chunk to numpy for fast search.
# Use internal helper function for consistency and robustness (handles null-typed arrays, etc.)
ckpt_arrays = []
for chunk in combine_ckpt_chunks:
ckpt_arrays.append(transform_pyarrow.to_numpy(chunk, zero_copy_only=False))
final_ckpt_array = numpy.concatenate(ckpt_arrays)
return final_ckpt_array


class CheckpointLoader:
Expand All @@ -81,11 +101,11 @@ def __init__(
self.id_column = id_column
self.checkpoint_path_partition_filter = checkpoint_path_partition_filter

def load_checkpoint(self) -> ObjectRef[Block]:
def load_checkpoint(self) -> ObjectRef[numpy.ndarray]:
"""Loading checkpoint data.

Returns:
ObjectRef[Block]: ObjectRef to the checkpointed IDs block.
ObjectRef[numpy.ndarray]: ObjectRef to the checkpointed IDs array.
"""
start_t = time.time()

Expand Down Expand Up @@ -118,20 +138,21 @@ def load_checkpoint(self) -> ObjectRef[Block]:
metadata: BlockMetadata = ref_bundle.blocks[0][1]

# Post-process the block
checkpoint_block_ref: ObjectRef[Block] = self._postprocess_block(block_ref)
checkpoint_ndarray_ref: ObjectRef[numpy.ndarray] = self._postprocess_block(
block_ref
)

# Validate the loaded checkpoint
self._validate_loaded_checkpoint(schema, metadata)

logger.info(
"Checkpoint loaded for %s in %.2f seconds. SizeBytes = %d, Schema = %s",
"Checkpoint loaded for %s in %.2f seconds. Arrow SizeBytes = %d, Schema = %s",
type(self).__name__,
time.time() - start_t,
metadata.size_bytes,
schema.to_string(),
)

return checkpoint_block_ref
return checkpoint_ndarray_ref

@abc.abstractmethod
def _preprocess_data_pipeline(
Expand All @@ -140,9 +161,11 @@ def _preprocess_data_pipeline(
"""Pre-process the checkpoint dataset. To be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement this method")

def _postprocess_block(self, block_ref: ObjectRef[Block]) -> ObjectRef[Block]:
def _postprocess_block(
self, block_ref: ObjectRef[Block]
) -> ObjectRef[numpy.ndarray]:
"""Combine the block so it has fewer chunks."""
return _combine_chunks.remote(block_ref)
return _combine_chunks.remote(block_ref, self.id_column)

def _validate_loaded_checkpoint(
self, schema: Schema, metadata: BlockMetadata
Expand Down Expand Up @@ -174,11 +197,11 @@ def _preprocess_data_pipeline(
class BatchBasedCheckpointFilter(CheckpointFilter):
"""CheckpointFilter for batch-based backends."""

def load_checkpoint(self) -> ObjectRef[Block]:
def load_checkpoint(self) -> ObjectRef[numpy.ndarray]:
"""Load checkpointed ids as a sorted block.

Returns:
ObjectRef[Block]: ObjectRef to the checkpointed IDs block.
ObjectRef[numpy.ndarray]: ObjectRef to the checkpointed IDs array.
"""
loader = IdColumnCheckpointLoader(
checkpoint_path=self.checkpoint_path,
Expand All @@ -191,76 +214,79 @@ def load_checkpoint(self) -> ObjectRef[Block]:
def delete_checkpoint(self) -> None:
self.filesystem.delete_dir(self.checkpoint_path_unwrapped)

def _warn_on_insufficient_memory(self):
"""When using checkpoints, each read process needs to maintain checkpointed_ids
to filter the input blocks. This will increase the memory usage of each process.
When there is a potential risk of OOM, this function warn the user.

"""
process = psutil.Process(os.getpid())
process_memory_usage = process.memory_info().rss
node_memory = psutil.virtual_memory()
# This node can not accept more read task
if process_memory_usage > node_memory.available:
logger.warning(
"Memory usage of current node: %.1f%%, per read task costs at least %d bytes."
' To prevent oom, set ray_remote_args={"memory": ${memory}} in ray.data.read_datasource()'
" and make sure ${memory} > %d",
node_memory.percent,
process_memory_usage,
process_memory_usage,
)

def filter_rows_for_block(
self,
block: Block,
checkpointed_ids: Block,
checkpointed_ids: numpy.ndarray,
) -> Block:
"""For the given block, filter out rows that have already
been checkpointed, and return the resulting block.

Args:
block: The input block to filter.
checkpointed_ids: A block containing IDs of all rows that have
checkpointed_ids: A numpy ndarray containing IDs of all rows that have
been checkpointed.
Returns:
A new block with rows that have not been checkpointed.
"""

if len(checkpointed_ids) == 0 or len(block) == 0:
if checkpointed_ids.shape[0] == 0 or len(block) == 0:
return block

assert isinstance(block, pyarrow.Table)
assert isinstance(checkpointed_ids, pyarrow.Table)
assert isinstance(checkpointed_ids, numpy.ndarray)
self._warn_on_insufficient_memory()

# The checkpointed_ids block is sorted (see load_checkpoint).
# We'll use binary search to filter out processed rows.
# And we process a single chunk at a time, otherwise `to_numpy` below
# will copy the data from shared memory to worker's heap memory.

import concurrent.futures

# Get all chunks of the checkpointed ID column.
ckpt_chunks = checkpointed_ids[self.id_column].chunks
# Convert the block's ID column to a numpy array for fast processing.
block_ids = block[self.id_column].to_numpy()

def filter_with_ckpt_chunk(ckpt_chunk: pyarrow.ChunkedArray) -> numpy.ndarray:
# Convert checkpoint chunk to numpy for fast search.
# Use internal helper function for consistency and robustness (handles null-typed arrays, etc.)
ckpt_ids = transform_pyarrow.to_numpy(ckpt_chunk, zero_copy_only=False)
def filter_with_ckpt() -> numpy.ndarray:
# Start with a mask of all True (keep all rows).
mask = numpy.ones(len(block_ids), dtype=bool)
# Use binary search to find where block_ids would be in ckpt_ids.
sorted_indices = numpy.searchsorted(ckpt_ids, block_ids)
sorted_indices = numpy.searchsorted(checkpointed_ids, block_ids)
# Only consider indices that are within bounds.
valid_indices = sorted_indices < len(ckpt_ids)
valid_indices = sorted_indices < len(checkpointed_ids)
# For valid indices, check for exact matches.
potential_matches = sorted_indices[valid_indices]
matched = ckpt_ids[potential_matches] == block_ids[valid_indices]
matched = checkpointed_ids[potential_matches] == block_ids[valid_indices]
# Mark matched IDs as False (filter out these rows).
mask[valid_indices] = ~matched
# Delete the chunk to free memory.
del ckpt_chunk
return mask

# Use ThreadPoolExecutor to process each checkpoint chunk in parallel.
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.filter_num_threads or None
) as executor:
masks = list(executor.map(filter_with_ckpt_chunk, ckpt_chunks))

# Combine all masks using logical AND (row must not be in any checkpoint chunk).
final_mask = numpy.logical_and.reduce(masks)
mask = filter_with_ckpt()
# Convert the final mask to a PyArrow array and filter the block.
mask_array = pyarrow.array(final_mask)
mask_array = pyarrow.array(mask)
filtered_block = block.filter(mask_array)
return filtered_block

def filter_rows_for_batch(
self,
batch: DataBatch,
checkpointed_ids: Block,
checkpointed_ids: numpy.ndarray,
) -> DataBatch:
"""For the given batch, filter out rows that have already
been checkpointed, and return the resulting batch.
Expand Down
7 changes: 4 additions & 3 deletions python/ray/data/checkpoint/load_checkpoint_callback.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
from typing import Optional

import numpy

from ray.data._internal.execution.execution_callback import (
ExecutionCallback,
remove_execution_callback,
)
from ray.data._internal.execution.streaming_executor import StreamingExecutor
from ray.data.block import Block
from ray.data.checkpoint import CheckpointConfig
from ray.data.checkpoint.checkpoint_filter import BatchBasedCheckpointFilter
from ray.types import ObjectRef
Expand All @@ -22,7 +23,7 @@ def __init__(self, config: CheckpointConfig):
self._config = config

self._ckpt_filter = self._create_checkpoint_filter(config)
self._checkpoint_ref: Optional[ObjectRef[Block]] = None
self._checkpoint_ref: Optional[ObjectRef[numpy.ndarray]] = None

def _create_checkpoint_filter(
self, config: CheckpointConfig
Expand Down Expand Up @@ -57,6 +58,6 @@ def after_execution_fails(self, executor: StreamingExecutor, error: Exception):
# Remove the callback from the DataContext.
remove_execution_callback(self, executor._data_context)

def load_checkpoint(self) -> ObjectRef[Block]:
def load_checkpoint(self) -> ObjectRef[numpy.ndarray]:
assert self._checkpoint_ref is not None
return self._checkpoint_ref
15 changes: 12 additions & 3 deletions python/ray/data/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pytest_lazy_fixtures import lf as lazy_fixture

import ray
from ray.data._internal.arrow_ops import transform_pyarrow
from ray.data._internal.datasource.csv_datasource import CSVDatasource
from ray.data._internal.datasource.parquet_datasink import ParquetDatasink
from ray.data._internal.logical.interfaces.logical_plan import LogicalPlan
Expand Down Expand Up @@ -682,8 +683,16 @@ def test_filter_rows_for_block():
chunk1 = pyarrow.table({ID_COL: [1, 2, 4]})
chunk2 = pyarrow.table({ID_COL: [6, 8, 9, 11]})
chunk3 = pyarrow.table({ID_COL: [12, 13]})
checkpointed_ids = pyarrow.concat_tables([chunk1, chunk2, chunk3])
assert len(checkpointed_ids[ID_COL].chunks) == 3
pyarrow_checkpointed_ids = pyarrow.concat_tables([chunk1, chunk2, chunk3])
assert len(pyarrow_checkpointed_ids[ID_COL].chunks) == 3

combined_ckpt_block = transform_pyarrow.combine_chunks(pyarrow_checkpointed_ids)

combine_ckpt_chunks = combined_ckpt_block[ID_COL].chunks
assert len(combine_ckpt_chunks) == 1
checkpoint_ids = transform_pyarrow.to_numpy(
combine_ckpt_chunks[0], zero_copy_only=False
)

expected_block = pyarrow.table(
{
Expand All @@ -695,7 +704,7 @@ def test_filter_rows_for_block():
filter_instance = BatchBasedCheckpointFilter(config)
filtered_block = filter_instance.filter_rows_for_block(
block=block,
checkpointed_ids=checkpointed_ids,
checkpointed_ids=checkpoint_ids,
)

assert filtered_block.equals(expected_block)
Expand Down