Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions python/ray/data/_internal/planner/checkpoint/plan_read_op.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import functools
from typing import Callable, List, Optional
from typing import List

from ray import ObjectRef
from ray.data._internal.execution.interfaces import PhysicalOperator
from ray.data._internal.execution.operators.map_transformer import (
BlockMapTransformFn,
Expand All @@ -10,7 +9,6 @@
from ray.data._internal.output_buffer import OutputBlockSizeOption
from ray.data._internal.planner.plan_read_op import plan_read_op
from ray.data.checkpoint.util import (
CHECKPOINTED_IDS_KWARG_NAME,
filter_checkpointed_rows_for_blocks,
)
from ray.data.context import DataContext
Expand All @@ -20,7 +18,6 @@ def plan_read_op_with_checkpoint_filter(
op: Read,
physical_children: List[PhysicalOperator],
data_context: DataContext,
load_checkpoint: Optional[Callable[[], ObjectRef]] = None,
) -> PhysicalOperator:
physical_op = plan_read_op(op, physical_children, data_context)

Expand All @@ -30,7 +27,6 @@ def plan_read_op_with_checkpoint_filter(
BlockMapTransformFn(
functools.partial(
filter_checkpointed_rows_for_blocks,
checkpoint_config=data_context.checkpoint_config,
),
output_block_size_option=OutputBlockSizeOption.of(
target_max_block_size=data_context.target_max_block_size,
Expand All @@ -39,9 +35,4 @@ def plan_read_op_with_checkpoint_filter(
]
)

if load_checkpoint is not None:
physical_op.add_map_task_kwargs_fn(
lambda: {CHECKPOINTED_IDS_KWARG_NAME: load_checkpoint()}
)

return physical_op
8 changes: 1 addition & 7 deletions python/ray/data/_internal/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar

from ray import ObjectRef
from ray.data._internal.execution.execution_callback import add_execution_callback
from ray.data._internal.execution.interfaces import PhysicalOperator
from ray.data._internal.execution.operators.aggregate_num_rows import (
Expand Down Expand Up @@ -185,13 +184,10 @@ def plan(self, logical_plan: LogicalPlan) -> PhysicalPlan:

checkpoint_callback = self._create_checkpoint_callback(checkpoint_config)
add_execution_callback(checkpoint_callback, logical_plan.context)
load_checkpoint = checkpoint_callback.load_checkpoint

# Dynamically set the plan functions for checkpointing because they
# need to a reference to the checkpoint ref.
self._plan_fns_for_checkpointing = self._get_plan_fns_for_checkpointing(
load_checkpoint
)
self._plan_fns_for_checkpointing = self._get_plan_fns_for_checkpointing()

elif checkpoint_config is not None:
assert not self._check_supports_checkpointing(logical_plan)
Expand Down Expand Up @@ -275,12 +271,10 @@ def _create_checkpoint_callback(self, checkpoint_config) -> LoadCheckpointCallba

def _get_plan_fns_for_checkpointing(
self,
load_checkpoint: Callable[[], ObjectRef],
) -> Dict[Type[LogicalOperator], PlanLogicalOpFn]:
plan_fns = {
Read: functools.partial(
plan_read_op_with_checkpoint_filter,
load_checkpoint=load_checkpoint,
),
Write: plan_write_op_with_checkpoint_writer,
}
Expand Down
149 changes: 64 additions & 85 deletions python/ray/data/checkpoint/checkpoint_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import ray
from ray.data._internal.arrow_ops import transform_pyarrow
from ray.data._internal.arrow_ops.transform_pyarrow import combine_chunks
from ray.data._internal.execution.interfaces.ref_bundle import RefBundle
from ray.data.block import Block, BlockAccessor, BlockMetadata, DataBatch, Schema
from ray.data.checkpoint import CheckpointConfig
Expand All @@ -18,45 +19,6 @@
logger = logging.getLogger(__name__)


class CheckpointFilter(abc.ABC):
"""Abstract class which defines the interface for filtering checkpointed rows
based on varying backends.
"""

def __init__(self, config: CheckpointConfig):
self.ckpt_config = config
self.checkpoint_path = self.ckpt_config.checkpoint_path
self.checkpoint_path_unwrapped = _unwrap_protocol(
self.ckpt_config.checkpoint_path
)
self.id_column = self.ckpt_config.id_column
self.filesystem = self.ckpt_config.filesystem
self.filter_num_threads = self.ckpt_config.filter_num_threads


@ray.remote(max_retries=-1)
def _combine_chunks(ckpt_block: pyarrow.Table) -> pyarrow.Table:
"""Combine chunks for the checkpoint block.

Args:
ckpt_block: The checkpoint block to combine chunks for

Returns:
The combined checkpoint block
"""
from ray.data._internal.arrow_ops.transform_pyarrow import combine_chunks

combined_ckpt_block = combine_chunks(ckpt_block)
logger.debug(
"Checkpoint block stats for id column checkpoint: Combined block: type=%s, %d rows, %d bytes",
combined_ckpt_block.schema.to_string(),
combined_ckpt_block.num_rows,
combined_ckpt_block.nbytes,
)

return combined_ckpt_block


class CheckpointLoader:
"""Loading checkpoint data."""

Expand All @@ -81,11 +43,11 @@ def __init__(
self.id_column = id_column
self.checkpoint_path_partition_filter = checkpoint_path_partition_filter

def load_checkpoint(self) -> ObjectRef[Block]:
def load_checkpoint(self) -> numpy.ndarray:
"""Loading checkpoint data.

Returns:
ObjectRef[Block]: ObjectRef to the checkpointed IDs block.
numpy.ndarray: The checkpointed IDs array.
"""
start_t = time.time()

Expand Down Expand Up @@ -118,20 +80,19 @@ def load_checkpoint(self) -> ObjectRef[Block]:
metadata: BlockMetadata = ref_bundle.blocks[0][1]

# Post-process the block
checkpoint_block_ref: ObjectRef[Block] = self._postprocess_block(block_ref)
checkpoint_ndarray: numpy.ndarray = self._postprocess_block(block_ref)

# Validate the loaded checkpoint
self._validate_loaded_checkpoint(schema, metadata)

logger.info(
"Checkpoint loaded for %s in %.2f seconds. SizeBytes = %d, Schema = %s",
"Checkpoint loaded for %s in %.2f seconds. Arrow SizeBytes = %d, Schema = %s",
type(self).__name__,
time.time() - start_t,
metadata.size_bytes,
schema.to_string(),
)

return checkpoint_block_ref
return checkpoint_ndarray

@abc.abstractmethod
def _preprocess_data_pipeline(
Expand All @@ -140,9 +101,22 @@ def _preprocess_data_pipeline(
"""Pre-process the checkpoint dataset. To be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement this method")

def _postprocess_block(self, block_ref: ObjectRef[Block]) -> ObjectRef[Block]:
"""Combine the block so it has fewer chunks."""
return _combine_chunks.remote(block_ref)
def _postprocess_block(self, block_ref: ObjectRef[Block]) -> numpy.ndarray:
checkpointed_ids = ray.get(block_ref)
if checkpointed_ids.num_rows == 0:
return numpy.array([])

combined_checkpointed_ids = combine_chunks(checkpointed_ids)
ckpt_chunks = combined_checkpointed_ids[self.id_column].chunks

checkpoint_ids_array = []
for ckpt_chunk in ckpt_chunks:
checkpoint_ids_array.append(
transform_pyarrow.to_numpy(ckpt_chunk, zero_copy_only=False)
)
result = numpy.concatenate(checkpoint_ids_array)

return result

def _validate_loaded_checkpoint(
self, schema: Schema, metadata: BlockMetadata
Expand Down Expand Up @@ -171,96 +145,101 @@ def _preprocess_data_pipeline(
return checkpoint_ds.sort(self.id_column)


class CheckpointFilter(abc.ABC):
"""Abstract class which defines the interface for filtering checkpointed rows
based on varying backends.
"""

def __init__(self, config: CheckpointConfig):
self.ckpt_config = config
self.checkpoint_path = self.ckpt_config.checkpoint_path
self.checkpoint_path_unwrapped = _unwrap_protocol(
self.ckpt_config.checkpoint_path
)
self.id_column = self.ckpt_config.id_column
self.filesystem = self.ckpt_config.filesystem
self.filter_num_threads = self.ckpt_config.filter_num_threads
self.checkpointed_ids = None


@ray.remote
class BatchBasedCheckpointFilter(CheckpointFilter):
"""CheckpointFilter for batch-based backends."""
"""CheckpointFilter for batch-based backends.

def load_checkpoint(self) -> ObjectRef[Block]:
"""Load checkpointed ids as a sorted block.
This is a global actor that holds checkpoint_ids array.
Every read task will send its input block to this actor and get the filtered result.
"""

Returns:
ObjectRef[Block]: ObjectRef to the checkpointed IDs block.
"""
def __init__(self, config: CheckpointConfig):
super().__init__(config)

# load checkpoint
loader = IdColumnCheckpointLoader(
checkpoint_path=self.checkpoint_path,
filesystem=self.filesystem,
id_column=self.id_column,
checkpoint_path_partition_filter=self.ckpt_config.checkpoint_path_partition_filter,
)
return loader.load_checkpoint()
self.checkpointed_ids = loader.load_checkpoint()

assert isinstance(self.checkpointed_ids, numpy.ndarray)

def ready(self):
return True

def delete_checkpoint(self) -> None:
self.filesystem.delete_dir(self.checkpoint_path_unwrapped)

def filter_rows_for_block(
self,
block: Block,
checkpointed_ids: Block,
) -> Block:
"""For the given block, filter out rows that have already
been checkpointed, and return the resulting block.

Args:
block: The input block to filter.
checkpointed_ids: A block containing IDs of all rows that have
been checkpointed.
Returns:
A new block with rows that have not been checkpointed.
"""

if len(checkpointed_ids) == 0 or len(block) == 0:
if self.checkpointed_ids.shape[0] == 0 or len(block) == 0:
return block

assert isinstance(block, pyarrow.Table)
assert isinstance(checkpointed_ids, pyarrow.Table)

# The checkpointed_ids block is sorted (see load_checkpoint).
# We'll use binary search to filter out processed rows.
# And we process a single chunk at a time, otherwise `to_numpy` below
# will copy the data from shared memory to worker's heap memory.

import concurrent.futures

# Get all chunks of the checkpointed ID column.
ckpt_chunks = checkpointed_ids[self.id_column].chunks
# Convert the block's ID column to a numpy array for fast processing.
block_ids = block[self.id_column].to_numpy()

def filter_with_ckpt_chunk(ckpt_chunk: pyarrow.ChunkedArray) -> numpy.ndarray:
# Convert checkpoint chunk to numpy for fast search.
# Use internal helper function for consistency and robustness (handles null-typed arrays, etc.)
ckpt_ids = transform_pyarrow.to_numpy(ckpt_chunk, zero_copy_only=False)
def filter_with_ckpt() -> numpy.ndarray:
# Start with a mask of all True (keep all rows).
mask = numpy.ones(len(block_ids), dtype=bool)
# Use binary search to find where block_ids would be in ckpt_ids.
sorted_indices = numpy.searchsorted(ckpt_ids, block_ids)
sorted_indices = numpy.searchsorted(self.checkpointed_ids, block_ids)
# Only consider indices that are within bounds.
valid_indices = sorted_indices < len(ckpt_ids)
valid_indices = sorted_indices < len(self.checkpointed_ids)
# For valid indices, check for exact matches.
potential_matches = sorted_indices[valid_indices]
matched = ckpt_ids[potential_matches] == block_ids[valid_indices]
matched = (
self.checkpointed_ids[potential_matches] == block_ids[valid_indices]
)
# Mark matched IDs as False (filter out these rows).
mask[valid_indices] = ~matched
# Delete the chunk to free memory.
del ckpt_chunk
return mask

# Use ThreadPoolExecutor to process each checkpoint chunk in parallel.
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.filter_num_threads or None
) as executor:
masks = list(executor.map(filter_with_ckpt_chunk, ckpt_chunks))
mask = filter_with_ckpt()

# Combine all masks using logical AND (row must not be in any checkpoint chunk).
final_mask = numpy.logical_and.reduce(masks)
# Convert the final mask to a PyArrow array and filter the block.
mask_array = pyarrow.array(final_mask)
mask_array = pyarrow.array(mask)
filtered_block = block.filter(mask_array)
return filtered_block

def filter_rows_for_batch(
self,
batch: DataBatch,
checkpointed_ids: Block,
) -> DataBatch:
"""For the given batch, filter out rows that have already
been checkpointed, and return the resulting batch.
Expand All @@ -269,6 +248,6 @@ def filter_rows_for_batch(
so it is preferred to call that method directly if you already have a block.
"""
arrow_block = BlockAccessor.batch_to_block(batch)
filtered_block = self.filter_rows_for_block(arrow_block, checkpointed_ids)
filtered_block = self.filter_rows_for_block(arrow_block)
filtered_batch = BlockAccessor.for_block(filtered_block).to_batch_format(None)
return filtered_batch
Loading