Skip to content

Conversation

@wxwmd
Copy link
Contributor

@wxwmd wxwmd commented Jan 9, 2026

Modification 1

I'm using Ray Data's checkpoint. My data has 115 million records, with primary key {"id": str}. When I use Checkpoint to filter the input blocks, it takes several hours.

I checked the performance bottleneck and found it occurs in the filter_with_ckpt_chunk function in checkpoint_filter.py. I add some logs:

# Get all chunks of the checkpointed ID column.
ckpt_chunks = checkpointed_ids[self.id_column].chunks
# Convert the block's ID column to a numpy array for fast processing.
block_ids = block[self.id_column].to_numpy()

def filter_with_ckpt_chunk(ckpt_chunk: pyarrow.ChunkedArray) -> numpy.ndarray:
    t1 = time.time()
    ckpt_ids = transform_pyarrow.to_numpy(ckpt_chunk, zero_copy_only=False)
    print(f"ckpt_ids to numpy cost time {time.time()-t1}s")
   
    ...
    t2 = time.time()
    sorted_indices = numpy.searchsorted(ckpt_ids, block_ids)
    print(f"searchsorted costs {time.time()-t2}s")

the ckpt_chunk has shape (115022113), and block_ids has shape (14534). I got:

ckpt_ids to numpy cost time: 6.057122468948364s
searchsorted costs 0.11587834358215332s

We can see from the perf test that:

  1. ckpt_chunks has only one chunk because we has combined chunks _combine_chunks
  2. the ckpt_chunk is a very large chunk that holds 115 millon ids, convert it from pyarrow to numpy will costs 6s
  3. For every input block, ckpt_ids = transform_pyarrow.to_numpy(ckpt_chunk, zero_copy_only=False) is executed once, causing a large time overhead.

This PR obtains the ckpt_id numpy array in advance, avoiding multiple calls. In my tests, this can reduce the filtering time from 5 hours to 40 minutes.

Notes:

In this PR, each read task needs to read the ckpt_ids(numpy.ndarray) from the object store, rather than Arrow format. This increases I/O and memory overhead because Arrow arrays usually costs less space. In my experiment, the pyarrow array(115 million rows, string-typed) used 1.7 GB of memory, while the numpy array used 9 GB. However, I this this memory overhead is acceptable because of the performance improvement.

Modification 2

When RayJob runs for the first time, there is no need to filter the input block(because there is no checkpoint). This PR makes some changes: when the checkpoint does not exist, directly return the input block without filtering it.

Signed-off-by: xiaowen.wxw <[email protected]>
@wxwmd wxwmd requested a review from a team as a code owner January 9, 2026 09:15
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization for checkpoint filtering by converting checkpointed IDs to a NumPy array once, rather than for every block. The changes are well-implemented and consistent across the modified files. My review includes a couple of suggestions to enhance code clarity and maintainability.

Comment on lines +230 to +245
def filter_with_ckpt() -> np.ndarray:
# Start with a mask of all True (keep all rows).
mask = numpy.ones(len(block_ids), dtype=bool)
mask = np.ones(len(block_ids), dtype=bool)
# Use binary search to find where block_ids would be in ckpt_ids.
sorted_indices = numpy.searchsorted(ckpt_ids, block_ids)
sorted_indices = np.searchsorted(checkpointed_ids, block_ids)
# Only consider indices that are within bounds.
valid_indices = sorted_indices < len(ckpt_ids)
valid_indices = sorted_indices < len(checkpointed_ids)
# For valid indices, check for exact matches.
potential_matches = sorted_indices[valid_indices]
matched = ckpt_ids[potential_matches] == block_ids[valid_indices]
matched = checkpointed_ids[potential_matches] == block_ids[valid_indices]
# Mark matched IDs as False (filter out these rows).
mask[valid_indices] = ~matched
# Delete the chunk to free memory.
del ckpt_chunk
return mask

# Use ThreadPoolExecutor to process each checkpoint chunk in parallel.
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.filter_num_threads or None
) as executor:
masks = list(executor.map(filter_with_ckpt_chunk, ckpt_chunks))

# Combine all masks using logical AND (row must not be in any checkpoint chunk).
final_mask = numpy.logical_and.reduce(masks)
mask = filter_with_ckpt()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The nested function filter_with_ckpt is only called once, so it can be inlined for better readability. Also, the comment on line 242 is a leftover from the previous implementation and should be removed.

        # Start with a mask of all True (keep all rows).
        mask = np.ones(len(block_ids), dtype=bool)
        # Use binary search to find where block_ids would be in ckpt_ids.
        sorted_indices = np.searchsorted(checkpointed_ids, block_ids)
        # Only consider indices that are within bounds.
        valid_indices = sorted_indices < len(checkpointed_ids)
        # For valid indices, check for exact matches.
        potential_matches = sorted_indices[valid_indices]
        matched = checkpointed_ids[potential_matches] == block_ids[valid_indices]
        # Mark matched IDs as False (filter out these rows).
        mask[valid_indices] = ~matched

Comment on lines +689 to +695
combined_ckpt_block = transform_pyarrow.combine_chunks(pyarrow_checkpointed_ids)

combine_ckpt_chunks = combined_ckpt_block[ID_COL].chunks
assert len(combine_ckpt_chunks) == 1
# Convert checkpoint chunk to numpy for fast search.
# Use internal helper function for consistency and robustness (handles null-typed arrays, etc.)
ckpt_ids = transform_pyarrow.to_numpy(combine_ckpt_chunks[0], zero_copy_only=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for converting a pyarrow Table to a numpy array of IDs is duplicated from _combine_chunks in checkpoint_filter.py. To improve maintainability, consider extracting this logic into a non-remote helper function in checkpoint_filter.py and calling it from both _combine_chunks and this test. This would avoid having to update the logic in two places if it ever changes.

@wxwmd wxwmd changed the title [Data] speedup ckpt filter 5x [Data] speedup checkpoint filter 5x Jan 9, 2026
Signed-off-by: xiaowen.wxw <[email protected]>

return combined_ckpt_block
combine_ckpt_chunks = combined_ckpt_block[id_column].chunks
assert len(combine_ckpt_chunks) == 1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assertion fails for checkpoint data exceeding 2 GiB

High Severity

The assertion assert len(combine_ckpt_chunks) == 1 incorrectly assumes combine_chunks always produces a single chunk. The _try_combine_chunks_safe function in transform_pyarrow.py explicitly returns a ChunkedArray with multiple chunks when data exceeds 2 GiB for variable-width types like strings (to avoid int32 offset overflow). Since the PR targets string-typed IDs with 115M+ records, larger checkpoints exceeding 2 GiB will cause this assertion to fail at runtime, crashing the checkpoint loading process.

Fix in Cursor Fix in Web

Signed-off-by: xiaowen.wxw <[email protected]>
# If checkpoint_path is not existed, checkpoint_existed=False
file_info = self.filesystem.get_file_info(self.checkpoint_path)
if file_info.type == pyarrow.fs.FileType.NotFound:
self.checkpoint_existed = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in variable name causes AttributeError and wrong result

High Severity

The code uses self.checkpoint_existed on lines 114 and 127, but the attribute is defined as self.is_checkpoint_existed on line 90. When the checkpoint path exists, self.checkpoint_existed is never set, causing an AttributeError at line 127. When the path doesn't exist, line 114 sets the wrong attribute, so self.is_checkpoint_existed remains True and the function returns an incorrect result indicating the checkpoint exists when it doesn't.

Additional Locations (1)

Fix in Cursor Fix in Web

@ray-gardener ray-gardener bot added data Ray Data-related issues community-contribution Contributed by the community labels Jan 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community data Ray Data-related issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant