Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 174 additions & 49 deletions tests/checkpoint_engine/test_special_server_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,161 @@
import pytest
import ray
from omegaconf import DictConfig
from openai import AsyncOpenAI
from transformers import PreTrainedTokenizer

from tests.checkpoint_engine.test_utils import create_trainer_worker_group
from verl.checkpoint_engine import CheckpointEngineManager, CheckpointEngineWorker
from verl.checkpoint_engine import CheckpointEngineManager
from verl.experimental.agent_loop.agent_loop import AgentLoopManager, AsyncLLMServerManager
from verl.experimental.fully_async_policy.agent_loop.agent_loop import FullyAsyncLLMServerManager
from verl.single_controller.ray import (
RayClassWithInitArgs,
RayResourcePool,
RayWorkerGroup,
)
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.device import get_device_name
from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig
from verl.workers.rollout.replica import get_rollout_replica_class
from verl.workers.config import CheckpointEngineConfig, HFModelConfig


@pytest.fixture
def init_config() -> DictConfig:
from hydra import compose, initialize_config_dir

with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
config = compose(config_name="ppo_trainer")
config = compose(
config_name="ppo_trainer",
overrides=[
"+async_training.partial_rollout_resume=True",
],
)

config.trainer.n_gpus_per_node = 8
config.trainer.nnodes = 1
config.actor_rollout_ref.model.path = os.path.expanduser("~/models/Qwen/Qwen3-VL-2B-Instruct")
config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"]
config.actor_rollout_ref.rollout.skip_tokenizer_init = False
config.actor_rollout_ref.rollout.max_num_seqs = 256
config.actor_rollout_ref.rollout.response_length = 4096
config.actor_rollout_ref.rollout.checkpoint_engine.backend = "nccl" if get_device_name() == "cuda" else "hccl"
config.actor_rollout_ref.rollout.nnodes = 1
config.trainer.n_gpus_per_node = 4
config.trainer.nnodes = 1

return config


async def _run_update_weights_with_global_steps_none(
server_manager: AsyncLLMServerManager,
checkpoint_manager: CheckpointEngineManager,
tokenizer: PreTrainedTokenizer,
):
await checkpoint_manager.update_weights(global_steps=None)
prompt = [{"role": "user", "content": "How to make a sandwich?"}]
prompt_ids = tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=True)
output = await server_manager.generate(
request_id="test_0",
prompt_ids=prompt_ids,
sampling_params={
"temperature": 1.0,
"logprobs": True,
},
)
assert output.stop_reason not in ("aborted", "abort"), (
f"output.stop_reason is {output.stop_reason}, expected not abort"
)
assert output.extra_info["global_steps"] is None, (
f"output.extra_info['global_steps'] is {output.extra_info['global_steps']}, expected None"
)
print("========== [update_weights with global_steps=None] ==========")
print("[RESPONSE]", tokenizer.decode(output.token_ids, skip_special_tokens=True))


async def _run_server_manager_without_resume(
initial_steps: int,
train_steps: int,
server_manager: AsyncLLMServerManager,
checkpoint_manager: CheckpointEngineManager,
prompts: list[list[dict]],
tokenizer: PreTrainedTokenizer,
):
for global_steps in range(initial_steps, initial_steps + train_steps):
tasks = []
for i, prompt in enumerate(prompts):
prompt_ids = tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=True)
tasks.append(
asyncio.create_task(
server_manager.generate(
request_id=f"test_{global_steps}_{i}",
prompt_ids=prompt_ids,
sampling_params={
"temperature": 1.0,
"logprobs": True,
},
)
)
)

# wait a while and update weights to interrupt the generation
await asyncio.sleep(3)
await checkpoint_manager.update_weights(global_steps=global_steps)

outputs = await asyncio.gather(*tasks)
expected_steps = global_steps - 1
for output in outputs:
global_steps = output.extra_info["global_steps"]
assert output.stop_reason in ("aborted", "abort"), (
f"output.stop_reason is {output.stop_reason}, expected in abort"
)
assert global_steps == expected_steps, f"output.global_steps is {global_steps}, expected {expected_steps}"
print(f"========== [{initial_steps=}, {train_steps=}] ==========")
print("[RESPONSE]", tokenizer.decode(outputs[0].token_ids, skip_special_tokens=True))


