Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/e2e_ppo_trainer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,14 @@ jobs:
run: |
ray stop --force
ENGINE=sglang bash tests/special_e2e/ppo_trainer/run_function_reward.sh
- name: Running GSM8K E2E training tests on sglang async
run: |
ray stop --force
ENGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh
- name: Running GSM8K E2E training tests on vllm async
run: |
ray stop --force
ENGINE=vllm ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh

e2e_ppo_trainer_sglang_multiturn_with_tool:
runs-on: [L20x8]
Expand Down
6 changes: 6 additions & 0 deletions tests/test_data.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{"data": [1], "text": [1, 1], "index": 0}
{"data": [2], "text": [2, 2], "index": 3}
{"data": [3], "text": [3, 3], "index": 5}
{"data": [4], "text": [4, 4], "index": 1}
{"data": [5], "text": [5, 5], "index": 4}
{"data": [6], "text": [6, 6], "index": 2}
149 changes: 148 additions & 1 deletion tests/test_protocol_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random

import datasets
import numpy as np
import pytest
import torch
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader

from verl import DataProto
from verl.protocol import union_numpy_dict, union_tensor_dict
from verl.protocol import LinkedDataProto, SequentialPrefetchOnPolicyDataProto, SequentialPrefetchPostponeOffPolicyDataProto, union_numpy_dict, union_tensor_dict


def test_union_tensor_dict():
Expand Down Expand Up @@ -499,3 +502,147 @@ def test_dataproto_chunk_after_index():
selected = data[torch_int_mask]
assert isinstance(selected.batch.batch_size, torch.Size)
assert all(isinstance(d, int) for d in selected.batch.batch_size)


def get_dataloader(batch_size):
def collate_fn(batch):
flattened_data = [item["data"] for item in batch]
flattened_text = [item["text"] for item in batch]
flattened_index = [item["index"] for item in batch]
return {"data": torch.tensor(flattened_data), "text": np.array(flattened_text), "index": np.array(flattened_index)}

train_dataset = datasets.load_dataset("json", data_files=[os.path.join(os.path.dirname(__file__), "test_data.jsonl")])
dataloader = StatefulDataLoader(
dataset=train_dataset["train"],
batch_size=batch_size,
num_workers=0,
drop_last=True,
collate_fn=collate_fn,
)
return dataloader


def test_linked_data_proto_next():
"""Test the next() method of LinkedDataProto with real DataLoader."""

batch_size = 2
num_total_batches = 3
dataloader = get_dataloader(batch_size)

assert dataloader is not None, "DataLoader creation failed"

assert len(dataloader) > 0, "DataLoader is empty, check test_data.jsonl and batch_size"

# Initialize LinkedDataProto
batch = LinkedDataProto(dataloader=dataloader)

# Test normal iteration
count = 0
while True:
next_batch = batch.next()
if next_batch is None:
assert batch._dataloader_exhausted, "dataloader exhausted should be True"
break # End of epoch
batch = next_batch
popped_data = batch.pop(batch_keys=["data"], non_tensor_batch_keys=["text"])
expected_data = torch.tensor([[count * batch_size + 1], [count * batch_size + 2]])
assert torch.equal(popped_data.batch["data"], expected_data)

expected_text = np.array([[count * batch_size + 1, count * batch_size + 1], [count * batch_size + 2, count * batch_size + 2]])
assert np.array_equal(popped_data.non_tensor_batch["text"], expected_text)

# fetch data from next batch
next_batch = batch.next()
if count < num_total_batches - 1:
popped_data = next_batch.select(batch_keys=["data"], non_tensor_batch_keys=["text"])
expected_data = torch.tensor([[count * batch_size + 3], [count * batch_size + 4]])
assert torch.equal(popped_data.batch["data"], expected_data)

expected_text = np.array([[count * batch_size + 3, count * batch_size + 3], [count * batch_size + 4, count * batch_size + 4]])
assert np.array_equal(popped_data.non_tensor_batch["text"], expected_text)
else:
assert next_batch is None, "next batch should be None"

count += 1

assert count == num_total_batches, f"Iterated {count} times, expected {num_total_batches}"


def test_prefetch_samples():
batch_size = 2
dataloader = get_dataloader(batch_size)

assert dataloader is not None, "DataLoader creation failed"

assert len(dataloader) > 0, "DataLoader is empty, check test_data.jsonl and batch_size"

# Initialize SequentialPrefetchOnPolicyDataProto
batch = SequentialPrefetchOnPolicyDataProto(dataloader=dataloader)

