Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions examples/scripts/async_grpo_delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
AsyncGRPO with delta weight sync (Transport B: HF Storage Bucket + in-place sparse apply).

Only changed bf16 weights are encoded as a sparse safetensors patch, uploaded to a bucket, and
applied in place on vLLM via PR #40096 — no full-model broadcast, no vLLM-side snapshot.

Start the vLLM server with the `delta` backend + worker extension (registers the engine) and the
`transformers` model impl (so vLLM's runtime param names match the trainer's HF names — every
param is then addressable by the in-place sparse apply, no fuse/unfuse remap needed):

# VLLM_USE_V2_MODEL_RUNNER=0 is required: the in-place sparse apply (apply_sparse_weight_patches,
# vLLM #40096) exists only on the V1 model runner. Without it the server picks V2 and every sparse
# delta update fails (the dense anchors still work, so it silently degrades to anchor-only sync).
CUDA_VISIBLE_DEVICES=1 VLLM_SERVER_DEV_MODE=1 VLLM_USE_V2_MODEL_RUNNER=0 vllm serve Qwen/Qwen3-1.7B \
--model-impl transformers \
--worker-extension-cls trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension \
--weight-transfer-config '{"backend":"delta"}' \
--max-model-len 2560

CUDA_VISIBLE_DEVICES=0 accelerate launch examples/scripts/async_grpo_delta.py
"""

import logging
import os

from datasets import load_dataset

from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer
from trl.rewards import accuracy_reward


logging.basicConfig(
level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logging.getLogger("trl").setLevel(logging.INFO)


def format_sample(sample):
return {
"prompt": [{"role": "user", "content": sample["question"]}],
"solution": sample["answer"].split("####")[-1].strip(),
}


def main() -> None:
dataset = load_dataset("openai/gsm8k", "main", split="train")
dataset = dataset.map(format_sample, remove_columns=dataset.column_names)

config = AsyncGRPOConfig(
output_dir="./results/async_grpo_delta",
per_device_train_batch_size=1,
num_generations=8,
max_completion_length=512,
max_steps=60,
learning_rate=1e-5,
logging_steps=1,
bf16=True,
report_to="none",
project="async_grpo_delta",
log_completions=True,
# Qwen3 thinking traces blow past the completion cap on GSM8K (truncated -> no answer ->
# zero reward); disable thinking so completions are short and accuracy_reward gets signal.
chat_template_kwargs={"enable_thinking": False},
# --- delta weight sync (Transport B) ---
delta_sync_enabled=True,
delta_sync_repo_id="<your-hf-username>/async-grpo-delta-demo", # set to a bucket you own
delta_sync_anchor_interval=20, # full anchor every N syncs; sparse deltas in between
delta_sync_encoding="gap_delta", # raw | gap_delta | nvcomp_cascaded
)
trainer = AsyncGRPOTrainer(
model="Qwen/Qwen3-1.7B",
args=config,
train_dataset=dataset,
reward_funcs=accuracy_reward,
)
trainer.train()


if __name__ == "__main__":
main()
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ vlm = [
math_verify = [
"math-verify>=0.5.2",
]
delta_weight_sync = [
# Delta weight sync for AsyncGRPO (trl/experimental/async_grpo). Needs vLLM's sparse weight
# transfer (vllm-project/vllm#40096): merged to main after v0.22.0, not in a release yet —
# install vLLM from the nightly index until it ships (so no `vllm` pin here). Serve with
# `--model-impl transformers` and `VLLM_USE_V2_MODEL_RUNNER=0`.
"huggingface-hub>=1.17.0", # HF Storage Bucket API (create/batch/download_bucket_files)
"nvidia-nvcomp-cu12>=5.2.0", # optional: only for the nvcomp_cascaded index encoding (CUDA 12)
]
openreward = [
"openreward>=0.1.109; python_version >= '3.11'", # openreward requires Python 3.11+
]
Expand Down
29 changes: 29 additions & 0 deletions trl/experimental/async_grpo/async_grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,35 @@ class AsyncGRPOConfig(_BaseConfig):
},
)

# Parameters that control delta weight sync (Transport B: sparse patches over an HF Storage Bucket)
delta_sync_enabled: bool = field(
default=False,
metadata={
"help": "Sync only the changed bf16 weights as sparse safetensors patches via an HF Storage Bucket "
"(applied in place on vLLM), instead of broadcasting the full policy over NCCL. Requires a vLLM with "
"sparse weight transfer (#40096), served with `--model-impl transformers` and `VLLM_USE_V2_MODEL_RUNNER=0`."
},
)
delta_sync_repo_id: str | None = field(
default=None,
metadata={
"help": "HF Storage Bucket for the delta patches/anchors (created if missing). Required when "
"`delta_sync_enabled=True`."
},
)
delta_sync_anchor_interval: int = field(
default=10,
metadata={
"help": "Send a full anchor every N syncs; sparse deltas in between (bounds drift from missed bits)."
},
)
delta_sync_encoding: str = field(
default="gap_delta",
metadata={
"help": "Index encoding for delta patches: 'raw', 'gap_delta', or 'nvcomp_cascaded' (needs nvcomp)."
},
)

# Parameters that control the logging
log_completions: bool = field(
default=False,
Expand Down
58 changes: 57 additions & 1 deletion trl/experimental/async_grpo/async_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from .async_grpo_config import AsyncGRPOConfig
from .async_rollout_worker import AsyncRolloutWorker
from .weight_diff import LowByteChangeDetector
from .weight_transfer import WeightTransferClient


Expand Down Expand Up @@ -413,6 +414,7 @@ def __init__(
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
self._train_tokens_start_time = None
self.model_version = 0
self._change_detector: LowByteChangeDetector | None = None # delta sync only; created in compute_loss
# Create worker and queue on rank 0
if self.accelerator.is_main_process:
if self.train_dataset is None:
Expand Down Expand Up @@ -444,6 +446,10 @@ def __init__(
"packed": True,
"is_checkpoint_format": True,
},
delta_sync_enabled=self.args.delta_sync_enabled,
delta_sync_repo_id=self.args.delta_sync_repo_id,
delta_sync_anchor_interval=self.args.delta_sync_anchor_interval,
delta_sync_encoding=self.args.delta_sync_encoding,
)
self.rollout_worker = AsyncRolloutWorker(
model_name=model_name,
Expand Down Expand Up @@ -517,6 +523,10 @@ def _set_signature_columns_if_needed(self):
]

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
# Register the change detector before the first optimizer.step so step 1 is captured and the
# first delta sync is already sparse (no-op unless delta sync is enabled).
self._maybe_init_change_detector()

input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
completion_mask = inputs["completion_mask"]
Expand Down Expand Up @@ -667,8 +677,48 @@ def _streaming_iter(self):
full = full.to(device)
yield name, full

def _streaming_iter_delta(self):
"""Like [`_streaming_iter`] but yields ``(name, full, mask)`` for delta sync. With an active
change detector, only changed params are yielded (their element-level masks); otherwise all params with
``mask=None`` (the cold anchor). All ranks must walk this identically so the FSDP2 ``full_tensor()``
collectives line up.
"""
device = self.accelerator.device
masks = self._change_detector._validated_masks if self._change_detector is not None else {}
for name, param in self.model.named_parameters():
name = name.removeprefix("module.") # DDP/FSDP1 wrapping
mask = masks.get(name) if masks else None
if masks and (mask is None or not mask.any()):
continue # unchanged param -> not in this delta
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FSDP collectives skipped on delta sync

High Severity

_streaming_iter_delta skips full_tensor() for parameters deemed unchanged, but non-main ranks still walk the same loop for FSDP2 collectives. Per-rank low-byte masks can differ under sharding, so ranks may skip different parameters and deadlock or corrupt the gather.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit e3166c0. Configure here.

full = param.full_tensor() if isinstance(param, DTensor) else param.detach()
if full.device != device:
full = full.to(device)
yield name, full, mask
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Periodic anchors omit unchanged params

High Severity

After training starts, _streaming_iter_delta only yields parameters the low-byte detector marked as changed. Periodic anchor uploads still use that iterator and only set mask=None on those tensors, so vLLM receives a partial checkpoint instead of a full model refresh. If no weights changed, the anchor upload can be skipped entirely while inference keeps stale weights.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 36d54db. Configure here.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shard masks vs full tensors

High Severity

For DTensor parameters, full_tensor() is global but change masks come from the local optimizer shard. Sparse encode pairs that mask with the gathered full tensor, so indices and values can be wrong or indexing can fail.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit e3166c0. Configure here.


def _maybe_init_change_detector(self):
"""Create the bf16 change detector once the (prepared) optimizer exists, before its first
step, so the first delta sync is already sparse. No-op unless delta sync is enabled."""
if self.args.delta_sync_enabled and self._change_detector is None and getattr(self, "optimizer", None):
# Unwrap AcceleratedOptimizer to the native torch optimizer (register_step_*_hook).
raw_optimizer = getattr(self.optimizer, "optimizer", self.optimizer)
self._change_detector = LowByteChangeDetector(self.model, raw_optimizer)

def _sync_weight(self):
t0 = time.time()
is_delta = self.args.delta_sync_enabled

if is_delta:
# Delta phase 1: upload the sparse patch to the bucket while inference keeps running.
logger.info("Weight sync: uploading delta patch (inference still running)...")
if self.accelerator.is_main_process and self.weight_transfer:
self.weight_transfer.upload_weights(self._streaming_iter_delta())
else:
# Non-rank-0 still walks the iterator so full_tensor() collectives complete.
for _ in self._streaming_iter_delta():
pass
self.accelerator.wait_for_everyone()

# Pause vllm both delta and full
logger.info("Weight sync: pausing vLLM...")
if self.accelerator.is_main_process and self.weight_transfer:
self.weight_transfer.pause()
Expand All @@ -679,7 +729,13 @@ def _sync_weight(self):
t_barrier = time.time()

logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)")
if self.accelerator.is_main_process and self.weight_transfer:
if is_delta:
if self.accelerator.is_main_process and self.weight_transfer:
try:
self.weight_transfer.apply_weights_delta()
except Exception as e:
logger.warning(f"Weight sync: apply failed ({e}), skipping, vLLM will use stale weights")
elif self.accelerator.is_main_process and self.weight_transfer:
self.weight_transfer.send_weights(self._streaming_iter())
else:
# Non-rank-0 processes must still participate in full_tensor() collectives for FSDP2.
Expand Down
Loading
Loading