async def _run_server_manager_with_resume(
initial_steps: int,
train_steps: int,
server_manager: FullyAsyncLLMServerManager,
checkpoint_manager: CheckpointEngineManager,
prompts: list[list[dict]],
tokenizer: PreTrainedTokenizer,
):
# 1. rollout generate responses
tasks = []
for i, prompt in enumerate(prompts):
prompt_ids = tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=True)
tasks.append(
asyncio.create_task(
server_manager.generate(
request_id=f"test_{initial_steps}_{i}",
prompt_ids=prompt_ids,
sampling_params={
"temperature": 1.0,
"logprobs": True,
},
)
)
)

# 2. trainer update weights to rollout multiple times
for global_steps in range(initial_steps, initial_steps + train_steps):
# wait a while and update weights to interrupt the generation
await asyncio.sleep(3)
await checkpoint_manager.update_weights(global_steps=global_steps)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the test case include update_weights(global_steps=None)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case _run_update_weights_with_global_steps_none


# 3. wait for rollout generate responses finished
outputs = await asyncio.gather(*tasks)
expected_min_steps = initial_steps - 1
for output in outputs:
min_global_steps = output.extra_info["min_global_steps"]
max_global_steps = output.extra_info["max_global_steps"]
assert min_global_steps == expected_min_steps, (
f"output.min_global_steps is {min_global_steps}, expected {expected_min_steps}"
)
assert max_global_steps > expected_min_steps, (
f"output.max_global_steps is {max_global_steps}, expected > {expected_min_steps}"
)
assert output.stop_reason not in ("aborted", "abort"), (
f"output.stop_reason is {output.stop_reason}, expected not abort"
)
print(f"========== [{initial_steps=}, {train_steps=}] ==========")
print("[RESPONSE]", tokenizer.decode(outputs[0].token_ids, skip_special_tokens=True))


@pytest.mark.asyncio
async def test_server_adapter(init_config):
ray.init(
Expand All @@ -69,53 +191,56 @@ async def test_server_adapter(init_config):
checkpoint_engine_config: CheckpointEngineConfig = omega_conf_to_dataclass(
init_config.actor_rollout_ref.rollout.checkpoint_engine
)
trainer_pool = RayResourcePool(process_on_nodes=[4], max_colocate_count=3)
trainer_pool = RayResourcePool(process_on_nodes=[init_config.trainer.n_gpus_per_node], max_colocate_count=3)
trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config)
trainer.reset()

# 2. create rollout replicas
rollout_config: RolloutConfig = omega_conf_to_dataclass(init_config.actor_rollout_ref.rollout)

# 2.1 create checkpoint engine worker group
rollout_pool = RayResourcePool(process_on_nodes=[4], max_colocate_count=3)
ray_cls_with_init = RayClassWithInitArgs(
cls=ray.remote(CheckpointEngineWorker),
model_config=model_config,
rollout_config=rollout_config,
)
rollout = RayWorkerGroup(
resource_pool=rollout_pool, ray_cls_with_init=ray_cls_with_init, device_name=get_device_name()
)

# 2.2 create rollout replicas
rollout_replica_class = get_rollout_replica_class(rollout_config.name)
rollout_replicas = [
rollout_replica_class(
replica_rank=replica_rank,
config=rollout_config,
model_config=model_config,
)
for replica_rank in range(2)
]
await asyncio.gather(*[replica.init_hybrid(rollout) for replica in rollout_replicas])
# 2. create standalone rollout with AgentLoopManager
agent_loop_manager = await AgentLoopManager.create(config=init_config)
server_handles = [server._server_handle for server in agent_loop_manager.rollout_replicas]

# 3. create checkpoint engine manager
checkpoint_manager = CheckpointEngineManager(
config=checkpoint_engine_config, trainer=trainer, replicas=rollout_replicas
config=checkpoint_engine_config, trainer=trainer, replicas=agent_loop_manager.rollout_replicas
)
for i in range(3):
await checkpoint_manager.update_weights()

server_addresses = rollout_replicas[i % len(rollout_replicas)].server_address
client = AsyncOpenAI(
api_key="123-abc",
base_url=f"http://{server_addresses}/v1",
)
n = 4
prompts = [
[{"role": "user", "content": "Please write an article about the history of China, at least 1000 words."}],
[{"role": "user", "content": "Please write an article about the history of America, at least 1000 words."}],
[{"role": "user", "content": "Please write an article about the geography of China, at least 1000 words."}],
[{"role": "user", "content": "Please write an article about the geography of America, at least 1000 words."}],
] * n

server_manager = AsyncLLMServerManager(config=init_config, server_handles=server_handles)

# 4. test update_weights with global_steps=None
await _run_update_weights_with_global_steps_none(
server_manager=server_manager,
checkpoint_manager=checkpoint_manager,
tokenizer=model_config.tokenizer,
)

completion = await client.chat.completions.create(
model=init_config.actor_rollout_ref.model.path,
messages=[{"role": "user", "content": "What can you do?"}],
)
print("[OUTPUT]:", completion.choices[0].message.content)
# 5. test AsyncLLMServerManager without partial rollout resume
await checkpoint_manager.update_weights(global_steps=0)
await _run_server_manager_without_resume(
initial_steps=1,
train_steps=3,
server_manager=server_manager,
checkpoint_manager=checkpoint_manager,
prompts=prompts,
tokenizer=model_config.tokenizer,
)

# 6. test FullyAsyncLLMServerManager with partial rollout resume
server_manager = FullyAsyncLLMServerManager(config=init_config, server_handles=server_handles)
await _run_server_manager_with_resume(
initial_steps=4,
train_steps=3,
server_manager=server_manager,
checkpoint_manager=checkpoint_manager,
prompts=prompts,
tokenizer=model_config.tokenizer,
)

ray.shutdown()
2 changes: 1 addition & 1 deletion tests/checkpoint_engine/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, config: TrainingWorkerConfig, checkpoint_engine_config: Check
self.checkpoint_engine = CheckpointEngineRegistry.new(backend, bucket_size=bucket_size, **engine_kwargs)

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
async def update_weights(self):
async def update_weights(self, global_steps: int = None):
per_tensor_param, _ = self.engine.get_per_tensor_param()
await self.checkpoint_engine.send_weights(per_tensor_param)

Expand Down
18 changes: 11 additions & 7 deletions verl/checkpoint_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,9 @@ def __init__(
initialize_global_process_group_ray(timeout_second=None, backend="cpu:gloo")

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
async def update_weights(self):
async def update_weights(self, global_steps: int = None):
weights = self.checkpoint_engine.receive_weights()
await self.server_adapter.update_weights(weights)
await self.server_adapter.update_weights(weights, global_steps=global_steps)

@register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False)
def execute_checkpoint_engine(self, method: str, *args, **kwargs):
Expand Down Expand Up @@ -399,12 +399,16 @@ async def sleep_replicas(self):
await asyncio.gather(*[r.sleep() for r in self.replicas])

@auto_await
async def update_weights(self):
"""Update weights from trainer to rollout replicas."""
async def update_weights(self, global_steps: int = None):
"""Update weights from trainer to rollout replicas.

Args:
global_steps: The global steps of the trainer.
"""

# 0. update weights for sync training with colocated trainer and rollout
if self.backend == "naive":
ray.get(self.trainer.update_weights())
ray.get(self.trainer.update_weights(global_steps=global_steps))
return

# 1. abort and save all unfinished requests for partial rollout
Expand All @@ -421,7 +425,7 @@ async def update_weights(self):
self.build_process_group(rollout)

# 4. update weights of all workers
ray.get(trainer.update_weights() + rollout.update_weights())
ray.get(trainer.update_weights(global_steps=global_steps) + rollout.update_weights(global_steps=global_steps))

# 5. finalize all workers
ray.get(
Expand All @@ -430,4 +434,4 @@ async def update_weights(self):
)

# 6. resume all unfinished requests for partial rollout
await asyncio.gather(*[r.resume_all_requests() for r in self.replicas])
await asyncio.gather(*[r.resume_generation() for r in self.replicas])
Loading
Loading