# Test normal iteration
count = 0
batch = batch.next()

prefetch_sample = batch.prefetch_samples(count=1)

# verify prefetch data
prefetched = prefetch_sample.select(batch_keys=["data"], non_tensor_batch_keys=["text"])
expected_data = torch.tensor([[count * batch_size + 3]])
assert torch.equal(prefetched.batch["data"], expected_data)

expected_text = np.array([[count * batch_size + 3, count * batch_size + 3]])
assert np.array_equal(prefetched.non_tensor_batch["text"], expected_text)

# verify prefetch data are joined to current batch
batch_data = batch.select(batch_keys=["data"], non_tensor_batch_keys=["text"])
expected_data = torch.tensor([[count * batch_size + 1], [count * batch_size + 2], [count * batch_size + 3]])
assert torch.equal(batch_data.batch["data"], expected_data)

expected_text = np.array([[count * batch_size + 1, count * batch_size + 1], [count * batch_size + 2, count * batch_size + 2], [count * batch_size + 3, count * batch_size + 3]])
assert np.array_equal(batch_data.non_tensor_batch["text"], expected_text)

# verify fetched data are removed from next batch
next_batch = batch.next()
next_batch_data = next_batch.select(batch_keys=["data"], non_tensor_batch_keys=["text"])
expected_data = torch.tensor([[count * batch_size + 4]])
assert torch.equal(next_batch_data.batch["data"], expected_data)

expected_text = np.array([[count * batch_size + 4, count * batch_size + 4]])
assert np.array_equal(next_batch_data.non_tensor_batch["text"], expected_text)


def test_postpone_samples():
batch_size = 2
dataloader = get_dataloader(batch_size)

assert dataloader is not None, "DataLoader creation failed"

assert len(dataloader) > 0, "DataLoader is empty, check test_data.jsonl and batch_size"

# Initialize SequentialPrefetchOnPolicyDataProto
batch = SequentialPrefetchPostponeOffPolicyDataProto(dataloader=dataloader)

# Test normal iteration
count = 0
batch = batch.next()

samples = batch.select_idxs([1])

assert batch.postpone_samples(samples)

# verify prefetch data are removed from current batch
batch_data = batch.select(batch_keys=["data"], non_tensor_batch_keys=["text"])
expected_data = torch.tensor([[count * batch_size + 1]])
assert torch.equal(batch_data.batch["data"], expected_data)

expected_text = np.array([[count * batch_size + 1, count * batch_size + 1]])
assert np.array_equal(batch_data.non_tensor_batch["text"], expected_text)

# verify postponed data are joined to next batch
next_batch = batch.next()
next_batch_data = next_batch.select(batch_keys=["data"], non_tensor_batch_keys=["text"])
expected_data = torch.tensor([[count * batch_size + 3], [count * batch_size + 4], [count * batch_size + 2]])
assert torch.equal(next_batch_data.batch["data"], expected_data)

expected_text = np.array([[count * batch_size + 3, count * batch_size + 3], [count * batch_size + 4, count * batch_size + 4], [count * batch_size + 2, count * batch_size + 2]])
assert np.array_equal(next_batch_data.non_tensor_batch["text"], expected_text)
170 changes: 170 additions & 0 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import logging
import os
import pickle
import traceback
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -514,6 +515,32 @@ def slice(self, start=None, end=None, step=None):
# Return a new DataProto object
return type(self)(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info)

def join(self, other: "DataProto") -> "DataProto":
"""
Join another DataProto to the current one.

Args:
other (DataProto): The DataProto to join.

Returns:
None
"""
if not isinstance(other, DataProto):
raise TypeError(f"Can only join with another DataProto, but got {type(other)}")

# Join batch
if self.batch is not None and other.batch is not None:
self.batch = torch.cat([self.batch, other.batch], dim=0)
elif other.batch is not None:
self.batch = other.batch

# Join non_tensor_batch
for key, val in other.non_tensor_batch.items():
if key in self.non_tensor_batch:
self.non_tensor_batch[key] = np.concatenate([self.non_tensor_batch[key], val], axis=0)
else:
self.non_tensor_batch[key] = val

def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto":
"""Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`

Expand Down Expand Up @@ -839,6 +866,149 @@ def sample_level_repeat(self, repeat_times):
)


@dataclass
class LinkedDataProto(DataProto):
"""
Represents a data structure that behaves like a linked list.
Loads dataset on demand from DataLoader.
Elements can be accessed via `next()` method.
"""

dataloader: Optional[DataLoader] = None
_current_index: int = -1 # Index of the current element in _data_cache, -1 if before the first element
_dataloader_iter: Optional[object] = None
_dataloader_exhausted: bool = False
_linked_len: int = -1 # dataload len, -1 for unknown
_next: Optional["LinkedDataProto"] = None

def __post_init__(self):
logging.info(f"Using DataProto: {type(self)}")
if self.dataloader is not None:
self._linked_len = len(self.dataloader)
self._dataloader_iter = iter(self.dataloader)
self._dataloader_exhausted = False
else:
# If no dataloader, behave like a regular DataProto
super().__post_init__() # Call parent's post_init for consistency checks
self._dataloader_exhausted = True # No data to load

def next(self) -> "DataProto":
"""Moves to and returns the next element. Loads from dataloader if not cached."""
if self._next:
return self._next

if self.dataloader is None:
# No dataloader, behave like a regular DataProto
return None

next_idx = self._current_index + 1
if (self._linked_len != -1 and next_idx >= self._linked_len) or self._dataloader_exhausted:
self._dataloader_exhausted = True
logging.info("No more data in DataLoader configured for LinkedDataProto.")
return None

try:
batch_dict = next(self._dataloader_iter)
loaded_batch = type(self).from_single_dict(batch_dict)
loaded_batch.dataloader = self.dataloader
loaded_batch._linked_len = self._linked_len
loaded_batch._dataloader_iter = self._dataloader_iter
loaded_batch._current_index = next_idx
loaded_batch._dataloader_exhausted = self._dataloader_exhausted

self._next = loaded_batch

return loaded_batch
except StopIteration:
self._dataloader_exhausted = True
self._next = None
logging.info("No next item in DataLoader")
return None
except Exception as e:
logging.error(f"Error loading next item from DataLoader: {e}")
traceback.print_exc()
return None

def prefetch_samples(self, count=1) -> "DataProto":
return self

def postpone_samples(self, samples: "DataProto"):
return False

def should_stop_gen():
return False


class SequentialPrefetchOnPolicyDataProto(LinkedDataProto):
"""
1. prefetch samples from next batch
2. do not postpone sample in order to on-policy
"""

def prefetch_samples(self, count=1) -> "DataProto":
if count > len(self):
raise ValueError(f"count {count} must be less than batch_size {len(self)}")

if count > len(self):
raise ValueError(f"count {count} must be less than batch_size {len(self)}")

next_batch = self.next()
# 1. Get the prefetched samples
prefetched_data = next_batch.slice(0, count, 1)

# 2. Remove the prefetched samples from the current batch
remaining_data = next_batch.slice(count, len(self), 1)

# 3. Update the current instance with the remaining data
next_batch.batch = remaining_data.batch
next_batch.non_tensor_batch = remaining_data.non_tensor_batch

# 4. join fetched data to self
self.join(prefetched_data)

return prefetched_data


class SequentialPrefetchPostponeOffPolicyDataProto(LinkedDataProto):
"""
1. prefetch samples from next batch
2. postpone sample to next batch (off-policy)
3. stop current batch when the original samples are generated
"""

def postpone_samples(self, samples: "DataProto") -> bool:
next_batch = self.next()
next_batch.join(samples)

indices_to_remove = samples.non_tensor_batch["index"]

# Create a boolean mask for elements to keep
mask = ~np.isin(self.non_tensor_batch["index"], indices_to_remove)

# Apply the mask to non_tensor_batch
for key, val in self.non_tensor_batch.items():
self.non_tensor_batch[key] = val[mask]

# Apply the mask to batch
if self.batch is not None:
self.batch = self.batch[torch.from_numpy(mask)]

return True


def get_data_proto(rollout_policy: str, dataloader: DataLoader) -> LinkedDataProto:
"""
factory method to create LinkedDataProto that can be subclassed
to extend with different prefetch/postpone/stop_generate policy
"""
if rollout_policy == "SequentialPrefetchOnPolicyDataProto":
return SequentialPrefetchOnPolicyDataProto(dataloader=dataloader)
elif rollout_policy == "SequentialPrefetchPostponeOffPolicyDataProto":
return SequentialPrefetchPostponeOffPolicyDataProto(dataloader=dataloader)
else:
return LinkedDataProto(dataloader=dataloader)


@dataclass
class DataProtoFuture:
"""
Expand Down
Loading
Loading