Skip to content

[trainer,vllm] feat: add a recipe to support reorder rollout#4160

Open
echo-rain wants to merge 8 commits intoverl-project:mainfrom
echo-rain:main_dev
Open

[trainer,vllm] feat: add a recipe to support reorder rollout#4160
echo-rain wants to merge 8 commits intoverl-project:mainfrom
echo-rain:main_dev

Conversation

@echo-rain
Copy link
Contributor

@echo-rain echo-rain commented Nov 17, 2025

What does this PR do?

The long-tail effect in the current rollout phase is more pronounced in scenarios with long max response lengths. Therefore, we need a solution to address the long-tail effect in this scenario, thereby reducing the overall training time.

20251110161447

The PR strategy for this project comes from RollPacker

Add multiple requests during the data preparation phase. Once the required number of requests for training is met, end all training rollout requests and save them to a list. In a later step, re-add them in a long-tail batch.

experimental data in GRPO-Qwen3-8b_bf16-tp2-ep1 gen_batch_size=18 train_batch_size=16
normal mean ablation experimental group
image

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: #2200 PR #2929 PR
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

example: examples/grpo_trainer/run_qwen3-8b_reorder.sh

API and Usage Example

You can enable this feature by adding a batch size larger than the training batch size.
data.train_batch_size=8 +data.gen_batch_size=10

Design & Code Changes

1. Data preparation stage

Modify the existing data preparation logic, adding a new variable to configure the number of additional data records that need to be prepared. Add additional logic during data preparation to use identified long-tail requests to form batches.
20251110200626

2. rollout post-process

The original asycio.gather check logic for the end of a rollout needs to be replaced with a check of the number of completed rollout requests. At the same time, unprocessed requests need to be marked as long-tail requests, added to a new long-tail request list, and the abort interface used to abort the request.
20251110200645

3. reload req

Here we offer two options as solutions: adding in small amounts and multiple times, and adding all at once in multiple batches.

  1. The meaning of adding a small number of requests is that long-tail requests from each step will be added to the next round of rollout1. After the requests are added, the long-tail request list is cleared2. However, due to the limited learning rate, requests that exhibited a long-tail effect in the previous round are likely to exhibit a long-tail effect in this round as well. This may lead to a rollback phenomenon, i.e., a rollback to the original implementation that uses the newly added requests. However, if the one-step-off-policy strategy of using the streaming return results of the previous round as the prompt input3 can be used, the rollback phenomenon can be largely avoided.
    20251110203121

  2. Adding multiple requests together means that all long-tail requests are added to the training process to execute a complete training step, provided they form a complete batch. The drawback of this approach is that long-sequence requests that should be distributed across multiple training steps are concentrated into a single step, which may limit the learning rate and prevent effective learning of long-sequence representation capabilities.
    normal step:
    20251110193215
    unfinished req step:
    20251110193503

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new recipe to handle long-tail rollouts by reordering them, which is a valuable addition for optimizing training time. The implementation adds new Reorder variants of AgentLoopManager and RayPPOTrainer, with the core logic relying on separate queues for finished and unfinished rollouts. My review has identified several critical issues in the asynchronous and queue handling logic that could lead to deadlocks or incorrect behavior. I've also pointed out a minor issue with a duplicated decorator. Addressing these points will be crucial for the stability and correctness of this new feature.

kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}
self.tasks.append(asyncio.create_task(self._run_agent_loop(sampling_params, trajectory_info[i], **kwargs)))

results = await asyncio.gather(*self.tasks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to asyncio.gather is missing return_exceptions=True. Without it, if any of the awaited tasks are cancelled, asyncio.gather will raise a CancelledError immediately, and the subsequent loop to handle finished and unfinished tasks will not be executed. This would break the core logic of reordering rollouts.

Suggested change
results = await asyncio.gather(*self.tasks)
results = await asyncio.gather(*self.tasks, return_exceptions=True)

ray.get(worker.set_queue.remote(self.queue))
ray.get(self.queue.append_worker.remote(worker))

ray.get(worker.set_unfinished_queue.remote(self.queue))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The unfinished_queue is incorrectly set to self.queue during initialization. This will cause DataProto objects from unfinished tasks to be mixed with _InternalAgentLoopOutput objects from finished tasks in the same queue. The _postprocess function, which expects only _InternalAgentLoopOutput objects, will then fail. This line should be removed as the correct unfinished_queue is set later via set_unfinished_queue.

Comment on lines +258 to +268
def _get_all_from_queue(queue, max_item=None):
items = []
while not queue.empty():
try:
item = queue.get()
items.append(item)
except asyncio.QueueEmpty:
break
if not max_item and len(items) >= max_item:
break
return items
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function _get_all_from_queue is intended to drain a ray.util.queue.Queue but has several critical bugs:

  1. It uses a blocking queue.get() call within a while not queue.empty() loop, which is prone to race conditions and can cause the process to hang indefinitely.
  2. It incorrectly catches asyncio.QueueEmpty instead of queue.Empty from Python's standard library, which is what ray.util.queue.Queue raises.
  3. The logic for max_item is flawed and will raise a TypeError if max_item is None.

To fix this, the function should use get_nowait() in a loop and handle queue.Empty correctly. You will also need to import queue at the top of the file, and import asyncio can be removed as it's no longer used.

Suggested change
def _get_all_from_queue(queue, max_item=None):
items = []
while not queue.empty():
try:
item = queue.get()
items.append(item)
except asyncio.QueueEmpty:
break
if not max_item and len(items) >= max_item:
break
return items
def _get_all_from_queue(queue, max_item=None):
items = []
while True:
if max_item is not None and len(items) >= max_item:
break
try:
item = queue.get_nowait()
items.append(item)
except queue.Empty:
break
return items



@ray.remote
@ray.remote
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The @ray.remote decorator is duplicated. This is likely a copy-paste error and should be removed to avoid confusion.

@echo-rain echo-rain changed the title [WIP][trainer] feat: add a recipe to support reorder rollout [trainer] feat: add a recipe to support reorder rollout Nov 19, 2025
@echo-rain echo-rain changed the title [trainer] feat: add a recipe to support reorder rollout [recipe,trainer] feat: add a recipe to support reorder rollout Nov 19, 2025
@echo-rain
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new recipe to support reordering rollouts, aiming to mitigate the long-tail effect in training. The changes are substantial, adding new components like ReorderAgentLoopManager, AgentLoopReorderWorker, and RayPPOReorderTrainer. While the overall approach is sound, I have identified a few critical issues related to race conditions, incorrect API usage, and a lack of robustness in data processing. These issues could lead to deadlocks or runtime crashes and should be addressed to ensure the stability and correctness of this new feature.

Comment on lines +121 to +122
if inputs[0].response_logprobs is not None:
optional_outputs["rollout_log_probs"] = torch.cat([input.response_logprobs for input in inputs], dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current logic for handling response_logprobs is not robust and can lead to a crash. It checks if inputs[0].response_logprobs is not None and then attempts to concatenate response_logprobs from all inputs. If some inputs have response_logprobs as None while others have tensors, torch.cat will fail with a TypeError.

You should handle the case where response_logprobs can be a mix of tensors and None values in the inputs list. A possible solution is to replace None values with a tensor of zeros of the correct shape before concatenation.

    first_valid_logprob = next((inp.response_logprobs for inp in inputs if inp.response_logprobs is not None), None)
    if first_valid_logprob is not None:
        dummy_logprobs = torch.zeros_like(first_valid_logprob)
        logprobs_to_cat = [
            inp.response_logprobs if inp.response_logprobs is not None else dummy_logprobs for inp in inputs
        ]
        optional_outputs["rollout_log_probs"] = torch.cat(logprobs_to_cat, dim=0)

Comment on lines +260 to +270
def _get_all_from_queue(queue, max_item=None):
items = []
while not queue.empty():
try:
item = queue.get()
items.append(item)
except asyncio.QueueEmpty:
break
if max_item and len(items) >= max_item:
break
return items
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of _get_all_from_queue has a critical bug that can lead to a deadlock.

  1. It uses queue.get() which is a blocking call, inside a while not queue.empty() loop. This is a race condition: if the queue becomes empty between the check and the call, get() will block forever.
  2. It catches asyncio.QueueEmpty, but ray.util.queue.Queue raises queue.Empty from Python's standard library for non-blocking gets.

A safer implementation would be to use queue.get_nowait() within a try...except queue.Empty block. Please also import the queue module.

Suggested change
def _get_all_from_queue(queue, max_item=None):
items = []
while not queue.empty():
try:
item = queue.get()
items.append(item)
except asyncio.QueueEmpty:
break
if max_item and len(items) >= max_item:
break
return items
def _get_all_from_queue(queue, max_item=None):
import queue
items = []
num_to_get = queue.qsize()
if max_item is not None:
num_to_get = min(num_to_get, max_item)
for _ in range(num_to_get):
try:
item = queue.get_nowait()
items.append(item)
except queue.Empty:
break
return items

ray.get(worker.set_queue.remote(self.queue))
ray.get(self.queue.append_worker.remote(worker))

ray.get(worker.set_unfinished_queue.remote(self.queue))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line incorrectly sets the unfinished_queue on the worker to be a QueueMonitor actor handle, while the worker expects a ray.util.queue.Queue object. This would cause a runtime AttributeError when put_async is called on it.

Although this is currently overwritten by a subsequent call in RayPPOReorderTrainer, this line is buggy, confusing, and should be removed to avoid future issues.

@echo-rain
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a feature to reorder rollouts to mitigate the long-tail effect, which is a great initiative for optimizing training time. The changes are extensive, touching upon agent loops, the PPO trainer, and adding new queue utilities. While the overall logic for reordering seems sound, I've identified two critical issues in the implementation that could lead to process hangs or crashes. One is related to improper exception handling during task cancellation, and the other concerns incorrect queue handling in a multiprocess context. Addressing these issues is crucial for the stability of the new feature.

Comment on lines +152 to +164
try:
server = self._choose_server(request_id)
task = server.generate.remote(
request_id=uuid4().hex, # use new request_id for each turn
prompt_ids=prompt_ids,
sampling_params=sampling_params,
image_data=image_data,
)
self.ray_tasks.append(task)
output = await task
return output
except Exception as e:
logger.error(f"server manager got exception: {e}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The try...except Exception block is too broad. It catches ray.exceptions.TaskCancelledError, logs it as an error, and then implicitly returns None. The caller, _run_agent_loop, does not expect None and will crash with an AttributeError when trying to access attributes on the None object. This will prevent the cancellation logic from working correctly.

Cancellation errors should be handled specifically, and other exceptions should be re-raised to avoid hiding bugs. I suggest handling ray.exceptions.TaskCancelledError by raising an asyncio.CancelledError to correctly propagate cancellation in the asyncio context, and re-raising other exceptions.

Suggested change
try:
server = self._choose_server(request_id)
task = server.generate.remote(
request_id=uuid4().hex, # use new request_id for each turn
prompt_ids=prompt_ids,
sampling_params=sampling_params,
image_data=image_data,
)
self.ray_tasks.append(task)
output = await task
return output
except Exception as e:
logger.error(f"server manager got exception: {e}")
try:
server = self._choose_server(request_id)
task = server.generate.remote(
request_id=uuid4().hex, # use new request_id for each turn
prompt_ids=prompt_ids,
sampling_params=sampling_params,
image_data=image_data,
)
self.ray_tasks.append(task)
output = await task
return output
except ray.exceptions.TaskCancelledError as e:
raise asyncio.CancelledError from e
except Exception as e:
logger.error(f"server manager got exception: {e}")
raise

Comment on lines +264 to +283
def _get_all_from_queue(queue, max_item=None) -> list[Any]:
"""
Get all items from a queue.
Args:
queue: queue to get items from.
max_item: limit the number of items to return.

Returns:
list[Any]: list of items.
"""
items = []
while not queue.empty():
try:
item = queue.get_nowait()
items.append(item)
except asyncio.QueueEmpty:
break
if max_item and len(items) >= max_item:
break
return items
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function _get_all_from_queue has two critical issues:

  1. It uses except asyncio.QueueEmpty. The queue being used is a ray.util.queue.Queue, which is a multiprocess queue. Its get_nowait() method raises queue.Empty (aliased as ray.exceptions.QueueEmpty), not asyncio.QueueEmpty. This means the exception will never be caught.
  2. The while not queue.empty(): check is a race condition in a multiprocess context. The queue can become empty between the check and the get_nowait() call. Combined with the wrong exception type, this can lead to an infinite loop.

The combination of these issues will likely cause the training process to hang. The correct way to drain a queue is to loop indefinitely and break when get_nowait() raises the appropriate empty exception.

Suggested change
def _get_all_from_queue(queue, max_item=None) -> list[Any]:
"""
Get all items from a queue.
Args:
queue: queue to get items from.
max_item: limit the number of items to return.
Returns:
list[Any]: list of items.
"""
items = []
while not queue.empty():
try:
item = queue.get_nowait()
items.append(item)
except asyncio.QueueEmpty:
break
if max_item and len(items) >= max_item:
break
return items
def _get_all_from_queue(queue, max_item=None) -> list[Any]:
"""
Get all items from a queue.
Args:
queue: queue to get items from.
max_item: limit the number of items to return.
Returns:
list[Any]: list of items.
"""
items = []
while True:
try:
if max_item and len(items) >= max_item:
break
item = queue.get_nowait()
items.append(item)
except ray.exceptions.QueueEmpty:
break
return items

@echo-rain
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a 'reorder rollout' mechanism to mitigate the long-tail effect in training, especially for responses with long sequences. The core idea is to identify long-running rollout requests, cancel them to avoid blocking a training step, and then group these cancelled 'long-tail' requests into separate batches for processing in later steps. This is achieved by introducing a queueing system for rollout tasks, allowing the trainer to proceed with completed tasks while unfinished ones are requeued.

My review identifies two critical issues. First, in verl/experimental/agent_loop/agent_loop.py, the _postprocess function can fail with an IndexError if it receives an empty list of inputs, which can occur if all rollout tasks are cancelled. Second, in verl/trainer/ppo/ray_trainer.py, the main training loop in the fit method is structured in a way that processing requeued long-tail batches introduces extra training steps. This incorrectly increments global_steps, which will disrupt learning rate schedules, validation frequency, and checkpointing. These issues need to be addressed to ensure the correctness and robustness of the new feature.

Comment on lines +307 to +310
def _postprocess(inputs: list[_InternalAgentLoopOutput]) -> DataProto:
"""Process the padded outputs from _run_agent_loop and combine them into a batch."""
# Convert lists back to tensors and stack them to create a batch.
prompt_ids = torch.cat([input.prompt_ids for input in inputs], dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _postprocess function does not handle the case where inputs is an empty list. If inputs is empty, torch.cat on an empty list of tensors can raise an error (depending on the PyTorch version), and accessing inputs[0] at line 317 will definitely raise an IndexError. This scenario can occur in the reorder rollout flow if no requests are completed before being cancelled. The function should gracefully handle an empty list, for example by returning an empty DataProto. Note that this might have downstream effects, for example in _performance_metrics, which may also need to be updated to handle empty inputs.

Suggested change
def _postprocess(inputs: list[_InternalAgentLoopOutput]) -> DataProto:
"""Process the padded outputs from _run_agent_loop and combine them into a batch."""
# Convert lists back to tensors and stack them to create a batch.
prompt_ids = torch.cat([input.prompt_ids for input in inputs], dim=0)
def _postprocess(inputs: list[_InternalAgentLoopOutput]) -> DataProto:
"""Process the padded outputs from _run_agent_loop and combine them into a batch."""
if not inputs:
return DataProto(
batch=TensorDict({}, batch_size=[0]),
non_tensor_batch={},
meta_info={"metrics": [], "reward_extra_keys": []},
)
# Convert lists back to tensors and stack them to create a batch.
prompt_ids = torch.cat([input.prompt_ids for input in inputs], dim=0)

Comment on lines 1017 to +1078
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}

with marked_timer("start_profile", timing_raw):
self._start_profiling(
not prev_step_profile and curr_step_profile
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
)
train_batch_size = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
if self.reorder_rollout and self.unfinished_queue.qsize() >= train_batch_size:
# When using reordering, some incomplete long-tail requests are generated,
# so these requests need to be grouped into an additional batch for inference.
unfinished_batch_list = _get_all_from_queue(self.unfinished_queue, max_item=train_batch_size)
batch = collate_fn(unfinished_batch_list)
self._step(batch, logger, epoch, 1)

batch: DataProto = DataProto.from_single_dict(batch_dict)
batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
self._step(batch, logger, epoch, self.config.actor_rollout_ref.rollout.n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The fit loop is structured to iterate over self.train_dataloader. However, with the reorder_rollout feature, an extra training step (self._step) is performed whenever the unfinished_queue has enough items to form a batch. Since self._step increments self.global_steps, this leads to more training steps than originally planned, as self.total_training_steps is calculated based on the dataloader size. This will cause several issues:

  1. The training will run for more steps than total_training_steps.
  2. The tqdm progress bar will go beyond its total.
  3. Learning rate schedules, validation frequency, and checkpointing frequency, which are all based on self.global_steps, will be incorrect.

A possible solution would be to restructure the training loop to be based on self.global_steps (e.g., using a while loop) instead of iterating over the dataloader directly, and to handle fetching batches from either the dataloader or the unfinished_queue within that loop.

Signed-off-by: shanyicheng.syc <shanyicheng.syc@antgroup.com>
Signed-off-by: shanyicheng.syc <shanyicheng.syc@antgroup.com>
Signed-off-by: shanyicheng.syc <shanyicheng.syc@antgroup.com>
Signed-off-by: shanyicheng.syc <shanyicheng.syc@antgroup.com>
Signed-off-by: shanyicheng.syc <shanyicheng.syc@antgroup.com>
Signed-off-by: shanyicheng.syc <shanyicheng.syc@antgroup.com>
… and inference batch size caused step offset.

Signed-off-by: shanyicheng.syc <shanyicheng.syc@antgroup.com>
Signed-off-by: shanyicheng.syc <shanyicheng.syc@antgroup.com>
@echo-rain echo-rain changed the title [recipe,trainer] feat: add a recipe to support reorder rollout [trainer] feat: add a recipe to support reorder rollout Nov 25, 2025
@echo-rain echo-rain changed the title [trainer] feat: add a recipe to support reorder rollout [trainer,vllm] feat: add a recipe to support reorder rollout Nov 25, 2025
@yang-ybb
Copy link

@echo-rain when train_batch_size and gen_batch_size is small (like 8 and 10), rollpacker logic works. But when increase train_batch_size to 128, gen_batch_size to 130 and rollout.n to 16, vllm rollout time become very slow, seems related to these code

await self.queue.put.remote(
(
        raw_item,
        loop_output,
)

@echo-rain
Copy link
Contributor Author

@yang-ybb It seems unlikely to be a queue issue. Although the queue lacks multi-process concurrency optimizations, this design is intended to prevent a large number of requests from completing within the same time window due to variable request completion times.
With extremely large batches, the more probable problem is the slow inference speed of the rollout itself.
However, the absence of parallel write logic does indeed limit queue write speed under high concurrency. Could you provide more relevant data so I can assess whether adding concurrency logic is necessary?

@yang-ybb
Copy link

yang-ybb commented Dec 23, 2025

@echo-rain ok~

  1. use rollpacker, train_batch_size=128, gen_batch_size=129, rollout.n=16
    first add time logs for self.queue.put, and results shows below. The execution time for the first few items is relatively normal, but the later ones take significantly longer, reaching 26 seconds. Total vllm rollout finish use 245s
Training Progress:   0%|          | 0/41500 [00:00<?, ?it/s]
(AgentLoopWorker pid=1704607) The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release. [repeated 7x across cluster]
(TaskRunner pid=1695273) ---------------global_steps 1 starts at 1766473912.1881988----------------
(AsyncvLLMServer pid=1699817) (EngineCore_0 pid=1701800) WARNING 12-23 07:11:30 [profiling.py:280] The sequence length (1344) is smaller than the pre-defined worst-case total number of multimodal tokens (32768). This may cause certain multi-modal inputs to fail during inference. To avoid this, you should increase `max_model_len` or reduce `mm_counts`. [repeated 7x across cluster]
(AsyncvLLMServer pid=1699817) WARNING 12-23 07:11:31 [__init__.py:1625] Default sampling parameters have been overridden by the model's Hugging Face generation config recommended from the model creator. If this is not intended, please relaunch vLLM instance with `--generation-config vllm`. [repeated 7x across cluster]
(RewardManagerWorker pid=1706240) The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release. [repeated 6x across cluster]
(AgentLoopWorker pid=1704612) You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
(RewardManagerWorker pid=1706238) The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release. [repeated 2x across cluster]
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.18525314331054688
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.14816617965698242
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.15647530555725098
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.22381877899169922
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.1943674087524414
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.17458271980285645
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.126600980758667
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.1065363883972168
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.11809802055358887
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.10920262336730957
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.08970022201538086
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.08512425422668457
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.11975407600402832
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.0790250301361084
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.10097265243530273
(AgentLoopWorker pid=1704612) ============in agent_loop queue put cost 0.06738114356994629
(AgentLoopWorker pid=1704607) ============in agent_loop queue put cost 0.08369851112365723 [repeated 58x across cluster]
(AgentLoopWorker pid=1704604) You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding. [repeated 6x across cluster]
(AgentLoopWorker pid=1704605) ============in agent_loop queue put cost 0.30245041847229004 [repeated 152x across cluster]
(AgentLoopWorker pid=1704605) ============in agent_loop queue put cost 3.4407339096069336 [repeated 148x across cluster]
(AgentLoopWorker pid=1704605) ============in agent_loop queue put cost 7.087812662124634 [repeated 199x across cluster]
(AgentLoopWorker pid=1704605) ============in agent_loop queue put cost 10.952093839645386 [repeated 260x across cluster]
(AgentLoopWorker pid=1704610) ============in agent_loop queue put cost 14.84812307357788 [repeated 257x across cluster]
(AgentLoopWorker pid=1704610) ============in agent_loop queue put cost 18.69814658164978 [repeated 237x across cluster]
(AgentLoopWorker pid=1704610) ============in agent_loop queue put cost 22.471090078353882 [repeated 200x across cluster]
(AgentLoopWorker pid=1704606) ============in agent_loop queue put cost 26.274360418319702 [repeated 250x across cluster]
(AgentLoopWorker pid=1704604) ============in agent_loop queue put cost 26.876274585723877 [repeated 232x across cluster]
(TaskRunner pid=1695273) ---_run_reorder_worker_generate_sequences tasks finish: 1766473980.4850829. cost 65.91393113136292
(AgentLoopWorker pid=1704607) ============in agent_loop queue put cost 26.489635467529297 [repeated 39x across cluster]
(TaskRunner pid=1695273) ---_run_reorder_worker_generate_sequences queue clear finish: 1766474156.4789267. cost 241.90777111053467
(TaskRunner pid=1695273) ---------------global_steps 1 generate_sequences finished at 1766474157.9028032 costs 245.71460437774658----------------
(TaskRunner pid=1695273) ---------------global_steps 1 rewards finished at 1766474160.5839174 costs 248.39571857452393----------------
(TaskRunner pid=1695273) on_policy=True, skip compute_old_log_prob.
(TaskRunner pid=1695273) ---------------global_steps 1 old_log_prob finished at 1766474160.584271 costs 248.39607214927673----------------
(TaskRunner pid=1695273) ---------------global_steps 1 cal_ref finished at 1766474160.5842805 costs 248.3960816860199----------------
(TaskRunner pid=1695273) ---------------global_steps 1 cal_adv finished at 1766474169.90291 costs 257.71471118927----------------
(TaskRunner pid=1695273) ---------------global_steps 1 cal_critic finished at 1766474169.9029684 costs 257.7147696018219----------------

second shows gpu utl
image

  1. no rollpacker, train_batch_size=128, rollout.n=16
    Total vllm rollout finish use 35s
Training Progress:   0%|          | 0/41500 [00:00<?, ?it/s]
(TaskRunner pid=1715642) ---------------global_steps 1 starts at 1766474490.5935266----------------
(AgentLoopWorker pid=1724728) The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
(RewardManagerWorker pid=1726250) The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release. [repeated 8x across cluster]
(AgentLoopWorker pid=1724730) You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
(RewardManagerWorker pid=1727550) The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release. [repeated 7x across cluster]
(AgentLoopWorker pid=1724733) You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding. [repeated 6x across cluster]
(TaskRunner pid=1715642) ---------------global_steps 1 generate_sequences finished at 1766474526.0778503 costs 35.48432374000549----------------
(AsyncvLLMServer pid=1719964) (EngineCore_0 pid=1721927) WARNING 12-23 07:21:26 [profiling.py:280] The sequence length (1344) is smaller than the pre-defined worst-case total number of multimodal tokens (32768). This may cause certain multi-modal inputs to fail during inference. To avoid this, you should increase `max_model_len` or reduce `mm_counts`. [repeated 7x across cluster]
(AsyncvLLMServer pid=1719964) WARNING 12-23 07:21:27 [__init__.py:1625] Default sampling parameters have been overridden by the model's Hugging Face generation config recommended from the model creator. If this is not intended, please relaunch vLLM instance with `--generation-config vllm`. [repeated 7x across cluster]
(TaskRunner pid=1715642) ---------------global_steps 1 rewards finished at 1766474528.536392 costs 37.9428653717041----------------
(TaskRunner pid=1715642) on_policy=True, skip compute_old_log_prob.
(TaskRunner pid=1715642) ---------------global_steps 1 old_log_prob finished at 1766474528.5366735 costs 37.94314694404602----------------
(TaskRunner pid=1715642) ---------------global_steps 1 cal_ref finished at 1766474528.5366833 costs 37.943156719207764----------------
(TaskRunner pid=1715642) ---------------global_steps 1 cal_adv finished at 1766474537.8688495 costs 47.275322914123535----------------
(TaskRunner pid=1715642) ---------------global_steps 1 cal_critic finished at 1766474537.8689034 costs 47.27537679672241----------------
image

@echo-rain
Copy link
Contributor Author

@yang-ybb got it thank you.
Based on the data, the queue's insertion time is not as expected. I will try to reproduce this phenomenon locally and attempt to fix it.

@echo-rain
Copy link
Contributor Author

@yang-ybb To be precise, this performance issue isn't caused by concurrency. The root cause is that the put method of the queue serializes the data, and the large size of the batch data class causes this serialization to be very slow.
Passing pointers in the queue might be a potentially better solution.
I will fix this problem as soon as possible.

@yang-ybb
Copy link

@echo-rain yes, it may be put large data use ray actor. time cost normal after change to self.queue.put.remote(1), and put results in each AgentLoopWorker. and in AgentLoopManager, use worker.get_results.remote() can gather all results to main thread.

# AgentLoopWorker._run_agent_loop
if self.reorder_rollout:
    await self.queue.put.remote(1)
    self.finish_results.append((raw_item, loop_output))
    return
else:
    return loop_output
# AgentLoopManager._run_reorder_worker_generate_sequences

- hybird_out = ray.get(self.queue.clear.remote())
+ hybird_out = ray.get([
    worker.get_results.remote()
    for worker in self.agent_loop_workers
])

@echo-rain
Copy link
Contributor Author

@yang-ybb Great idea, thanks for the suggestion. However, in practice, any remote call will serialize the transmitted data. Perhaps a better implementation would be to store the larger tensor in the main process and mark it with some kind of pointer, then queue the pointer for transmission.

@cboss6
Copy link
Contributor

cboss6 commented Jan 6, 2026

@echo-rain
I tried reorder_rollout with Qwen3-8B and found reorder_rollout affects the convergence of the training:
Clipboard_Screenshot_1767701854
Clipboard_Screenshot_1767701879

The normal behavior should be(the only difference in script is whether gen_train_size equals to train_batch_size):
Clipboard_Screenshot_1767702062
Clipboard_Screenshot_1767702078

Do you have any comments? As far as I know, theoretically RollPacker shouldn’t affect the convergence of RL training.

@echo-rain
Copy link
Contributor Author

@cboss6 Thank you for your comment. In fact, we have already noticed this issue, which is why this PR has not been merged for so long.The results show that RollPacker does have an impact on training performance.In my opinion, adding long requests to the same batch leads to two potential problems. First, the aggregation of long requests causes a sharp increase in length-related penalties. Second, long requests are more easily truncated, resulting in generally lower reward scores.
20260107114310

Of course, this is just my conjecture. If you have better conclusions or solutions, please feel free to discuss them with me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants