Skip to content

[Data] Speed up checkpoint filter and reduce memory usage#60294

Open
wxwmd wants to merge 1 commit intoray-project:masterfrom
wxwmd:global_ckpt_filter
Open

[Data] Speed up checkpoint filter and reduce memory usage#60294
wxwmd wants to merge 1 commit intoray-project:masterfrom
wxwmd:global_ckpt_filter

Conversation

@wxwmd
Copy link
Contributor

@wxwmd wxwmd commented Jan 19, 2026

source code for issue #60200

Current checkpoint:

Image

The current implementation has two issues:

  1. Each ReadTask copies an Arrow-typed checkpoint_id array and then converts it into a Numpy-typed array. This step is very time-consuming(see previous testing) The most time-consuming operation is repeated in every ReadTask.
  2. Each ReadTask holds a copy of the checkpoint_id array, resulting in high memory usage of the cluster.

Improved Checkpoint:

Maintain a global checkpoint_filter actor that holds the checkpoint_ids array; this actor is responsible for filtering all input blocks.

Image

There are two advantages to this approach:

  1. The most time-consuming operation: the conversion from Arrow-typed array to Numpy-typed array is performed only once.
  2. Reduced memory usage: Each read task no longer needs to hold a large array; only the checkpoint_filter actor holds it.

Performance test

test code:

import shutil
from typing import Dict
import os
import time

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import ray
from ray.data.checkpoint import CheckpointConfig

INPUT_PATH="/tmp/ray_test/input/"
OUTPUT_PATH="/tmp/ray_test/output/"
CKPT_PATH="/tmp/ray_test/ckpt/"


class Qwen3ASRPredictor:
    def __init__(self):
        print("download ckpt")

    def __call__(self, batch_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        return batch_input

def setup():
    if os.path.exists(INPUT_PATH):
        shutil.rmtree(INPUT_PATH)
    if os.path.exists(CKPT_PATH):
        shutil.rmtree(CKPT_PATH)
    if os.path.exists(OUTPUT_PATH):
        shutil.rmtree(OUTPUT_PATH)

    # generate input data
    if not os.path.exists(INPUT_PATH):
        os.makedirs(INPUT_PATH)
    for i in range(10000):
        ids = [str(i) for i in range(i * 10000, (i + 1) * 10000)]
        df = pd.DataFrame({'id': ids})
        table = pa.Table.from_pandas(df)
        pq.write_table(table, os.path.join(INPUT_PATH, f"{i}.parquet"))

    # generate checkpoint
    if not os.path.exists(CKPT_PATH):
        os.makedirs(CKPT_PATH)
    ids = [str(i) for i in range(0, 80_000_000)]
    df = pd.DataFrame({'id': ids})
    table = pa.Table.from_pandas(df)
    pq.write_table(table, os.path.join(CKPT_PATH, "ckpt.parquet"))



if __name__ == "__main__":
    ray.init()

    setup()

    ctx = ray.data.DataContext.get_current()
    ctx.checkpoint_config = CheckpointConfig(
        id_column="id",
        checkpoint_path=CKPT_PATH,
        delete_checkpoint_on_success=False,
    )

    start_time = time.time()

    input = ray.data.read_parquet(
        INPUT_PATH,
        parallelism=1000,
        memory=8 * 1024 **3 # set for origin ray to avoid oom
    )

    pred = input.map_batches(Qwen3ASRPredictor, batch_size=1000)

    pred.write_parquet(OUTPUT_PATH)

    end_time = time.time()

    print(f"costs: {end_time - start_time}s")

    # check result
    result_ds = ray.data.read_parquet(OUTPUT_PATH)
    assert result_ds.count() == 20_000_000

node: 16 cores with 64GB memory (make sure you have memory at least 16GB to avoid oom)

origin ray:

pip install https://ray-wheel.oss-cn-beijing.aliyuncs.com/origin/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl
python test.py

Speedup:

pip install https://ray-wheel.oss-cn-beijing.aliyuncs.com/speedup/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl
python test.py

Test Result

origin: 680s
speedup: 190s
You can see that even the overall running time of the task has been accelerated by 3.6 times.

Memory

If we delete this row:

memory=8 * 1024 **3 # set for origin ray to avoid oom

original ray will oom, the fixed ray passed. This demonstrates that this PR has enhanced the stability.

@wxwmd wxwmd requested a review from a team as a code owner January 19, 2026 11:54
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the checkpoint filtering mechanism to use a global actor that holds checkpointed IDs as a NumPy array. This is a significant improvement for memory efficiency and performance by avoiding passing large checkpoint data to each task. The changes to the build and CI scripts seem appropriate for the internal environment.

My review identified a critical issue where the batch-based filtering path (filter_rows_for_batch and its caller) was not updated to align with the new actor-based architecture, which will lead to runtime errors. I also found a medium-severity performance issue due to using ray.get() inside a loop and a minor issue with a leftover debug print statement.

cursor[bot]

This comment was marked as outdated.

cursor[bot]

This comment was marked as outdated.

@ray-gardener ray-gardener bot added data Ray Data-related issues community-contribution Contributed by the community labels Jan 19, 2026
@wxwmd wxwmd force-pushed the global_ckpt_filter branch from b971082 to b2fc954 Compare January 20, 2026 03:33
cursor[bot]

This comment was marked as outdated.

@daiping8
Copy link
Contributor

daiping8 commented Jan 20, 2026

Great job! I have a few questions:

  • A single global actor may become a bottleneck.
    All Read-related filtering requests go through the same BatchBasedCheckpointFilter actor. Could this cause filtering requests to queue up and stall the entire read pipeline on checkpoint filtering?
  • checkpointed_ids is fully materialized as a numpy array and kept resident in memory.
    We call combine_chunks and then to_numpy once to build a single large ndarray.
    If the checkpoint is large, the actor process must have enough contiguous memory to hold the entire ID column.

Maybe we could turn the actor into a sharded design (multiple actors, partitioned by hash(id) or by ID ranges), and support a “partial loading + partial filtering” mode instead of materializing the entire ndarray at once.

@wxwmd
Copy link
Contributor Author

wxwmd commented Jan 21, 2026

Great job! I have a few questions:

Thanks.

  1. In my test, filtering requests are processed very fast. If checkpoint has 115millon rows, each block has 10k+ rows, each filtering request can be processed in 0.2s. See my log:
image

@wxwmd
Copy link
Contributor Author

wxwmd commented Jan 21, 2026

Great job! I have a few questions:

  • A single global actor may become a bottleneck.
    All Read-related filtering requests go through the same BatchBasedCheckpointFilter actor. Could this cause filtering requests to queue up and stall the entire read pipeline on checkpoint filtering?
  • checkpointed_ids is fully materialized as a numpy array and kept resident in memory.
    We call combine_chunks and then to_numpy once to build a single large ndarray.
    If the checkpoint is large, the actor process must have enough contiguous memory to hold the entire ID column.

Maybe we could turn the actor into a sharded design (multiple actors, partitioned by hash(id) or by ID ranges), and support a “partial loading + partial filtering” mode instead of materializing the entire ndarray at once.

  1. I think that is a good idea. I am interested in implementing in future

@wxwmd wxwmd force-pushed the global_ckpt_filter branch from b2fc954 to eb93d53 Compare January 23, 2026 08:43
cursor[bot]

This comment was marked as outdated.

@wxwmd
Copy link
Contributor Author

wxwmd commented Jan 23, 2026

@owenowenisme @raulchen please check this when you have time 😊

@wxwmd wxwmd force-pushed the global_ckpt_filter branch from eb93d53 to 0276b22 Compare January 23, 2026 09:16
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Signed-off-by: xiaowen.wxw <wxw403883@alibaba-inc.com>
job_id = ray.get_runtime_context().get_job_id()
self._ckpt_filter = BatchBasedCheckpointFilter.options(
name=f"checkpoint_filter_{job_id}",
lifetime="detached",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to make the actor detached?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will test the lifecycle of this actor

job_id = ray.get_runtime_context().get_job_id()
self._ckpt_filter = BatchBasedCheckpointFilter.options(
name=f"checkpoint_filter_{job_id}",
lifetime="detached",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you''re using actor now and the ThreadPoolExecutor is removed, can we use max_concurrency here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

youou mean using a thread pool to perform the filtering? I think that is a good idea, will implement it

Copy link
Member

@owenowenisme owenowenisme left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for reviewing this so late, and thanks for the beautiful diagram!

When I was reviewing your global actor approach, I have another idea. The actor introduces a serial bottleneck — every read task has to ship its block to the single actor for filtering and wait for the result back. Without max_concurrency, calls are processed one at a time, which could be a significant throughput regression from the old design where each worker filtered locally in parallel.

Instead, what if we keep the filtering local in each worker but broadcast the checkpoint IDs as a numpy array via the object store?
The approach would be:

  1. Load checkpoint data and convert to a sorted numpy array (your PR already does this in _postprocess_block — nice work on that part!)
  2. Use a remote task to do the heavy conversion, then ray.put() the numpy array into the object store
  3. Pass the ObjectRef to each read task via add_map_task_kwargs_fn (the old mechanism)
  4. Each worker calls ray.get(ref) to get a zero-copy read-only view from the local object store, then does searchsorted locally

This gives us:

  • Parallelism: filtering is parallel across all workers, no bottleneck
  • Memory efficiency: Ray's object store stores one copy per node in shared memory, and all workers on the same node share it via zero-copy
  • Minimize the re-computation of converting arrow blocks into numpy
  • Simplicity: No actor needed

@wxwmd
Copy link
Contributor Author

wxwmd commented Feb 10, 2026

Sorry for reviewing this so late, and thanks for the beautiful diagram!

When I was reviewing your global actor approach, I have another idea. The actor introduces a serial bottleneck — every read task has to ship its block to the single actor for filtering and wait for the result back. Without max_concurrency, calls are processed one at a time, which could be a significant throughput regression from the old design where each worker filtered locally in parallel.

Instead, what if we keep the filtering local in each worker but broadcast the checkpoint IDs as a numpy array via the object store? The approach would be:

  1. Load checkpoint data and convert to a sorted numpy array (your PR already does this in _postprocess_block — nice work on that part!)
  2. Use a remote task to do the heavy conversion, then ray.put() the numpy array into the object store
  3. Pass the ObjectRef to each read task via add_map_task_kwargs_fn (the old mechanism)
  4. Each worker calls ray.get(ref) to get a zero-copy read-only view from the local object store, then does searchsorted locally

This gives us:

  • Parallelism: filtering is parallel across all workers, no bottleneck
  • Memory efficiency: Ray's object store stores one copy per node in shared memory, and all workers on the same node share it via zero-copy
  • Minimize the re-computation of converting arrow blocks into numpy
  • Simplicity: No actor needed

hi youcheng, thanks for reviewing!

When I first solve this problem, I had the same idea as you: perform the Arrow->NumPy only once, then broadcast that NumPy array.

I implemented and tested this approach, and it has one issue: the NumPy array is too large.
For example, I have 100 million string IDs; storing them with Arrow takes 2 GB, but with NumPy it takes about ~10 GB. See the demo below:

import sys

import numpy as np

N = 10000_000 # set to 1kw to avoid oom
arr = np.array([f"text_{i}" for i in range(N)])

mem = arr.nbytes + sum(sys.getsizeof(s) for s in arr)

print(f"the 1kw arr costs {mem / 1024**3}GB, array of size 10000_0000 will costs {10 * mem / 1024**3}GB")

having each worker keep a ~10 GB object in memory is unacceptable. Our cluster has about 1,000 nodes, which means roughly 10,000 GB of memory would be used only for checkpoint.
This is also the second issue my diagram aims to address: redundant memory usage.

@owenowenisme
Copy link
Member

owenowenisme commented Feb 10, 2026

Got it, I think this is valid, one problem is that we should avoid this actor becoming bottleneck.
Do you have any plan to avoid this?
Also this could affect our back pressure right?

@wxwmd
Copy link
Contributor Author

wxwmd commented Feb 11, 2026

Got it, I think this is valid, one problem is that we should avoid this actor becoming bottleneck. Do you have any plan to avoid this? Also this could affect our back pressure right?

Yes, I will use concurrency to enhance the actor in this PR.
As for backpressure: yes, the actor’s speed is the limiting factor, some read tasks will have to wait. However, based on the experiments above, the read task is still faster than it is now.

@wxwmd
Copy link
Contributor Author

wxwmd commented Feb 11, 2026

Got it, I think this is valid, one problem is that we should avoid this actor becoming bottleneck. Do you have any plan to avoid this? Also this could affect our back pressure right?

@owenowenisme what do you think if we implement a checkpoint actor-pool? i think this will solve the single-actor bottleneck

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community data Ray Data-related issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants