diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index da9629a56b4..6ce9c0d05e2 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -251,6 +251,14 @@ jobs: run: | ray stop --force ENGINE=sglang bash tests/special_e2e/ppo_trainer/run_function_reward.sh + - name: Running GSM8K E2E training tests on sglang async + run: | + ray stop --force + ENGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh + - name: Running GSM8K E2E training tests on vllm async + run: | + ray stop --force + ENGINE=vllm ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh e2e_ppo_trainer_sglang_multiturn_with_tool: runs-on: [L20x8] diff --git a/tests/test_data.jsonl b/tests/test_data.jsonl new file mode 100644 index 00000000000..e5fbfd674de --- /dev/null +++ b/tests/test_data.jsonl @@ -0,0 +1,6 @@ +{"data": [1], "text": [1, 1], "index": 0} +{"data": [2], "text": [2, 2], "index": 3} +{"data": [3], "text": [3, 3], "index": 5} +{"data": [4], "text": [4, 4], "index": 1} +{"data": [5], "text": [5, 5], "index": 4} +{"data": [6], "text": [6, 6], "index": 2} diff --git a/tests/test_protocol_on_cpu.py b/tests/test_protocol_on_cpu.py index 3502bfb34da..1116e07ec62 100644 --- a/tests/test_protocol_on_cpu.py +++ b/tests/test_protocol_on_cpu.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import random +import datasets import numpy as np import pytest import torch from tensordict import TensorDict +from torchdata.stateful_dataloader import StatefulDataLoader from verl import DataProto -from verl.protocol import union_numpy_dict, union_tensor_dict +from verl.protocol import LinkedDataProto, SequentialPrefetchOnPolicyDataProto, SequentialPrefetchPostponeOffPolicyDataProto, union_numpy_dict, union_tensor_dict def test_union_tensor_dict(): @@ -499,3 +502,147 @@ def test_dataproto_chunk_after_index(): selected = data[torch_int_mask] assert isinstance(selected.batch.batch_size, torch.Size) assert all(isinstance(d, int) for d in selected.batch.batch_size) + + +def get_dataloader(batch_size): + def collate_fn(batch): + flattened_data = [item["data"] for item in batch] + flattened_text = [item["text"] for item in batch] + flattened_index = [item["index"] for item in batch] + return {"data": torch.tensor(flattened_data), "text": np.array(flattened_text), "index": np.array(flattened_index)} + + train_dataset = datasets.load_dataset("json", data_files=[os.path.join(os.path.dirname(__file__), "test_data.jsonl")]) + dataloader = StatefulDataLoader( + dataset=train_dataset["train"], + batch_size=batch_size, + num_workers=0, + drop_last=True, + collate_fn=collate_fn, + ) + return dataloader + + +def test_linked_data_proto_next(): + """Test the next() method of LinkedDataProto with real DataLoader.""" + + batch_size = 2 + num_total_batches = 3 + dataloader = get_dataloader(batch_size) + + assert dataloader is not None, "DataLoader creation failed" + + assert len(dataloader) > 0, "DataLoader is empty, check test_data.jsonl and batch_size" + + # Initialize LinkedDataProto + batch = LinkedDataProto(dataloader=dataloader) + + # Test normal iteration + count = 0 + while True: + next_batch = batch.next() + if next_batch is None: + assert batch._dataloader_exhausted, "dataloader exhausted should be True" + break # End of epoch + batch = next_batch + popped_data = batch.pop(batch_keys=["data"], non_tensor_batch_keys=["text"]) + expected_data = torch.tensor([[count * batch_size + 1], [count * batch_size + 2]]) + assert torch.equal(popped_data.batch["data"], expected_data) + + expected_text = np.array([[count * batch_size + 1, count * batch_size + 1], [count * batch_size + 2, count * batch_size + 2]]) + assert np.array_equal(popped_data.non_tensor_batch["text"], expected_text) + + # fetch data from next batch + next_batch = batch.next() + if count < num_total_batches - 1: + popped_data = next_batch.select(batch_keys=["data"], non_tensor_batch_keys=["text"]) + expected_data = torch.tensor([[count * batch_size + 3], [count * batch_size + 4]]) + assert torch.equal(popped_data.batch["data"], expected_data) + + expected_text = np.array([[count * batch_size + 3, count * batch_size + 3], [count * batch_size + 4, count * batch_size + 4]]) + assert np.array_equal(popped_data.non_tensor_batch["text"], expected_text) + else: + assert next_batch is None, "next batch should be None" + + count += 1 + + assert count == num_total_batches, f"Iterated {count} times, expected {num_total_batches}" + + +def test_prefetch_samples(): + batch_size = 2 + dataloader = get_dataloader(batch_size) + + assert dataloader is not None, "DataLoader creation failed" + + assert len(dataloader) > 0, "DataLoader is empty, check test_data.jsonl and batch_size" + + # Initialize SequentialPrefetchOnPolicyDataProto + batch = SequentialPrefetchOnPolicyDataProto(dataloader=dataloader) + + # Test normal iteration + count = 0 + batch = batch.next() + + prefetch_sample = batch.prefetch_samples(count=1) + + # verify prefetch data + prefetched = prefetch_sample.select(batch_keys=["data"], non_tensor_batch_keys=["text"]) + expected_data = torch.tensor([[count * batch_size + 3]]) + assert torch.equal(prefetched.batch["data"], expected_data) + + expected_text = np.array([[count * batch_size + 3, count * batch_size + 3]]) + assert np.array_equal(prefetched.non_tensor_batch["text"], expected_text) + + # verify prefetch data are joined to current batch + batch_data = batch.select(batch_keys=["data"], non_tensor_batch_keys=["text"]) + expected_data = torch.tensor([[count * batch_size + 1], [count * batch_size + 2], [count * batch_size + 3]]) + assert torch.equal(batch_data.batch["data"], expected_data) + + expected_text = np.array([[count * batch_size + 1, count * batch_size + 1], [count * batch_size + 2, count * batch_size + 2], [count * batch_size + 3, count * batch_size + 3]]) + assert np.array_equal(batch_data.non_tensor_batch["text"], expected_text) + + # verify fetched data are removed from next batch + next_batch = batch.next() + next_batch_data = next_batch.select(batch_keys=["data"], non_tensor_batch_keys=["text"]) + expected_data = torch.tensor([[count * batch_size + 4]]) + assert torch.equal(next_batch_data.batch["data"], expected_data) + + expected_text = np.array([[count * batch_size + 4, count * batch_size + 4]]) + assert np.array_equal(next_batch_data.non_tensor_batch["text"], expected_text) + + +def test_postpone_samples(): + batch_size = 2 + dataloader = get_dataloader(batch_size) + + assert dataloader is not None, "DataLoader creation failed" + + assert len(dataloader) > 0, "DataLoader is empty, check test_data.jsonl and batch_size" + + # Initialize SequentialPrefetchOnPolicyDataProto + batch = SequentialPrefetchPostponeOffPolicyDataProto(dataloader=dataloader) + + # Test normal iteration + count = 0 + batch = batch.next() + + samples = batch.select_idxs([1]) + + assert batch.postpone_samples(samples) + + # verify prefetch data are removed from current batch + batch_data = batch.select(batch_keys=["data"], non_tensor_batch_keys=["text"]) + expected_data = torch.tensor([[count * batch_size + 1]]) + assert torch.equal(batch_data.batch["data"], expected_data) + + expected_text = np.array([[count * batch_size + 1, count * batch_size + 1]]) + assert np.array_equal(batch_data.non_tensor_batch["text"], expected_text) + + # verify postponed data are joined to next batch + next_batch = batch.next() + next_batch_data = next_batch.select(batch_keys=["data"], non_tensor_batch_keys=["text"]) + expected_data = torch.tensor([[count * batch_size + 3], [count * batch_size + 4], [count * batch_size + 2]]) + assert torch.equal(next_batch_data.batch["data"], expected_data) + + expected_text = np.array([[count * batch_size + 3, count * batch_size + 3], [count * batch_size + 4, count * batch_size + 4], [count * batch_size + 2, count * batch_size + 2]]) + assert np.array_equal(next_batch_data.non_tensor_batch["text"], expected_text) diff --git a/verl/protocol.py b/verl/protocol.py index 0461be9c022..e69ffabe3a5 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -21,6 +21,7 @@ import logging import os import pickle +import traceback from dataclasses import dataclass, field from typing import Callable, Dict, List, Optional, Union @@ -514,6 +515,32 @@ def slice(self, start=None, end=None, step=None): # Return a new DataProto object return type(self)(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info) + def join(self, other: "DataProto") -> "DataProto": + """ + Join another DataProto to the current one. + + Args: + other (DataProto): The DataProto to join. + + Returns: + None + """ + if not isinstance(other, DataProto): + raise TypeError(f"Can only join with another DataProto, but got {type(other)}") + + # Join batch + if self.batch is not None and other.batch is not None: + self.batch = torch.cat([self.batch, other.batch], dim=0) + elif other.batch is not None: + self.batch = other.batch + + # Join non_tensor_batch + for key, val in other.non_tensor_batch.items(): + if key in self.non_tensor_batch: + self.non_tensor_batch[key] = np.concatenate([self.non_tensor_batch[key], val], axis=0) + else: + self.non_tensor_batch[key] = val + def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto": """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` @@ -839,6 +866,149 @@ def sample_level_repeat(self, repeat_times): ) +@dataclass +class LinkedDataProto(DataProto): + """ + Represents a data structure that behaves like a linked list. + Loads dataset on demand from DataLoader. + Elements can be accessed via `next()` method. + """ + + dataloader: Optional[DataLoader] = None + _current_index: int = -1 # Index of the current element in _data_cache, -1 if before the first element + _dataloader_iter: Optional[object] = None + _dataloader_exhausted: bool = False + _linked_len: int = -1 # dataload len, -1 for unknown + _next: Optional["LinkedDataProto"] = None + + def __post_init__(self): + logging.info(f"Using DataProto: {type(self)}") + if self.dataloader is not None: + self._linked_len = len(self.dataloader) + self._dataloader_iter = iter(self.dataloader) + self._dataloader_exhausted = False + else: + # If no dataloader, behave like a regular DataProto + super().__post_init__() # Call parent's post_init for consistency checks + self._dataloader_exhausted = True # No data to load + + def next(self) -> "DataProto": + """Moves to and returns the next element. Loads from dataloader if not cached.""" + if self._next: + return self._next + + if self.dataloader is None: + # No dataloader, behave like a regular DataProto + return None + + next_idx = self._current_index + 1 + if (self._linked_len != -1 and next_idx >= self._linked_len) or self._dataloader_exhausted: + self._dataloader_exhausted = True + logging.info("No more data in DataLoader configured for LinkedDataProto.") + return None + + try: + batch_dict = next(self._dataloader_iter) + loaded_batch = type(self).from_single_dict(batch_dict) + loaded_batch.dataloader = self.dataloader + loaded_batch._linked_len = self._linked_len + loaded_batch._dataloader_iter = self._dataloader_iter + loaded_batch._current_index = next_idx + loaded_batch._dataloader_exhausted = self._dataloader_exhausted + + self._next = loaded_batch + + return loaded_batch + except StopIteration: + self._dataloader_exhausted = True + self._next = None + logging.info("No next item in DataLoader") + return None + except Exception as e: + logging.error(f"Error loading next item from DataLoader: {e}") + traceback.print_exc() + return None + + def prefetch_samples(self, count=1) -> "DataProto": + return self + + def postpone_samples(self, samples: "DataProto"): + return False + + def should_stop_gen(): + return False + + +class SequentialPrefetchOnPolicyDataProto(LinkedDataProto): + """ + 1. prefetch samples from next batch + 2. do not postpone sample in order to on-policy + """ + + def prefetch_samples(self, count=1) -> "DataProto": + if count > len(self): + raise ValueError(f"count {count} must be less than batch_size {len(self)}") + + if count > len(self): + raise ValueError(f"count {count} must be less than batch_size {len(self)}") + + next_batch = self.next() + # 1. Get the prefetched samples + prefetched_data = next_batch.slice(0, count, 1) + + # 2. Remove the prefetched samples from the current batch + remaining_data = next_batch.slice(count, len(self), 1) + + # 3. Update the current instance with the remaining data + next_batch.batch = remaining_data.batch + next_batch.non_tensor_batch = remaining_data.non_tensor_batch + + # 4. join fetched data to self + self.join(prefetched_data) + + return prefetched_data + + +class SequentialPrefetchPostponeOffPolicyDataProto(LinkedDataProto): + """ + 1. prefetch samples from next batch + 2. postpone sample to next batch (off-policy) + 3. stop current batch when the original samples are generated + """ + + def postpone_samples(self, samples: "DataProto") -> bool: + next_batch = self.next() + next_batch.join(samples) + + indices_to_remove = samples.non_tensor_batch["index"] + + # Create a boolean mask for elements to keep + mask = ~np.isin(self.non_tensor_batch["index"], indices_to_remove) + + # Apply the mask to non_tensor_batch + for key, val in self.non_tensor_batch.items(): + self.non_tensor_batch[key] = val[mask] + + # Apply the mask to batch + if self.batch is not None: + self.batch = self.batch[torch.from_numpy(mask)] + + return True + + +def get_data_proto(rollout_policy: str, dataloader: DataLoader) -> LinkedDataProto: + """ + factory method to create LinkedDataProto that can be subclassed + to extend with different prefetch/postpone/stop_generate policy + """ + if rollout_policy == "SequentialPrefetchOnPolicyDataProto": + return SequentialPrefetchOnPolicyDataProto(dataloader=dataloader) + elif rollout_policy == "SequentialPrefetchPostponeOffPolicyDataProto": + return SequentialPrefetchPostponeOffPolicyDataProto(dataloader=dataloader) + else: + return LinkedDataProto(dataloader=dataloader) + + @dataclass class DataProtoFuture: """ diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 1a3a9c35963..1c95a8359b2 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -37,7 +37,7 @@ from tqdm import tqdm from verl import DataProto -from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.protocol import get_data_proto, pad_dataproto_to_divisor, unpad_dataproto from verl.single_controller.base import Worker from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from verl.single_controller.ray.base import create_colocated_worker_cls @@ -922,7 +922,15 @@ def fit(self): last_val_metrics = None for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: + # Create a DataProto instance that will handle on-demand loading + dataproto_policy = self.config.actor_rollout_ref.rollout.get("dataproto_policy", None) + linked_batch = get_data_proto(dataproto_policy, self.train_dataloader) + while True: + linked_batch = linked_batch.next() + batch = linked_batch + if linked_batch is None: + break # End of epoch + do_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False if do_profile: self.actor_rollout_wg.start_profile() @@ -935,7 +943,6 @@ def fit(self): metrics = {} timing_raw = {} - batch: DataProto = DataProto.from_single_dict(batch_dict) # pop those keys for generation batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] @@ -946,10 +953,6 @@ def fit(self): non_tensor_batch_keys_to_pop.append("raw_prompt") if "tools_kwargs" in batch.non_tensor_batch: non_tensor_batch_keys_to_pop.append("tools_kwargs") - gen_batch = batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=non_tensor_batch_keys_to_pop, - ) is_last_step = self.global_steps >= self.total_training_steps @@ -957,10 +960,20 @@ def fit(self): # generate a batch with marked_timer("gen", timing_raw, color="red"): if not self.async_rollout_mode: + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) else: self.async_rollout_manager.wake_up() - gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) + # async rollout may dynamicly change data in batch + gen_batch_output = self.async_rollout_manager.generate_sequences(batch) + # pop those keys to align with the keys in sync batch + batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) self.async_rollout_manager.sleep() timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index b0d5470302e..153fb167f89 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -29,9 +29,8 @@ class AsyncSglangServer(AsyncServerBase): def __init__(self, config: DictConfig, dp_size: int, dp_rank: int, wg_prefix: str): super().__init__() - self.config = config - rollout_config = config.get("rollout", {}) - self._tp_size = rollout_config.get("tensor_model_parallel_size", 1) + self.config = config.actor_rollout_ref + self._tp_size = self.config.rollout.get("tensor_model_parallel_size", 1) self._dp_size = dp_size self._dp_rank = dp_rank self.wg_prefix = wg_prefix diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 4e9c58eb714..fa06ffdc4e8 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -1084,7 +1084,7 @@ async def chat_completion(self, json_request): request_id=str(uuid4()), state=AsyncRolloutRequestStateEnum.PENDING, messages=[Message.model_validate(msg) for msg in json_request["messages"]], - tools=_tool_schemas, + tool_schemas=_tool_schemas, tools_kwargs=_tools_kwargs, input_ids=_input_ids, prompt_ids=_input_ids, @@ -1099,8 +1099,12 @@ async def chat_completion(self, json_request): prompt_loss_mask=[0] * len(_input_ids), response_loss_mask=[], reward_scores={}, + max_prompt_len=self.config.prompt_length, max_response_len=self.config.response_length, max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length), + use_inference_chat_template=self.config.multi_turn.use_inference_chat_template, + enable_tokenization_sanity_check=self.config.multi_turn.enable_tokenization_sanity_check, + tokenizer=self.tokenizer, ) # json_request already contains sampling_params