diff --git a/docs/source/async_grpo_trainer.md b/docs/source/async_grpo_trainer.md index fdb85a231d2..85deaca7976 100644 --- a/docs/source/async_grpo_trainer.md +++ b/docs/source/async_grpo_trainer.md @@ -34,7 +34,7 @@ The rollout worker runs in a separate process spawned from the trainer, so rewar > > If you do need a GPU reward model, the recommended approach is to **serve it behind its own inference engine** (vLLM, TGI, …) on separate GPUs and have a lightweight, picklable reward function call it over HTTP. This keeps the reward model on its own device while the rollout process stays CPU-only, and it scales independently of the trainer. -After every `weight_sync_steps` training steps, the updated weights are transferred to the vLLM server via NCCL so that subsequent generations reflect the latest policy. +After every `weight_sync_steps` training steps, the updated weights are transferred to the vLLM server so that subsequent generations reflect the latest policy. How they are transferred is configurable, see [Weight synchronization](#weight-synchronization) below. Because generation and training run concurrently, the training samples may have been generated by a slightly older version of the model. The `max_staleness` parameter controls how many weight updates a sample can lag behind before being discarded. @@ -76,6 +76,52 @@ CUDA_VISIBLE_DEVICES=0 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \ CUDA_VISIBLE_DEVICES=1 accelerate launch train_async_grpo.py ``` +## Weight synchronization + +After each weight sync the trainer pushes the updated policy to the vLLM server. Two choices control how: + +- **`weight_sync_mode`**: `"sparse"` (default) or `"full"`. + - `"sparse"` sends **only the bf16 weights that changed** in the step. The changed set is recovered by _inverting_ + the AdamW update from the optimizer's resident moments (`θ_old = (θ_t + lr·m̂/(√v̂+ε)) / (1−lr·wd)`) and diffing + against the live weights, so **no pre-step snapshot is kept**. It requires a `torch.optim.AdamW` optimizer (the + trainer raises otherwise) and a vLLM with sparse weight transfer ([vllm-project/vllm#40096](https://github.com/vllm-project/vllm/pull/40096)), served with `--model-impl transformers` and `VLLM_USE_V2_MODEL_RUNNER=0`. A full **anchor** is sent every `weight_sync_anchor_interval` syncs to bound drift. + - `"full"` broadcasts the entire policy every sync. Use it when the optimizer is not AdamW. It is always sent over NCCL. +- **`weight_sync_backend`**: the transport for sparse patches: `"nccl"` (default) or `"bucket"`. + +| backend | data plane | when to use | +| ---------- | ----------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- | +| `"nccl"` | NCCL broadcast over a group shared with vLLM | trainer and vLLM **co-located** (same node / NVLink). ~100× faster per sync. | +| `"bucket"` | sparse patches uploaded to an **HF Storage Bucket**, applied in place on vLLM | trainer and vLLM on **different hosts** (e.g. a remote vLLM HF Space). Object-storage latency (~seconds/sync). | + +Serve the vLLM side to match the backend: + +```bash +# nccl backend (default): co-located trainer + vLLM +VLLM_USE_V2_MODEL_RUNNER=0 vllm serve Qwen/Qwen3-4B --model-impl transformers \ + --weight-transfer-config '{"backend":"nccl"}' + +# bucket backend: register the engine via the worker extension +VLLM_USE_V2_MODEL_RUNNER=0 vllm serve Qwen/Qwen3-4B --model-impl transformers \ + --worker-extension-cls trl.experimental.async_grpo.delta_engine.HFBucketWorkerExtension \ + --weight-transfer-config '{"backend":"hf_bucket"}' +``` + +### Disaggregating training and inference + +The `"bucket"` backend decouples _where training runs_ from _where generation runs_: the control plane is plain HTTP +(`vllm_server_base_url`) and the data plane is the HF Hub (a bucket reachable from anywhere). Nothing requires the +trainer and the vLLM server to share a machine, a network, or a NCCL group. So you can keep the trainer on your local +training GPUs and serve generation from a **remote vLLM HF Space** (or several, scaled independently), syncing only the +~1% of weights that change each step. + +The end-to-end example (local trainer + remote vLLM Space + remote environment Space + bucket) is in +[`examples/scripts/async_grpo_buckets/async_grpo_buckets.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/async_grpo_buckets) (see its `README.md` for the deploy + run guide). + +> [!TIP] +> Sparse sync cost is roughly flat in model size (only the changed elements move), while a full broadcast grows with +> the model, so the sparse advantage widens for larger policies. On a single node, `"nccl"` is the fast default; reach +> for `"bucket"` specifically for the cross-host / remote-Space setup. + ## Design philosophy This trainer is intentionally kept minimal and is not meant to grow into a general-purpose solution. If you need a feature that is not supported, we recommend cloning the repository and adapting the trainer to your needs directly. New features will only be considered when there is significant community demand. diff --git a/examples/scripts/async_grpo.py b/examples/scripts/async_grpo.py index ccd020b9d13..e936e3d13b3 100644 --- a/examples/scripts/async_grpo.py +++ b/examples/scripts/async_grpo.py @@ -24,12 +24,23 @@ """ pip install math_verify -CUDA_VISIBLE_DEVICES=1 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-0.6B \ +AsyncGRPO defaults to *sparse* weight sync over NCCL: only the bf16 weights changed by each optimizer step are +broadcast and applied in place on vLLM (the changed set is recovered by inverting the AdamW step from the resident +optimizer moments — no snapshot kept). This needs a vLLM with sparse weight transfer (vllm-project/vllm#40096), the +`transformers` model impl (so vLLM's runtime param names match the trainer's HF names), and the V1 model runner +(`apply_sparse_weight_patches` is V1-only): + +CUDA_VISIBLE_DEVICES=1 VLLM_SERVER_DEV_MODE=1 VLLM_USE_V2_MODEL_RUNNER=0 vllm serve Qwen/Qwen3-0.6B \ + --model-impl transformers \ --max-model-len 2048 \ --logprobs-mode processed_logprobs \ --weight-transfer-config '{"backend":"nccl"}' CUDA_VISIBLE_DEVICES=0 accelerate launch examples/scripts/async_grpo.py + +To fall back to broadcasting the full policy every sync (e.g. a non-AdamW optimizer), set +`weight_sync_mode="full"` in the config and serve without the sparse-only flags (a plain +`--weight-transfer-config '{"backend":"nccl"}'` is enough). """ from datasets import load_dataset diff --git a/examples/scripts/async_grpo_buckets/README.md b/examples/scripts/async_grpo_buckets/README.md new file mode 100644 index 00000000000..9c21eba5b89 --- /dev/null +++ b/examples/scripts/async_grpo_buckets/README.md @@ -0,0 +1,91 @@ +# Disaggregated async GRPO with bucket weight sync (`async_grpo_buckets.py`) + +Train a policy on your **local GPU** while a **remote vLLM HF Space** does generation — the two never share a NCCL +group. They stay in sync through an **HF Storage Bucket**: after each optimizer step the trainer uploads only the bf16 +weights that changed (a sparse patch, recovered by inverting the AdamW step — no snapshot), and the remote vLLM applies +it in place. A full **anchor** is sent every N syncs to bound drift. + +This is the power of disaggregation: **training and inference scale and live independently.** Put the trainer wherever +your training GPUs are, serve generation from an autoscaling Space (or many), and connect any environment server — all +glued together by a bucket and plain HTTP. + +``` + ┌──────────────────────────┐ sparse patches / anchors ┌───────────────────────────┐ + │ Local trainer (1 GPU) │ ───────────────────────────────▶ │ HF Storage Bucket │ + │ AsyncGRPOTrainer │ │ anchors/ + deltas/ │ + │ + rollout worker │ ◀─────────────────────────────── └───────────────────────────┘ + └──────────┬───────────────┘ apply in place ▲ + │ /v1/completions (HTTP) │ fetch + ▼ │ + ┌──────────────────────────┐ ┌───────────────────────────┐ + │ vLLM HF Space (GPU) │ ◀────────────────────────────── │ HFBucketWorkerExtension │ + │ serves generation │ │ (hf_bucket backend) │ + └──────────┬───────────────┘ └───────────────────────────┘ + │ tool calls (HTTP) + ▼ + ┌──────────────────────────┐ + │ Wordle env HF Space │ (no GPU; public one at openenv-wordle.hf.space) + └──────────────────────────┘ +``` + +Files in this directory: + +- `async_grpo_buckets.py` — the local trainer (AsyncGRPO + Wordle env, `weight_sync_backend="bucket"`). +- `vllm_space/` — Dockerfile + README to deploy the **vLLM inference Space** (GPU). +- `wordle_space/` — Dockerfile + README to deploy your own **Wordle environment Space** (optional; a public one exists). + +## Prerequisites + +```sh +pip install "trl @ git+https://github.com/huggingface/trl.git@delta-weight-sync-v3" +pip install "openenv-textarena @ git+https://huggingface.co/spaces/openenv/wordle" # the Wordle env client +hf auth login # needs write access to create the bucket + (for Option 1) deploy Spaces +``` + +The vLLM side needs a build with sparse weight transfer (vllm-project/vllm#40096); the Space Dockerfile installs it +from the nightly index. Locally, install the same nightly (see the repo's `dev_delta_v2/INSTALL.md`). + +### Step 1 — deploy the vLLM inference Space (GPU) + +```sh +# Create the Space (l4 GPU, Docker SDK). HF_TOKEN lets the Space read the bucket. +hf repos create /vllm-wordle-inference \ + --repo-type space --space-sdk docker + +hf upload /vllm-wordle-inference \ + examples/scripts/async_grpo_buckets/vllm_space/ . --repo-type space + +# Set the GPU + secrets/vars in the Space settings (or via the CLI / web UI): +# hardware: l4x1 ; secret HF_TOKEN= ; the Dockerfile already sets VLLM_SERVER_DEV_MODE=1 +``` + +Wait until `https://-vllm-wordle-inference.hf.space/health` returns 200 (first build pulls the image and +loads the model — a few minutes). + +### Step 2 — (optional) deploy your own Wordle env Space + +A public env runs at `https://openenv-wordle.hf.space`. To run your own (higher concurrency), deploy `wordle_space/` +the same way and pass its URL via `--env-url`. + +### Step 3 — train locally (1 GPU) + +```sh +CUDA_VISIBLE_DEVICES=0 python examples/scripts/async_grpo_buckets/async_grpo_buckets.py \ + --model Qwen/Qwen3-1.7B \ + --vllm-server-url https://-vllm-wordle-inference.hf.space \ + --env-url https://openenv-wordle.hf.space \ + --weight-sync-bucket-id /wordle-deltas +``` + +The bucket (`/wordle-deltas`) is created automatically on the first sync. + +## Notes + +- **Bucket vs NCCL.** Bucket sync works across hosts/Spaces (data plane = HF Hub, control plane = HTTP), at the cost of + object-storage latency (~seconds/sync). On a single node where the trainer and vLLM share NVLink, the default + `weight_sync_backend="nccl"` is ~100× faster — use the bucket backend specifically for the disaggregated/cross-host + case this example demonstrates. +- **Anchors.** `--weight-sync-anchor-interval N` uploads a full checkpoint every N syncs (sparse deltas in between) to + bound drift from any missed bits. Lower N = more robust, larger uploads. +- **Key flags must match** between the example and the Space: `Qwen/Qwen3-1.7B`, the `hf_bucket` backend, and the + `HFBucketWorkerExtension`. diff --git a/examples/scripts/async_grpo_buckets/async_grpo_buckets.py b/examples/scripts/async_grpo_buckets/async_grpo_buckets.py new file mode 100644 index 00000000000..187c2344670 --- /dev/null +++ b/examples/scripts/async_grpo_buckets/async_grpo_buckets.py @@ -0,0 +1,302 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "trackio", +# "openenv-textarena @ git+https://huggingface.co/spaces/openenv/wordle", +# ] +# /// + + +""" +Async GRPO with **disaggregated** training and inference: a local trainer fine-tunes the policy on 1 GPU while a +remote vLLM HF Space serves generation, and the two are kept in sync by **bucket weight sync** — only the bf16 weights +changed by each optimizer step are uploaded to an HF Storage Bucket as a sparse patch and applied in place on the +remote vLLM (`weight_sync_backend="bucket"`). No NCCL group between trainer and server, so they can live anywhere. + +Architecture: + Local (1 GPU): AsyncGRPOTrainer + rollout worker (Wordle tool calls run locally) + Remote Space 1: vLLM server with the HFBucketWorkerExtension (GPU, serves /v1/completions + applies patches) + Remote Space 2: TextArena Wordle game server (no GPU; a public one runs at openenv-wordle.hf.space) + HF Storage Bucket: holds the weight anchors (full) and sparse deltas + +See `examples/scripts/async_grpo_buckets/README.md` for the full, copy-pasteable deploy + run guide. Quick reference: + +# Option 1 — fully remote inference (vLLM on an HF Space) + +Deploy the vLLM Space from `examples/scripts/async_grpo_buckets/vllm_space/` (see the README there), then run locally: + +```sh +CUDA_VISIBLE_DEVICES=0 python examples/scripts/async_grpo_buckets/async_grpo_buckets.py \\ + --vllm-server-url https://.hf.space \\ + --env-url https://openenv-wordle.hf.space \\ + --weight-sync-bucket-id /wordle-deltas \\ + --model Qwen/Qwen3-1.7B +``` + +# Option 2 — local vLLM (for testing the bucket path on one node) + +```sh +# Terminal 1: vLLM with the bucket backend (transformers impl + V1 runner for the in-place sparse apply) +CUDA_VISIBLE_DEVICES=0 VLLM_SERVER_DEV_MODE=1 VLLM_USE_V2_MODEL_RUNNER=0 vllm serve Qwen/Qwen3-1.7B \\ + --model-impl transformers \\ + --worker-extension-cls trl.experimental.async_grpo.delta_engine.HFBucketWorkerExtension \\ + --weight-transfer-config '{"backend":"hf_bucket"}' \\ + --max-model-len 8192 --gpu-memory-utilization 0.8 --logprobs-mode processed_logprobs + +# Terminal 2: training +CUDA_VISIBLE_DEVICES=1 python examples/scripts/async_grpo_buckets/async_grpo_buckets.py \\ + --vllm-server-url http://localhost:8000 \\ + --weight-sync-bucket-id /wordle-deltas \\ + --model Qwen/Qwen3-1.7B +``` +""" + +import argparse +import functools +import logging +import os + +from datasets import Dataset +from textarena_env import TextArenaAction, TextArenaEnv + +from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer + + +logging.basicConfig( + level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Async GRPO for Wordle with bucket weight sync to a remote vLLM HF Space." + ) + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen3-1.7B", + help="Model identifier passed to AsyncGRPOTrainer for fine-tuning.", + ) + parser.add_argument( + "--env-url", + type=str, + default="https://openenv-wordle.hf.space", + help="URL for the Wordle environment server.", + ) + parser.add_argument( + "--vllm-server-url", + type=str, + default="http://localhost:8000", + help="URL for the vLLM server (local or remote HF Space).", + ) + parser.add_argument( + "--weight-sync-bucket-id", + type=str, + default=None, + help="HF Storage Bucket for the weight anchors + sparse deltas (e.g. 'user/wordle-deltas'). Required.", + ) + parser.add_argument( + "--weight-sync-anchor-interval", + type=int, + default=10, + help="Upload a full anchor checkpoint every N weight syncs; sparse deltas in between.", + ) + parser.add_argument( + "--dataset-size", + type=int, + default=1000, + help="Number of entries in the synthetic training dataset.", + ) + parser.add_argument( + "--num-generations", + type=int, + default=16, + help="Number of rollout generations per prompt.", + ) + parser.add_argument( + "--max-completion-length", + type=int, + default=1024, + help="Maximum number of tokens generated per turn.", + ) + parser.add_argument( + "--max-tool-calling-iterations", + type=int, + default=3, + help="Maximum number of guess turns per Wordle game.", + ) + parser.add_argument( + "--per-device-train-batch-size", + type=int, + default=32, + help="Per-device training batch size.", + ) + parser.add_argument( + "--max-steps", + type=int, + default=100, + help="Maximum number of training steps.", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=1e-6, + help="Learning rate for training.", + ) + parser.add_argument( + "--max-staleness", + type=int, + default=5, + help="Drop rollout samples generated more than this many weight versions ago.", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory where training outputs and checkpoints are stored.", + ) + parser.add_argument( + "--trackio-space-id", + type=str, + default=None, + help="Trackio space identifier for logging (optional).", + ) + return parser.parse_args() + + +prompt = """You are an expert Wordle solver with deep knowledge of English vocabulary, letter frequency patterns, and optimal guessing strategies. + +Follow these rules to play Wordle: + +1. The target is a 5-letter English word +2. You have 6 attempts to guess the correct word +3. After each guess, you receive color-coded feedback: + - GREEN (G): Letter is correct and in the correct position + - YELLOW (Y): Letter is in the word but in the wrong position + - GRAY (X): Letter is not in the word at all +4. All guesses must be valid 5-letter English words +5. You cannot reuse a word you've already guessed +6. Use the tool `guess` to make a guess. +""" + + +def reward_func(environments, **kwargs) -> list[float]: + return [env.reward for env in environments] + + +# Defined at module level (not nested in `main`) so it is picklable: the rollout worker runs in a spawned child +# process and pickles `environment_factory`. `env_url` is bound per-run via `functools.partial(WordleEnv, env_url)`. +class WordleEnv: + def __init__(self, env_url: str): + self.env_url = env_url + self.client = TextArenaEnv(base_url=env_url).sync() + self.reward = 0.0 + self.done = False + + def _reconnect(self): + self.client = TextArenaEnv(base_url=self.env_url).sync() + + def reset(self, **kwargs) -> str | None: + try: + result = self.client.reset() + except Exception: + self._reconnect() + result = self.client.reset() + # The game returns cumulative feedback each turn (new text appended at the end), so + # we store the previous full response and slice out only the newly appended part. + self._last_full_feedback = result.observation.messages[0].content + self.reward = 0.0 + self.done = False + return self._last_full_feedback + + def guess(self, guess: str) -> str: + """ + Make a guess in the Wordle environment. + + Args: + guess: The guessed word, formatted as '[abcde]' + + Returns: + The feedback message from the environment. + """ + if self.done: + raise ValueError("Game over.") + try: + result = self.client.step(TextArenaAction(message=guess)) + except Exception: + self._reconnect() + result = self.client.step(TextArenaAction(message=guess)) + _full_feedback = result.observation.messages[0].content + # Just take the new feedback since the last guess, which is the part appended to the end of the full feedback + feedback = _full_feedback[len(self._last_full_feedback) :] + self._last_full_feedback = _full_feedback + # For some reason, the environment doesn't penalize invalid moves and just returns the last reward. + # We check the feedback for the invalid move message and penalize it if found. + if "You attempted an invalid move" in feedback: + self.reward = 0.0 + else: + self.reward = result.reward + self.done = result.done + return feedback + + +def main() -> None: + args = parse_args() + if args.weight_sync_bucket_id is None: + raise ValueError("--weight-sync-bucket-id is required (e.g. 'your-username/wordle-deltas').") + + output_dir = args.output_dir or f"{args.model.split('/')[-1]}-async-wordle-GRPO" + dataset = Dataset.from_dict({"prompt": [[{"role": "user", "content": prompt}] for _ in range(args.dataset_size)]}) + + config = AsyncGRPOConfig( + vllm_server_base_url=args.vllm_server_url, + learning_rate=args.learning_rate, + bf16=True, + output_dir=output_dir, + max_completion_length=args.max_completion_length, + max_tool_calling_iterations=args.max_tool_calling_iterations, + per_device_train_batch_size=args.per_device_train_batch_size, + num_generations=args.num_generations, + max_staleness=args.max_staleness, + max_steps=args.max_steps, + logging_steps=1, + log_completions=True, + num_completions_to_print=1, + report_to="trackio", # logs locally — view with `trackio show --project async-wordle-buckets` + trackio_space_id=args.trackio_space_id, # also sync to a remote Space only if --trackio-space-id is given + project="async-wordle-buckets", # dedicated trackio project so runs get their own clean dashboard + chat_template_kwargs={"enable_thinking": False}, + # --- bucket weight sync: only the changed bf16 weights are shipped to the remote vLLM via an HF bucket --- + weight_sync_mode="sparse", + weight_sync_backend="bucket", + weight_sync_bucket_id=args.weight_sync_bucket_id, + weight_sync_anchor_interval=args.weight_sync_anchor_interval, + ) + + trainer = AsyncGRPOTrainer( + model=args.model, + args=config, + train_dataset=dataset, + reward_funcs=reward_func, + environment_factory=functools.partial(WordleEnv, args.env_url), + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/async_grpo_buckets/vllm_space/Dockerfile b/examples/scripts/async_grpo_buckets/vllm_space/Dockerfile new file mode 100644 index 00000000000..53ccc69f980 --- /dev/null +++ b/examples/scripts/async_grpo_buckets/vllm_space/Dockerfile @@ -0,0 +1,40 @@ +# vLLM inference Space for async GRPO with bucket weight sync. +# +# Serves Qwen/Qwen3-1.7B and applies sparse weight patches in place from an HF Storage Bucket via the +# `hf_bucket` weight-transfer backend (the HFBucketWorkerExtension registers it in the worker process). +# +# Requirements baked in here: +# * a vLLM with the sparse weight-transfer API (vllm-project/vllm#40096) — installed from the nightly index, +# since it is not in a stable release yet. +# * `--model-impl transformers` -> vLLM's runtime param names match the trainer's HF names, so every param is +# addressable by the in-place sparse apply (no fuse/unfuse remap). +# * `VLLM_USE_V2_MODEL_RUNNER=0` -> the in-place sparse apply (apply_sparse_weight_patches) is V1-runner only. +FROM vllm/vllm-openai:latest + +RUN apt-get update && apt-get install -y --no-install-recommends git && rm -rf /var/lib/apt/lists/* + +# Overlay the vLLM nightly that contains #40096 (sparse weight transfer), then TRL (v3 branch) + a recent transformers. +RUN pip install -U --pre vllm --extra-index-url https://wheels.vllm.ai/nightly/cu129 +RUN pip install "transformers>=5.2.0" +RUN pip install "trl @ git+https://github.com/huggingface/trl.git@delta-weight-sync-v3" + +# HF Spaces serves on port 7860 and runs as uid 1000 without a passwd entry — set HOME/USER + writable caches. +EXPOSE 7860 +ENV VLLM_SERVER_DEV_MODE=1 +ENV VLLM_USE_V2_MODEL_RUNNER=0 +ENV USER=user +ENV HOME=/tmp +ENV HF_HOME=/tmp/hf_cache +ENV TORCH_HOME=/tmp/torch_cache +ENV XDG_CACHE_HOME=/tmp/.cache +ENV FLASHINFER_WORKSPACE_DIR=/tmp/flashinfer + +ENTRYPOINT ["vllm", "serve", "Qwen/Qwen3-1.7B", \ + "--host", "0.0.0.0", \ + "--port", "7860", \ + "--model-impl", "transformers", \ + "--worker-extension-cls", "trl.experimental.async_grpo.delta_engine.HFBucketWorkerExtension", \ + "--weight-transfer-config", "{\"backend\":\"hf_bucket\"}", \ + "--max-model-len", "8192", \ + "--gpu-memory-utilization", "0.85", \ + "--logprobs-mode", "processed_logprobs"] diff --git a/examples/scripts/async_grpo_buckets/vllm_space/README.md b/examples/scripts/async_grpo_buckets/vllm_space/README.md new file mode 100644 index 00000000000..3f86a463d51 --- /dev/null +++ b/examples/scripts/async_grpo_buckets/vllm_space/README.md @@ -0,0 +1,19 @@ +--- +title: vLLM Wordle Inference +emoji: 🎮 +colorFrom: blue +colorTo: green +sdk: docker +app_port: 7860 +hardware: l4 +--- + +vLLM inference server for **disaggregated** async GRPO training. + +Serves `Qwen/Qwen3-1.7B` and keeps in sync with a remote trainer via **bucket weight sync**: the trainer uploads +the changed bf16 weights as sparse patches to an HF Storage Bucket, and this Space applies them in place using the +`hf_bucket` weight-transfer backend (registered by the `HFBucketWorkerExtension`, served with `--model-impl +transformers` + `VLLM_USE_V2_MODEL_RUNNER=0` so the in-place sparse apply works). + +Used by `examples/scripts/async_grpo_buckets/async_grpo_buckets.py` in the TRL repo. See +`examples/scripts/async_grpo_buckets/README.md` for the end-to-end deploy + run guide. diff --git a/examples/scripts/async_grpo_buckets/wordle_space/Dockerfile b/examples/scripts/async_grpo_buckets/wordle_space/Dockerfile new file mode 100644 index 00000000000..bfeca9af1e7 --- /dev/null +++ b/examples/scripts/async_grpo_buckets/wordle_space/Dockerfile @@ -0,0 +1,15 @@ +FROM python:3.11-slim + +RUN apt-get update && apt-get install -y --no-install-recommends git git-lfs && rm -rf /var/lib/apt/lists/* +RUN git lfs install +ENV GIT_CLONE_PROTECTION_ACTIVE=false +RUN pip install --no-build-isolation "openenv-textarena @ git+https://huggingface.co/spaces/openenv/wordle" \ + || (git clone https://huggingface.co/spaces/openenv/wordle /tmp/wordle && pip install /tmp/wordle && rm -rf /tmp/wordle) + +COPY app.py /app/app.py + +EXPOSE 7860 + +CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"] + +WORKDIR /app diff --git a/examples/scripts/async_grpo_buckets/wordle_space/README.md b/examples/scripts/async_grpo_buckets/wordle_space/README.md new file mode 100644 index 00000000000..616368fd90b --- /dev/null +++ b/examples/scripts/async_grpo_buckets/wordle_space/README.md @@ -0,0 +1,11 @@ +--- +title: Wordle Environment (High Capacity) +emoji: 🟩 +colorFrom: green +colorTo: yellow +sdk: docker +app_port: 7860 +--- + +Wordle environment server with 256 concurrent session support for async GRPO training. +Used by `examples/scripts/async_grpo_buckets/async_grpo_buckets.py` in the TRL repo. diff --git a/examples/scripts/async_grpo_buckets/wordle_space/app.py b/examples/scripts/async_grpo_buckets/wordle_space/app.py new file mode 100644 index 00000000000..9e0e251df8d --- /dev/null +++ b/examples/scripts/async_grpo_buckets/wordle_space/app.py @@ -0,0 +1,20 @@ +"""Wordle environment server with higher concurrent session capacity for async GRPO training.""" + +import os + +from openenv.core.env_server.http_server import ConcurrencyConfig, create_app +from textarena_env.models import TextArenaAction, TextArenaObservation +from textarena_env.server.environment import TextArenaEnvironment + + +# Mark TextArena as safe for concurrent sessions (each session gets its own game instance) +TextArenaEnvironment.SUPPORTS_CONCURRENT_SESSIONS = True + +max_sessions = int(os.getenv("MAX_CONCURRENT_SESSIONS", "256")) + +app = create_app( + TextArenaEnvironment, + TextArenaAction, + TextArenaObservation, + concurrency_config=ConcurrencyConfig(max_concurrent_envs=max_sessions), +) diff --git a/examples/scripts/async_grpo_delta.py b/examples/scripts/async_grpo_delta.py new file mode 100644 index 00000000000..4dd715caa73 --- /dev/null +++ b/examples/scripts/async_grpo_delta.py @@ -0,0 +1,96 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +AsyncGRPO with sparse weight sync over an HF Storage Bucket (`weight_sync_backend="bucket"`). + +Same sparse delta as the default NCCL path (only the changed bf16 weights, recovered by inverting the AdamW step), but +the patch is routed through an HF Storage Bucket instead of NCCL and applied in place on vLLM via PR #40096 — useful +when the trainer and the vLLM server are not in the same NCCL world (e.g. cross-host). No full-model broadcast. + +Start the vLLM server with the `delta` backend + worker extension (registers the bucket engine) and the +`transformers` model impl (so vLLM's runtime param names match the trainer's HF names — every param is then +addressable by the in-place sparse apply, no fuse/unfuse remap needed): + +# VLLM_USE_V2_MODEL_RUNNER=0 is required: the in-place sparse apply (apply_sparse_weight_patches, +# vLLM #40096) exists only on the V1 model runner. Without it the server picks V2 and every sparse +# delta update fails (the dense anchors still work, so it silently degrades to anchor-only sync). +CUDA_VISIBLE_DEVICES=1 VLLM_SERVER_DEV_MODE=1 VLLM_USE_V2_MODEL_RUNNER=0 vllm serve Qwen/Qwen3-1.7B \ + --model-impl transformers \ + --worker-extension-cls trl.experimental.async_grpo.delta_engine.HFBucketWorkerExtension \ + --weight-transfer-config '{"backend":"hf_bucket"}' \ + --max-model-len 2560 + +CUDA_VISIBLE_DEVICES=0 accelerate launch examples/scripts/async_grpo_delta.py +""" + +import logging +import os + +from datasets import load_dataset + +from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer +from trl.rewards import accuracy_reward + + +logging.basicConfig( + level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logging.getLogger("trl").setLevel(logging.INFO) + + +def format_sample(sample): + return { + "prompt": [{"role": "user", "content": sample["question"]}], + "solution": sample["answer"].split("####")[-1].strip(), + } + + +def main() -> None: + dataset = load_dataset("openai/gsm8k", "main", split="train") + dataset = dataset.map(format_sample, remove_columns=dataset.column_names) + + config = AsyncGRPOConfig( + output_dir="./results/async_grpo_delta", + per_device_train_batch_size=1, + num_generations=8, + max_completion_length=512, + max_steps=60, + learning_rate=1e-5, + logging_steps=1, + bf16=True, + report_to="none", + project="async_grpo_delta", + log_completions=True, + # Qwen3 thinking traces blow past the completion cap on GSM8K (truncated -> no answer -> + # zero reward); disable thinking so completions are short and accuracy_reward gets signal. + chat_template_kwargs={"enable_thinking": False}, + weight_sync_mode="sparse", # default; only the changed bf16 weights are sent + weight_sync_backend="bucket", # route the patch through a bucket instead of NCCL + weight_sync_bucket_id="/async-grpo-delta-demo", # set to a bucket you own + weight_sync_anchor_interval=20, # full anchor every N syncs; sparse deltas in between + weight_sync_encoding="gap_delta", # raw | gap_delta | nvcomp_cascaded + ) + trainer = AsyncGRPOTrainer( + model="Qwen/Qwen3-1.7B", + args=config, + train_dataset=dataset, + reward_funcs=accuracy_reward, + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 279ccce42c1..e7427e5a258 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,14 @@ vlm = [ math_verify = [ "math-verify>=0.5.2", ] +delta_weight_sync = [ + # Delta weight sync for AsyncGRPO (trl/experimental/async_grpo). Needs vLLM's sparse weight + # transfer (vllm-project/vllm#40096): merged to main after v0.22.0, not in a release yet — + # install vLLM from the nightly index until it ships (so no `vllm` pin here). Serve with + # `--model-impl transformers` and `VLLM_USE_V2_MODEL_RUNNER=0`. + "huggingface-hub>=1.17.0", # HF Storage Bucket API (create/batch/download_bucket_files) + "nvidia-nvcomp-cu12>=5.2.0", # optional: only for the nvcomp_cascaded index encoding (CUDA 12) +] openreward = [ "openreward>=0.1.109; python_version >= '3.11'", # openreward requires Python 3.11+ ] diff --git a/trl/experimental/async_grpo/async_grpo_config.py b/trl/experimental/async_grpo/async_grpo_config.py index d2b588a9033..d9d66663d33 100644 --- a/trl/experimental/async_grpo/async_grpo_config.py +++ b/trl/experimental/async_grpo/async_grpo_config.py @@ -76,6 +76,21 @@ class AsyncGRPOConfig(_BaseConfig): heartbeat_stale_after_s (`float`, *optional*, defaults to `300.0`): Seconds since the rollout worker's last heartbeat after which the trainer treats it as hung and aborts. + weight_sync_mode (`str`, *optional*, defaults to `"sparse"`): + How to sync the policy to vLLM. `"sparse"` sends only the changed bf16 weights (the changed set is + recovered by inverting the AdamW step from the resident optimizer moments — no snapshot kept); requires a + `torch.optim.AdamW` optimizer and a vLLM with sparse weight transfer. `"full"` broadcasts the entire policy + over NCCL every sync (use this when the optimizer is not AdamW). + weight_sync_backend (`str`, *optional*, defaults to `"nccl"`): + Transport for the sparse patches: `"nccl"` (broadcast in place over the NCCL group) or `"bucket"` (upload + to an HF Storage Bucket and apply from there). The `"full"` mode is always NCCL. + weight_sync_bucket_id (`str`, *optional*): + HF Storage Bucket for the patches/anchors (created if missing). Required when `weight_sync_backend="bucket"`. + weight_sync_anchor_interval (`int`, *optional*, defaults to `20`): + In `"sparse"` mode, send a full anchor every N syncs; sparse deltas in between (bounds drift from low-byte / + inversion misses). + weight_sync_encoding (`str`, *optional*, defaults to `"gap_delta"`): + Index encoding for bucket patches: `"raw"`, `"gap_delta"`, or `"nvcomp_cascaded"` (needs nvcomp). > Parameters that control the logging @@ -195,6 +210,47 @@ class AsyncGRPOConfig(_BaseConfig): }, ) + # Parameters that control weight sync + weight_sync_mode: str = field( + default="sparse", + metadata={ + "help": "How to sync the policy to vLLM. `'sparse'` (default) sends only the changed bf16 weights (the " + "changed set is recovered by inverting the AdamW step from the resident optimizer moments); " + "requires a `torch.optim.AdamW` optimizer and a vLLM with sparse weight transfer, served with " + "`--model-impl transformers` and `VLLM_USE_V2_MODEL_RUNNER=0`. `'full'` broadcasts the entire policy over " + "NCCL every sync (use this when the optimizer is not AdamW).", + "choices": ["sparse", "full"], + }, + ) + weight_sync_backend: str = field( + default="nccl", + metadata={ + "help": "Transport for the sparse patches: `'nccl'` (default, broadcast in place over the NCCL group) or " + "`'bucket'` (upload to an HF Storage Bucket and apply from there). The `'full'` mode is always NCCL.", + "choices": ["nccl", "bucket"], + }, + ) + weight_sync_bucket_id: str | None = field( + default=None, + metadata={ + "help": "HF Storage Bucket for the patches/anchors (created if missing). Required when " + "`weight_sync_backend='bucket'`." + }, + ) + weight_sync_anchor_interval: int = field( + default=20, + metadata={ + "help": "In `'sparse'` mode, send a full anchor every N syncs; sparse deltas in between (bounds drift " + "from low-byte / inversion misses)." + }, + ) + weight_sync_encoding: str = field( + default="gap_delta", + metadata={ + "help": "Index encoding for bucket patches: 'raw', 'gap_delta', or 'nvcomp_cascaded' (needs nvcomp)." + }, + ) + # Parameters that control the logging log_completions: bool = field( default=False, @@ -211,6 +267,16 @@ class AsyncGRPOConfig(_BaseConfig): def __post_init__(self): super().__post_init__() + # Validate the weight-sync configuration. + if self.weight_sync_mode not in ("sparse", "full"): + raise ValueError(f"weight_sync_mode must be 'sparse' or 'full', got {self.weight_sync_mode!r}") + if self.weight_sync_backend not in ("nccl", "bucket"): + raise ValueError(f"weight_sync_backend must be 'nccl' or 'bucket', got {self.weight_sync_backend!r}") + if self.weight_sync_mode == "full" and self.weight_sync_backend != "nccl": + raise ValueError("weight_sync_mode='full' transfers over NCCL; set weight_sync_backend='nccl'.") + if self.weight_sync_backend == "bucket" and self.weight_sync_bucket_id is None: + raise ValueError("weight_sync_backend='bucket' requires weight_sync_bucket_id to be set.") + # Accelerator config: required for the async IterableDataset-backed dataloader to work correctly. # split_batches=True and dispatch_batches=True ensure that the main process drives the dataloader # and batches are broadcast to other processes rather than each process pulling independently. diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index 78a008a7ff8..abada9a7e5f 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -35,7 +35,8 @@ from .async_grpo_config import AsyncGRPOConfig from .async_rollout_worker import AsyncRolloutWorker -from .weight_transfer import WeightTransferClient +from .weight_diff import AdamWInversionChangeDetector +from .weight_transfer import make_weight_transfer logger = get_logger(__name__) @@ -109,8 +110,8 @@ def on_train_begin(self, _args, _state, _control, **_kwargs): if self._fired: return self._fired = True - if self._trainer.accelerator.is_main_process and self._trainer.weight_transfer is not None: - self._trainer.weight_transfer.init_weight_transfer() + if self._trainer.weight_transfer is not None: + self._trainer.weight_transfer.init(self._trainer.accelerator) self._trainer._sync_weight() @@ -413,38 +414,42 @@ def __init__( self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} self._train_tokens_start_time = None self.model_version = 0 - # Create worker and queue on rank 0 + self._change_detector: AdamWInversionChangeDetector | None = None # sparse sync only; created in compute_loss + # Collect weight metadata once — names/dtypes/shapes are fixed for the lifetime of training, and identical + # across ranks (DTensor.shape returns the global shape without triggering any all-gather). + weight_names, weight_dtype_names, weight_shapes = [], [], [] + for name, param in model.named_parameters(): + name = name.removeprefix("module.") # DDP/FSDP1 wrapping, avoids vllm module not exist error + weight_names.append(name) + weight_dtype_names.append(str(param.dtype).split(".")[-1]) + weight_shapes.append(list(param.shape)) + + # The weight transport lives on every rank: rank 0 drives it, the others walk the iterator during sync so the + # FSDP2 full_tensor() collectives line up. An injected stub worker (tests, no real vLLM) leaves it unset. + self.weight_transfer = make_weight_transfer( + self.args.weight_sync_backend, + vllm_server_url=self.args.vllm_server_base_url, + server_timeout=self.args.vllm_server_timeout, + weight_update_info={ + "names": weight_names, + "dtype_names": weight_dtype_names, + "shapes": weight_shapes, + }, + bucket_id=self.args.weight_sync_bucket_id, + encoding=self.args.weight_sync_encoding, + ) + + # Create the rollout worker and its queue on rank 0. if self.accelerator.is_main_process: if self.train_dataset is None: raise ValueError("train_dataset is required for AsyncGRPOTrainer") if rollout_worker is not None: - # Use the injected worker (e.g. a stub in tests). The queue is owned by the worker. - # Weight transfer is also expected to be wired by the test fixture (or left as None - # if the stub doesn't sync to a real vLLM). + # Injected worker (e.g. a stub in tests). The queue is owned by the worker; with no real vLLM to + # sync to, drop the transport so weight sync is a no-op. self.rollout_worker = rollout_worker self.weight_transfer = None else: - # Collect weight metadata once — names/dtypes/shapes are fixed for the lifetime of training. - # DTensor.shape returns the global shape without triggering any all-gather. - weight_names, weight_dtype_names, weight_shapes = [], [], [] - for name, param in model.named_parameters(): - # DDP/FSDP1 wrapping, avoids vllm module not exist error - name = name.removeprefix("module.") - weight_names.append(name) - weight_dtype_names.append(str(param.dtype).split(".")[-1]) - weight_shapes.append(list(param.shape)) - self.weight_transfer = WeightTransferClient( - vllm_server_url=self.args.vllm_server_base_url, - server_timeout=self.args.vllm_server_timeout, - weight_update_info={ - "names": weight_names, - "dtype_names": weight_dtype_names, - "shapes": weight_shapes, - "packed": True, - "is_checkpoint_format": True, - }, - ) self.rollout_worker = AsyncRolloutWorker( model_name=model_name, dataset=train_dataset, @@ -469,7 +474,6 @@ def __init__( else: self.rollout_queue = None self.rollout_worker = None - self.weight_transfer = None # Add callbacks. Registration order matters: weight sync first, then worker start. self.add_callback(_InitialWeightSyncCallback(self)) @@ -517,6 +521,10 @@ def _set_signature_columns_if_needed(self): ] def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + # Create the change detector once the prepared optimizer exists (it validates AdamW up front, + # so a non-AdamW optimizer fails fast). No-op unless weight_sync_mode='sparse'. + self._maybe_init_change_detector() + input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] completion_mask = inputs["completion_mask"] @@ -656,49 +664,87 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: super().log(logs, start_time) self._metrics[mode].clear() - def _streaming_iter(self): - # Iterate parameters one at a time. For FSDP2 (DTensor), full_tensor() all-gathers just this parameter across - # FSDP ranks, then frees it once the generator advances — avoiding materializing the full model in memory. + def _weight_iter(self, sparse: bool): + """Yield ``(name, full_tensor, mask)`` for weight sync. + + For a sparse patch (``sparse=True``) only changed params are yielded, with their element-level change mask from + the AdamW-inversion detector; for a full transfer / anchor every param is yielded with ``mask=None``. Iterate + params one at a time: for FSDP2 (DTensor), ``full_tensor()`` all-gathers just this param across FSDP ranks then + frees it once the generator advances — avoiding materializing the full model. All ranks must walk this + identically so the ``full_tensor()`` collectives line up. + + Under FSDP2 the change mask is itself a DTensor sharded exactly like the param (it is reconstructed from the + per-shard AdamW moments), so it is gathered to the global shape alongside the param. The skip decision uses + ``mask.any()``, which all-reduces for a DTensor and is therefore identical on every rank — keeping the + ``full_tensor()`` collectives that follow aligned. + """ device = self.accelerator.device + masks = self._change_detector._validated_masks if (sparse and self._change_detector is not None) else {} for name, param in self.model.named_parameters(): name = name.removeprefix("module.") # DDP/FSDP1 wrapping + if sparse: + mask = masks.get(name) + if mask is None: + continue # param never stepped -> not in this delta (consistent across ranks) + # Gather the per-shard mask to the global shape first, so `.any()` runs on a plain tensor (DTensor + # has no sharding strategy for the reduction) and lines up with `full` below. + if isinstance(mask, DTensor): + mask = mask.full_tensor() + if not bool(mask.any()): + continue # unchanged param -> not in this delta + else: + mask = None full = param.full_tensor() if isinstance(param, DTensor) else param.detach() if full.device != device: full = full.to(device) - yield name, full + yield name, full, mask + + def _maybe_init_change_detector(self): + """Create the AdamW-inversion change detector once the (prepared) optimizer exists. Validates that the + optimizer is AdamW (raises otherwise, pointing at weight_sync_mode='full'). No-op unless sparse sync.""" + if ( + self.args.weight_sync_mode == "sparse" + and self._change_detector is None + and getattr(self, "optimizer", None) + ): + # Unwrap AcceleratedOptimizer to the native torch optimizer. + raw_optimizer = getattr(self.optimizer, "optimizer", self.optimizer) + self._change_detector = AdamWInversionChangeDetector(self.model, raw_optimizer) def _sync_weight(self): + # No real vLLM to sync to (injected stub worker in tests). + if self.weight_transfer is None: + return t0 = time.time() - logger.info("Weight sync: pausing vLLM...") - if self.accelerator.is_main_process and self.weight_transfer: - self.weight_transfer.pause() - t_pause = time.time() - logger.info(f"Weight sync: pause took {t_pause - t0:.1f}s, waiting for all ranks...") - - self.accelerator.wait_for_everyone() - t_barrier = time.time() - - logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)") - if self.accelerator.is_main_process and self.weight_transfer: - self.weight_transfer.send_weights(self._streaming_iter()) - else: - # Non-rank-0 processes must still participate in full_tensor() collectives for FSDP2. - for _ in self._streaming_iter(): - pass - t_transfer = time.time() - - self.accelerator.wait_for_everyone() + next_version = self.model_version + 1 + # In sparse mode, send a full anchor on the first sync and every Nth after, to bound drift from inversion + # misses; in full mode every sync is a full transfer. Computed identically on every rank (deterministic) so + # all ranks pick the same path and their full_tensor() collectives line up. + is_anchor = ( + self.args.weight_sync_mode == "full" + or next_version == 1 + or next_version % self.args.weight_sync_anchor_interval == 0 + ) + sparse = not is_anchor + + # Reconstruct the change mask from the AdamW state (no snapshot) for a sparse patch. + if sparse and self._change_detector is not None: + self._change_detector.compute_masks() + + # The transport owns its phasing (NCCL: single-phase broadcast; bucket: upload then apply). + self.weight_transfer.sync( + iter_fn=self._weight_iter, + sparse=sparse, + is_anchor=is_anchor, + version=next_version, + accelerator=self.accelerator, + ) - logger.info(f"Weight sync: resuming vLLM... (transfer took {t_transfer - t_barrier:.1f}s)") - if self.accelerator.is_main_process: - if self.weight_transfer: - self.weight_transfer.resume() - self.model_version += 1 - if self.rollout_worker: - self.rollout_worker.update_model_version(self.model_version) - weight_sync_time_s = time.time() - t0 - self._metrics["train"]["weight_sync_time_s"].append(weight_sync_time_s) - logger.info(f"Weight sync: done. Total {weight_sync_time_s:.1f}s") + # Bump on every rank to keep the anchor cadence in lockstep; only rank 0 owns the rollout worker's version. + self.model_version += 1 + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.update_model_version(self.model_version) + self._metrics["train"]["weight_sync_time_s"].append(time.time() - t0) def _inner_training_loop(self, *args, **kwargs): try: diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index dbf426695ba..250600e05b9 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import copy import inspect import multiprocessing as mp import os @@ -75,6 +76,7 @@ class RolloutGroup: tool_mask: list[list[int]] tool_call_counts: list[int] tool_failure_counts: list[int] + environments: list[object] model_version: int queued_at: float = 0.0 @@ -338,6 +340,7 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None: tool_mask=[], tool_call_counts=[], tool_failure_counts=[], + environments=[], model_version=self.model_version, ) pending_completed[group_id] = 0 @@ -382,6 +385,9 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None: group.tool_mask.append(tool_mask) group.tool_call_counts.append(tool_call_count) group.tool_failure_counts.append(tool_failure_count) + if self.environments is not None: + # NOTE(@aminediro) Snapshot the env (with its accumulated reward) for this completion; the slot's env + group.environments.append(copy.copy(self.environments[slot])) self._total_completion_tokens += sum(tool_mask) pending_completed[group_id] += 1 @@ -598,6 +604,8 @@ async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]: completion_ids=group.completions_ids, **group.reward_kwargs, ) + if group.environments: + kwargs["environments"] = group.environments all_rewards = await asyncio.gather( *[ reward_func(**kwargs) diff --git a/trl/experimental/async_grpo/delta_codec.py b/trl/experimental/async_grpo/delta_codec.py new file mode 100644 index 00000000000..095846eec2b --- /dev/null +++ b/trl/experimental/async_grpo/delta_codec.py @@ -0,0 +1,185 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GPU-resident sparse delta extraction + index encoding. + +The change mask, ``nonzero``, and value gather all run on the device, so only the sparse payload — and, for the disk +transport, its compressed index form — crosses PCIe. This replaces the dense ``tensor.to(bf16).cpu()`` + CPU +``nonzero``/gather in the v1 upload path, which copied the full dense tensor (~100%) to host just to keep ~1-3% of it. + +Index encodings (lossless, shrink only the index half — values are sent raw): + +- ``raw`` : int32 absolute flat positions (4 B/elem) — what NCCL / vLLM #40096 expect. +- ``gap_delta`` : ``idx[k] - idx[k-1] - 1`` packed to uint16 (2 B), uint32 fallback per param. +- ``nvcomp`` : nvCOMP "Cascaded" (delta + bit-pack [+ RLE]) over the int32 indices, on GPU. +""" + +from __future__ import annotations + +from enum import Enum + +import numpy as np +import torch + + +try: + from nvidia import nvcomp # pip: nvidia-nvcomp + + _NVCOMP_OK = True +except Exception: # pragma: no cover - environment dependent + nvcomp = None + _NVCOMP_OK = False + + +class Encoding(str, Enum): + """Index-encoding scheme for a sparse delta patch (values are always stored raw).""" + + RAW = "raw" # int32 absolute positions (4 B/elem) + GAP_DELTA = "gap_delta" # uint16 gaps, uint32 fallback per param (~2 B/elem) + NVCOMP_CASCADED = "nvcomp_cascaded" # nvCOMP Cascaded delta+bitpack on the GPU (~1.3 B/elem) + + +class UpdateKind(str, Enum): + """vLLM weight-update format (the ``update_kind`` field in the update_info sent to vLLM).""" + + DENSE = "dense" + SPARSE_FLAT = "sparse_flat" + + +def extract_sparse( + param: torch.Tensor, + mask: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pull changed ``(indices, values)`` for one param, entirely on ``param``'s device. + + Args: + param (`torch.Tensor`): + Full dense parameter (kept on the GPU; never copied to host here). + mask (`torch.Tensor`): + Boolean change mask, same numel as ``param``, same device. + + Returns: + `tuple` of: + - indices (`torch.Tensor`): 1D `int32` flat positions (ascending), on device. + - values (`torch.Tensor`): 1D values at those positions, param dtype, on device. + """ + flat = param.detach().reshape(-1) + idx = mask.reshape(-1).nonzero(as_tuple=True)[0] # device; one sync (dynamic size) + vals = flat.index_select(0, idx) + return idx.to(torch.int32), vals + + +def extract_sparse_batched( + items: list[tuple[str, torch.Tensor, torch.Tensor]], + max_chunk_numel: int = 256_000_000, +) -> list[tuple[str, torch.Tensor, torch.Tensor]]: + """Batched [`extract_sparse`] over many params, chunked to bound the ``nonzero`` buffer. + + Args: + items (`list[tuple[str, torch.Tensor, torch.Tensor]]`): + ``(name, tensor, mask)`` triples; ``mask`` is the per-param boolean change mask. + max_chunk_numel (`int`, *optional*, defaults to `256_000_000`): + Cap on the total element count concatenated per ``nonzero``. + + Returns: + `list[tuple[str, torch.Tensor, torch.Tensor]]`: ``(name, int32 local indices, values)`` per input param, in + input order. + """ + out: list[tuple[str, torch.Tensor, torch.Tensor]] = [] + chunk: list[tuple[str, torch.Tensor, torch.Tensor]] = [] + chunk_numel = 0 + for item in items: + n = item[1].numel() + if chunk and chunk_numel + n > max_chunk_numel: + out.extend(_extract_sparse_chunk(chunk)) + chunk, chunk_numel = [], 0 + chunk.append(item) + chunk_numel += n + if chunk: + out.extend(_extract_sparse_chunk(chunk)) + return out + + +def _extract_sparse_chunk( + items: list[tuple[str, torch.Tensor, torch.Tensor]], +) -> list[tuple[str, torch.Tensor, torch.Tensor]]: + """One batched extraction: a single ``nonzero`` over the concatenated masks, split back per param.""" + device = items[0][1].device + flats = [tensor.detach().reshape(-1) for _, tensor, _ in items] + sizes = torch.tensor([f.numel() for f in flats], device=device) + offsets = torch.cat([sizes.new_zeros(1), torch.cumsum(sizes, 0)]) + global_idx = torch.cat([mask.reshape(-1) for _, _, mask in items]).nonzero(as_tuple=True)[0] + bounds = torch.searchsorted(global_idx, offsets[1:]).tolist() + + out = [] + prev = 0 + for i, (name, _, _) in enumerate(items): + g = global_idx[prev : bounds[i]] - offsets[i] # local positions within this param + out.append((name, g.to(torch.int32), flats[i].index_select(0, g))) + prev = bounds[i] + return out + + +def gap_delta_encode(idx: torch.Tensor) -> torch.Tensor: + """Gap-encode sorted positions: ``delta[k] = idx[k] - idx[k-1] - 1`` (idx[-1] := -1). + + Returns the gaps as ``uint16`` if the max gap fits, else ``uint32`` — so one outlier never bumps the whole param to + 4 B. The dtype *is* the width (no separate width needed); the receiver inverts with [`gap_delta_decode`]. + + Args: + idx (`torch.Tensor`): + 1D ascending `int32` positions (as returned by [`extract_sparse`]). + + Returns: + `torch.Tensor`: 1D gaps, dtype `uint16` or `uint32`, same device. + """ + if idx.numel() == 0: + return idx.to(torch.uint16) + idx64 = idx.long() + prev = torch.cat([idx64.new_full((1,), -1), idx64[:-1]]) + deltas = idx64 - prev - 1 + return deltas.to(torch.uint16 if int(deltas.max()) <= 0xFFFF else torch.uint32) + + +def gap_delta_decode(deltas: torch.Tensor) -> torch.Tensor: + """Invert [`gap_delta_encode`] → 1D ascending `int32` positions on ``deltas``' device.""" + if deltas.numel() == 0: + return deltas.to(torch.int32) + return (torch.cumsum(deltas.long() + 1, dim=0) - 1).to(torch.int32) + + +def nvcomp_available() -> bool: + return _NVCOMP_OK + + +def nvcomp_encode(idx: torch.Tensor) -> torch.Tensor: + """Compress absolute int32 indices with nvCOMP Cascaded → ``uint8`` byte tensor (CPU). + + Cascaded does the delta + bit-pack internally, so callers pass raw int32 positions (no gap-encoding needed). The + self-describing bitstream carries dtype + length, so [`nvcomp_decode`] needs nothing else. Cascaded is a GPU codec, + so indices are moved to CUDA. + """ + if not _NVCOMP_OK: + raise RuntimeError("nvidia-nvcomp not installed") + comp = nvcomp.Codec(algorithm="Cascaded").encode(nvcomp.as_array(idx.to("cuda", torch.int32).contiguous())) + return torch.from_numpy(np.asarray(comp.cpu()).view(np.uint8).copy()) + + +def nvcomp_decode(raw: torch.Tensor) -> torch.Tensor: + """Inverse of [`nvcomp_encode`]: ``uint8`` bytes → 1D ``int32`` indices (CPU).""" + if not _NVCOMP_OK: + raise RuntimeError("nvidia-nvcomp not installed") + dec = nvcomp.Codec(algorithm="Cascaded").decode(nvcomp.as_array(raw.to("cuda").contiguous())) + return torch.from_numpy(np.asarray(dec.cpu()).view(np.int32).copy()) # reinterpret bytes → int32 diff --git a/trl/experimental/async_grpo/delta_engine.py b/trl/experimental/async_grpo/delta_engine.py new file mode 100644 index 00000000000..b72fa32a6aa --- /dev/null +++ b/trl/experimental/async_grpo/delta_engine.py @@ -0,0 +1,301 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import os +import tempfile +import time +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from dataclasses import dataclass + +import torch +from huggingface_hub import batch_bucket_files, download_bucket_files +from safetensors import safe_open +from safetensors.torch import save_file +from vllm.distributed.weight_transfer.base import ( + SparseWeightPatch, + WeightTransferEngine, + WeightTransferInitInfo, + WeightTransferUpdateInfo, +) +from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory + +from .delta_codec import ( + Encoding, + UpdateKind, + extract_sparse_batched, + gap_delta_decode, + gap_delta_encode, + nvcomp_decode, + nvcomp_encode, +) +from .weight_diff import PatchMetadata + + +try: + from vllm.logger import init_logger + + logger = init_logger(f"vllm.{__name__}") +except Exception: # pragma: no cover - vLLM always present where this module is used + logger = logging.getLogger(__name__) + + +@contextmanager +def _fetch(update_info): + """Download a patch from the bucket to a temp file and yield an open safetensors handle.""" + # TODO(@aminediro): writes only 1 safetensors, for very large model, this needs to be multiple files + with tempfile.TemporaryDirectory() as tmpdir: + path = f"{tmpdir}/weights.safetensors" + download_bucket_files(update_info.repo_id, files=[(update_info.filename, path)]) + with safe_open(path, framework="pt", device="cpu") as f: + yield f + + +def _encode_idx(idx: torch.Tensor, encoding: Encoding) -> torch.Tensor: + """Encode absolute int32 indices for storage. Returns a CPU tensor (I32, U16/U32, or U8 bytes).""" + encoding = Encoding(encoding) + if encoding is Encoding.RAW: + return idx.to(torch.int32).cpu().contiguous() + if encoding is Encoding.GAP_DELTA: + return gap_delta_encode(idx).cpu().contiguous() # native uint16/uint32 (dtype = width) + return nvcomp_encode(idx) # Encoding.NVCOMP_CASCADED -> uint8 CPU bytes + + +def _decode_idx(raw: torch.Tensor, encoding: Encoding) -> torch.Tensor: + """Inverse of [`_encode_idx`] → 1D int32 absolute indices. + + Self-describing: ``raw.dtype`` carries the gap-delta width, so no element count is needed. + """ + encoding = Encoding(encoding) + if encoding is Encoding.RAW: + return raw.to(torch.int32) + if encoding is Encoding.GAP_DELTA: + return gap_delta_decode(raw) + return nvcomp_decode(raw) + + +@dataclass +class HFBucketWeightTransferInitInfo(WeightTransferInitInfo): + pass + + +@dataclass +class HFBucketWeightTransferUpdateInfo(WeightTransferUpdateInfo): + """Per-sync info sent via ``/update_weights`` — just bucket coordinates + kind. + + Names and per-param nnz are read from the downloaded file, so ``num_updates_list`` is not required here (we + override the base validation that would otherwise demand it for sparse). + """ + + repo_id: str = "" # bucket_id + filename: str = "" + + def __post_init__(self) -> None: + if self.update_kind not in (UpdateKind.DENSE, UpdateKind.SPARSE_FLAT): + raise ValueError(f"Unsupported update_kind: {self.update_kind}") + + +class HFBucketWeightTransferEngine( + WeightTransferEngine[HFBucketWeightTransferInitInfo, HFBucketWeightTransferUpdateInfo] +): + """Weight transfer engine using an HF Storage Bucket as the data plane.""" + + init_info_cls = HFBucketWeightTransferInitInfo + update_info_cls = HFBucketWeightTransferUpdateInfo + + def init_transfer_engine(self, init_info: HFBucketWeightTransferInitInfo) -> None: + pass + + def receive_weights( + self, + update_info: HFBucketWeightTransferUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + """Anchor path: download full safetensors from the bucket and load them directly.""" + t0 = time.time() + with _fetch(update_info) as f: + t_dl = time.time() + for name in f.keys(): + load_weights([(name, f.get_tensor(name))]) + meta = PatchMetadata.from_metadata_dict(f.metadata()) + t_apply = time.time() + logger.info( + "Applied anchor (step %d, %d params) | download %.2fs load %.2fs", + meta.model_version, + meta.num_changed_params, + t_dl - t0, + t_apply - t_dl, + ) + + def receive_sparse_weights( + self, + update_info: HFBucketWeightTransferUpdateInfo, + apply_patches: Callable[[list[SparseWeightPatch]], None], + ) -> None: + t0 = time.time() + patches = [] + with _fetch(update_info) as f: + t_dl = time.time() + names, idxs, vals = [], [], [] + meta = PatchMetadata.from_metadata_dict(f.metadata()) + for name, idx, values in iter_sparse_patches(f): + names.append(name) + idxs.append(idx) + vals.append(values) + if names: + device = torch.accelerator.current_device_index() + sizes = [i.numel() for i in idxs] + all_idx = torch.cat(idxs).to(device) + all_val = torch.cat(vals).to(device) + off = 0 + for name, n in zip(names, sizes, strict=False): + patches.append( + SparseWeightPatch(name=name, indices=all_idx[off : off + n], values=all_val[off : off + n]) + ) + off += n + if patches: + apply_patches(patches) + + t_apply = time.time() + logger.info( + "Applied delta (step %d, %d params, sparsity=%.4f) | download %.2fs decode+apply %.2fs", + meta.model_version, + meta.num_changed_params, + meta.sparsity, + t_dl - t0, + t_apply - t_dl, + ) + + def shutdown(self) -> None: + pass + + @staticmethod + def trainer_send_weights(iterator, trainer_args) -> None: + raise NotImplementedError("Use WeightTransferClient.upload_patch / apply_patch instead") + + @staticmethod + def upload( + iterator: Iterator[tuple[str, torch.Tensor, torch.Tensor | None]], + bucket_id: str, + filename: str, + model_version: int = 0, + encoding: Encoding = Encoding.GAP_DELTA, + ) -> PatchMetadata | None: + """Encode params as a safetensors patch and push to the bucket. + + Returns the :class:`PatchMetadata` (also written to the safetensors header), or ``None`` if the iterator was + empty. + """ + tensors, meta = encode_patch(iterator, model_version=model_version, encoding=encoding) + if tensors is None: + return None + # Write to a temp file and upload the path: hf-xet's in-memory `upload_bytes` panics on + # multi-GB buffers (e.g. a full-model anchor); the file path uses the large-file code path. + with tempfile.TemporaryDirectory() as tmpdir: + local_path = f"{tmpdir}/patch.safetensors" + save_file(tensors, local_path, metadata=meta.to_metadata_dict()) + size_mb = os.path.getsize(local_path) / 1e6 + batch_bucket_files(bucket_id, add=[(local_path, filename)]) + logger.info( + "[delta_engine] uploaded %s/%s (%.1f MB, %d params, sparse=%s, enc=%s, sparsity=%.4f)", + bucket_id, + filename, + size_mb, + meta.num_changed_params, + meta.sparse, + meta.encoding.value, + meta.sparsity, + ) + return meta + + +def encode_patch( + iterator: Iterator[tuple[str, torch.Tensor, torch.Tensor | None]], + model_version: int = 0, + encoding: Encoding = Encoding.GAP_DELTA, +) -> tuple[dict[str, torch.Tensor] | None, PatchMetadata | None]: + """Build the safetensors tensor dict + metadata for a patch (no I/O). + + Each item is ``(name, tensor, mask)``: + + - ``mask is None``: full tensor stored as ``name`` (anchor). + - ``mask`` provided: GPU sparse-extract; store encoded indices as ``{name}.idx`` and values as ``{name}.val``. + ``encoding`` is ``"raw"`` (int32) or ``"gap_delta"`` (uint16 gap bytes, uint32 fallback per param). + """ + encoding = Encoding(encoding) + tensors: dict[str, torch.Tensor] = {} + delta_items: list[tuple[str, torch.Tensor, torch.Tensor]] = [] + n_params = 0 + total_changed = 0 + total_elements = 0 + + for name, tensor, mask in iterator: + n_params += 1 + total_elements += tensor.numel() + if mask is None: # anchor: store the full tensor + tensors[name] = tensor.detach().to(torch.bfloat16).cpu().contiguous().clone() + total_changed += tensor.numel() + else: + delta_items.append((name, tensor, mask)) + + for name, idx, vals in extract_sparse_batched(delta_items): + total_changed += idx.numel() + tensors[f"{name}.val"] = vals.to(torch.bfloat16).cpu().contiguous() + tensors[f"{name}.idx"] = _encode_idx(idx, encoding) + + sparse = bool(delta_items) + if not tensors: + return None, None + + meta = PatchMetadata( + sparse=sparse, + model_version=model_version, + num_changed_params=n_params, + total_changed_elements=total_changed, + total_elements=total_elements, + sparsity=1.0 - total_changed / max(total_elements, 1), + encoding=encoding, + ) + return tensors, meta + + +def iter_sparse_patches(f) -> Iterator[tuple[str, torch.Tensor, torch.Tensor]]: + """Yield ``(name, int32 indices, values)`` from an open sparse safetensors handle. + + Flat, self-describing format: param names are recovered from the ``{name}.val`` tensor keys and the index encoding + is a single global header field (the gap-delta width is carried by the index tensor's own dtype). ``f`` is a + ``safetensors.safe_open`` handle. + """ + encoding = PatchMetadata.from_metadata_dict(f.metadata()).encoding + names = sorted({k[: -len(".val")] for k in f.keys() if k.endswith(".val")}) + for name in names: + idx = _decode_idx(f.get_tensor(f"{name}.idx"), encoding) + yield name, idx, f.get_tensor(f"{name}.val") + + +class HFBucketWorkerExtension: + """vLLM worker-extension hook (pass via ``--worker-extension-cls``). + + Required: ``--worker-extension-cls`` makes the vLLM *worker* process import this module, which runs the + ``register_engine`` call below so the ``"hf_bucket"`` backend exists in the worker (the factory registry is + per-process)""" + + pass + + +if "hf_bucket" not in WeightTransferEngineFactory._registry: + WeightTransferEngineFactory.register_engine("hf_bucket", HFBucketWeightTransferEngine) diff --git a/trl/experimental/async_grpo/weight_diff.py b/trl/experimental/async_grpo/weight_diff.py new file mode 100644 index 00000000000..55c65f9ccea --- /dev/null +++ b/trl/experimental/async_grpo/weight_diff.py @@ -0,0 +1,183 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Delta-compressed weight synchronization utilities. + +- ``AdamWInversionChangeDetector``: recovers which bf16 elements changed across an optimizer step by *inverting* the + AdamW update from the resident moments — no pre-step snapshot kept. +- ``PatchMetadata``: structured metadata stored in safetensors headers for both anchor (full) and delta (sparse) weight + files. +""" + +from __future__ import annotations + +import logging +from dataclasses import asdict, dataclass +from enum import Enum + +import torch +from torch.distributed._tensor import DTensor + +from .delta_codec import Encoding + + +logger = logging.getLogger(__name__) + + +def _adamw_reconstruct_pre_step(param: torch.Tensor, state: dict, group: dict) -> torch.Tensor: + """Reconstruct a parameter's pre-step value from its post-step value and the resident AdamW moments. + + Decoupled AdamW (the ``torch.optim.AdamW`` update): + + ``theta_t = theta_{t-1} * (1 - lr*wd) - (lr/bc1) * m_t / (sqrt(v_t)/sqrt(bc2) + eps)`` + + with bias corrections ``bc1 = 1 - beta1**t`` and ``bc2 = 1 - beta2**t``. The moments ``m_t`` / ``v_t`` resident in + ``optimizer.state`` after the step are exactly the ones the step used, so inverting it recovers ``theta_{t-1}``: + + ``theta_{t-1} = (theta_t + (lr/bc1) * m_t / (sqrt(v_t)/sqrt(bc2) + eps)) / (1 - lr*wd)`` + + Returns an ``fp32`` reconstruction on ``param``'s device. The result is exact up to floating-point error, so a + bf16-rounded comparison may flip rare elements near a rounding boundary (bounded by periodic anchors). + """ + beta1, beta2 = group["betas"] + lr, eps, weight_decay = group["lr"], group["eps"], group["weight_decay"] + step = state["step"] + t = step.item() if torch.is_tensor(step) else step + bias_correction1 = 1 - beta1**t + bias_correction2 = 1 - beta2**t + exp_avg = state["exp_avg"].float() + exp_avg_sq = state["exp_avg_sq"].float() + denom = exp_avg_sq.sqrt() / (bias_correction2**0.5) + eps + update = (lr / bias_correction1) * exp_avg / denom + return (param.detach().float() + update) / (1 - lr * weight_decay) + + +class AdamWInversionChangeDetector: + """Detects changed bf16 weights by *inverting* the AdamW step from the resident moments. + + The AdamW update is invertible, so we instead reconstruct each param's pre-step value on the fly from the + ``exp_avg`` / ``exp_avg_sq`` moments already living in ``optimizer.state`` (see [`_adamw_reconstruct_pre_step`]), + then diff the **low byte** of the bf16 pattern (a flipped low byte ⊆ a changed bf16 value). Persistent extra + storage is **zero**, the reconstruction is transient and thrown away. + + + ``_validated_masks[name]`` is a boolean tensor, True for each element detected as changed in the last step; + populated by [`compute_masks`]. + + Args: + model ([`~torch.nn.Module`]): + Model whose parameters are tracked. + optimizer ([`~torch.optim.Optimizer`]): + Optimizer driving the step. Must be a [`torch.optim.AdamW`] (unwrap Accelerate first). + """ + + def __init__(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer): + if not isinstance(optimizer, torch.optim.AdamW): + raise TypeError( + f"Sparse weight sync reconstructs the pre-step weights from the resident AdamW moments, so it " + f"requires a `torch.optim.AdamW` optimizer, but got `{type(optimizer).__name__}`. Set " + f"`weight_sync_mode='full'` to broadcast the full policy over NCCL instead." + ) + self.optimizer = optimizer + self._validated_masks: dict[str, torch.Tensor] = {} + + # Match model param names to optimizer param objects via data_ptr() + # (id() doesn't work because Accelerate wraps params as different objects) + model_params = {p.data_ptr(): name.removeprefix("module.") for name, p in model.named_parameters()} + self._param_id_to_name: dict[int, str] = {} + for group in optimizer.param_groups: + for p in group["params"]: + name = model_params.get(p.data_ptr()) + if name is not None: + self._param_id_to_name[id(p)] = name + + logger.info( + "AdamWInversionChangeDetector: matched %d/%d optimizer params", + len(self._param_id_to_name), + sum(1 for _ in model.named_parameters()), + ) + + def compute_masks(self) -> dict[str, torch.Tensor]: + """Reconstruct the pre-step weights from the current AdamW state and diff against the live weights. + + Returns ``_validated_masks`` (``{name: bool tensor}``, True where the bf16 weight changed in the last step). + The reconstruction is rounded to bf16 before the comparison: an unchanged bf16 weight differs from its fp32 + reconstruction by a sub-ULP residual, so an fp32 comparison would flag ~every element — the bf16 round makes + the mask track the actual bf16 changes (sparse). Params that have never stepped (no optimizer state) are + omitted. + + Under FSDP2 ``p`` and its moments are DTensors. The reconstruction + comparison run on the **local shards** + (the AdamW step is elementwise, so it inverts shard-locally — and ``aten.ne`` has no DTensor sharding + strategy), then the mask is re-wrapped as a DTensor with the param's placement so the trainer can + ``full_tensor()`` it. Runs entirely on each param's device. + """ + self._validated_masks.clear() + for group in self.optimizer.param_groups: + for p in group["params"]: + name = self._param_id_to_name.get(id(p)) + if name is None: + continue + state = self.optimizer.state.get(p) + if not state: # never stepped (e.g. frozen / no grad) -> nothing changed + continue + is_dtensor = isinstance(p, DTensor) + p_local = p.to_local() if is_dtensor else p + local_state = { + "step": state["step"], + "exp_avg": state["exp_avg"].to_local() if is_dtensor else state["exp_avg"], + "exp_avg_sq": state["exp_avg_sq"].to_local() if is_dtensor else state["exp_avg_sq"], + } + theta_old = _adamw_reconstruct_pre_step(p_local, local_state, group) + mask = p_local.to(torch.bfloat16) != theta_old.to(torch.bfloat16) + if is_dtensor: + mask = DTensor.from_local(mask, p.device_mesh, p.placements) + self._validated_masks[name] = mask + return self._validated_masks + + +@dataclass +class PatchMetadata: + format: str = "sparse_weight_patch" + version: str = "1" + sparse: bool = False + model_version: int = 0 + num_changed_params: int = 0 + total_changed_elements: int = 0 + total_elements: int = 0 + sparsity: float = 0.0 + encoding: Encoding = Encoding.GAP_DELTA # index encoding (delta files only) + + def to_metadata_dict(self) -> dict[str, str]: + return {k: (v.value if isinstance(v, Enum) else str(v)) for k, v in asdict(self).items()} + + @classmethod + def from_metadata_dict(cls, d: dict[str, str]) -> PatchMetadata: + field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()} + kwargs = {} + for k, v in d.items(): + if k not in field_types: + continue + ft = field_types[k] + if ft == "int": + kwargs[k] = int(v) + elif ft == "float": + kwargs[k] = float(v) + elif ft == "bool": + kwargs[k] = v == "True" + elif ft == "Encoding": + kwargs[k] = Encoding(v) + else: + kwargs[k] = v + return cls(**kwargs) diff --git a/trl/experimental/async_grpo/weight_transfer.py b/trl/experimental/async_grpo/weight_transfer.py index d0af83c2d2a..02ce798a307 100644 --- a/trl/experimental/async_grpo/weight_transfer.py +++ b/trl/experimental/async_grpo/weight_transfer.py @@ -12,24 +12,60 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import threading import time +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterator +from typing import TypedDict import requests +import torch from accelerate.logging import get_logger +from huggingface_hub import create_bucket from trl.import_utils import is_vllm_available +from .delta_codec import UpdateKind, extract_sparse_batched +from .delta_engine import HFBucketWeightTransferEngine + if is_vllm_available(min_version="0.17.1"): + from vllm.distributed.weight_transfer.base import SparseWeightPatch from vllm.distributed.weight_transfer.nccl_engine import NCCLTrainerSendWeightsArgs, NCCLWeightTransferEngine from vllm.utils.network_utils import get_ip, get_open_port logger = get_logger(__name__) +# A weight iterator yields ``(name, full_tensor, mask)`` per parameter; ``mask`` is None for a full transfer. +WeightIterFn = Callable[[bool], Iterator[tuple[str, torch.Tensor, torch.Tensor | None]]] + + +class PendingPatch(TypedDict): + """A bucket patch uploaded in [`BucketWeightTransfer`] phase 1, awaiting the phase-2 apply.""" + + repo_id: str # bucket id + filename: str # path within the bucket + update_kind: UpdateKind # DENSE (anchor) or SPARSE_FLAT (delta) + + +class WeightTransfer(ABC): + """Base for the trainer-side weight-sync transports. Holds the vLLM HTTP plumbing shared by every transport. + + Args: + vllm_server_url (`str`): + Base URL of the vLLM server. + weight_update_info (`dict`): + Full-model metadata with keys ``names`` / ``dtype_names`` / ``shapes`` (one entry per parameter, in + ``model.named_parameters()`` order). Used to build the update_info sent to vLLM. + server_timeout (`float`, *optional*, defaults to `240.0`): + Seconds to wait for the vLLM server to become ready. + init_weight_transfer_timeout (`int`, *optional*, defaults to `1800`): + Timeout for the one-off ``/init_weight_transfer_engine`` call. + """ -class WeightTransferClient: def __init__( self, vllm_server_url: str, @@ -39,13 +75,27 @@ def __init__( ): if not is_vllm_available(min_version="0.17.1"): raise ImportError( - "vLLM >= 0.17.1 is required to use WeightTransferClient. Install it with: pip install 'vllm>=0.17.1'" + "vLLM >= 0.17.1 is required to use the weight-sync transports. Install it with: " + "pip install 'vllm>=0.17.1'" ) self.vllm_server_url = vllm_server_url.rstrip("/") self.server_timeout = server_timeout self.init_weight_transfer_timeout = init_weight_transfer_timeout - self._weight_update_info = weight_update_info - self.model_update_group = None + self._names = weight_update_info["names"] + self._dtype_names = weight_update_info["dtype_names"] + self._shapes = weight_update_info["shapes"] + + @abstractmethod + def init(self, accelerator) -> None: + """Set up the transport (rank 0 only). Called once before the first sync.""" + + @abstractmethod + def sync(self, *, iter_fn: WeightIterFn, sparse: bool, is_anchor: bool, version: int, accelerator) -> None: + """Push the current policy to vLLM. Runs on every rank; rank 0 drives the transport, the others only walk + ``iter_fn`` so the FSDP2 collectives line up.""" + + def destroy(self) -> None: # noqa: B027 - intentional no-op default; NCCL overrides, bucket needs none + """Tear down the transport (rank 0 only). Default: nothing to do.""" def _wait_for_server_ready_sync(self, timeout_s: float | None = None, poll_interval_s: float = 2.0) -> None: timeout_s = timeout_s if timeout_s is not None else self.server_timeout @@ -70,10 +120,44 @@ def _wait_for_server_ready_sync(self, timeout_s: float | None = None, poll_inter logger.info(f"Still waiting for vLLM server... ({elapsed:.0f}s)") time.sleep(poll_interval_s) - def init_weight_transfer(self) -> None: + def _post_vllm(self, path: str, json_body: dict, retries: int = 1, timeout: int = 300) -> None: + """POST to a vLLM server endpoint with bounded retry on 429 / connection errors.""" + url = f"{self.vllm_server_url}{path}" + for attempt in range(retries): + try: + resp = requests.post(url, json=json_body, timeout=timeout) + if resp.status_code < 429: + resp.raise_for_status() + return + status = resp.status_code + except requests.RequestException as e: + logger.warning(f"[weight_sync] POST {path} failed: {e}") + status = "connection error" + if attempt < retries - 1: + wait = min(2**attempt, 30) + logger.warning(f"[weight_sync] POST {path} -> {status}, retry in {wait}s ({attempt + 1}/{retries})") + time.sleep(wait) + raise RuntimeError(f"[weight_sync] POST {path} failed after {retries} attempt(s)") + + def pause(self) -> None: + requests.post(f"{self.vllm_server_url}/pause", params={"mode": "keep"}) + + def resume(self) -> None: + requests.post(f"{self.vllm_server_url}/resume") + + +class NCCLWeightTransfer(WeightTransfer): + """Broadcast weights to vLLM over a shared NCCL group. Single-phase: pause, broadcast, resume.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_update_group = None + + def init(self, accelerator) -> None: + if not accelerator.is_main_process: + return self._wait_for_server_ready_sync() - response = requests.get(f"{self.vllm_server_url}/get_world_size") - inference_world_size = response.json()["world_size"] + inference_world_size = self._get_world_size() world_size = inference_world_size + 1 master_address = get_ip() master_port = get_open_port() @@ -83,6 +167,7 @@ def init_weight_transfer(self) -> None: "rank_offset": 1, "world_size": world_size, } + # vLLM's /init joins the group on the worker side; the trainer joins concurrently from this rank. t_init = threading.Thread( target=requests.post, args=(f"{self.vllm_server_url}/init_weight_transfer_engine",), @@ -90,48 +175,126 @@ def init_weight_transfer(self) -> None: ) t_init.start() self.model_update_group = NCCLWeightTransferEngine.trainer_init( - { - "master_address": master_address, - "master_port": master_port, - "world_size": world_size, - } + {"master_address": master_address, "master_port": master_port, "world_size": world_size} ) t_init.join() logger.info("Initialised weight-transfer NCCL group with vLLM") - def send_weights(self, iterator) -> None: - if self.model_update_group is None: - return + def _get_world_size(self, attempts: int = 30, poll_interval_s: float = 2.0) -> int: + """Read vLLM's world size, retrying until the engine RPC is ready. ``/health`` can go green before + ``/get_world_size`` answers, so poll instead of indexing the first response blindly.""" + for _ in range(attempts): + try: + world_size = requests.get(f"{self.vllm_server_url}/get_world_size", timeout=5).json().get("world_size") + if world_size is not None: + return int(world_size) + except (requests.RequestException, ValueError): + pass + time.sleep(poll_interval_s) + raise RuntimeError(f"vLLM /get_world_size did not return a world_size after {attempts} attempts") + + def sync(self, *, iter_fn: WeightIterFn, sparse: bool, is_anchor: bool, version: int, accelerator) -> None: + is_main = accelerator.is_main_process t0 = time.time() + if is_main: + self.pause() + accelerator.wait_for_everyone() # broadcast must start in lockstep across FSDP ranks + if is_main: + kind = "sparse" if sparse else ("anchor" if is_anchor else "full") + logger.info("Weight sync: NCCL %s broadcast...", kind) + if sparse: + self._send_sparse(iter_fn(True)) + else: + self._send_full(iter_fn(False)) + else: + for _ in iter_fn(sparse): # participate in the full_tensor() collectives + pass + accelerator.wait_for_everyone() + if is_main: + self.resume() + logger.info("Weight sync: done (NCCL, %.1fs)", time.time() - t0) + + def _send_full(self, iterator) -> None: + """Dense full-policy broadcast (checkpoint format). The worker enters ``receive_weights`` inside the + ``/update_weights`` call, so the NCCL send must run concurrently with that POST.""" + update_info = { + "update_kind": UpdateKind.DENSE, + "names": self._names, + "dtype_names": self._dtype_names, + "shapes": self._shapes, + "packed": True, + } + self._post_vllm("/start_weight_update", {"is_checkpoint_format": True}) t_update = threading.Thread( target=requests.post, args=(f"{self.vllm_server_url}/update_weights",), - kwargs={"json": {"update_info": self._weight_update_info}, "timeout": 1800}, + kwargs={"json": {"update_info": update_info}, "timeout": 1800}, ) t_update.start() - logger.debug(f"[weight_sync] /update_weights POST sent ({time.time() - t0:.1f}s)") - t_nccl = time.time() NCCLWeightTransferEngine.trainer_send_weights( - iterator=iterator, + iterator=((name, tensor) for name, tensor, _mask in iterator), trainer_args=NCCLTrainerSendWeightsArgs(group=self.model_update_group, packed=True), ) - logger.debug(f"[weight_sync] NCCL transfer took {time.time() - t_nccl:.1f}s") - t_join = time.time() t_update.join() - logger.debug( - f"[weight_sync] /update_weights join took {time.time() - t_join:.1f}s " - f"(total send_weights: {time.time() - t0:.1f}s)" - ) + self._post_vllm("/finish_weight_update", {}) - def pause(self) -> None: - t0 = time.time() - requests.post(f"{self.vllm_server_url}/pause", params={"mode": "keep"}) - logger.debug(f"[weight_sync] pause HTTP took {time.time() - t0:.1f}s") + def _send_sparse(self, iterator) -> None: + """Sparse delta broadcast (kernel format, applied in place via ``index_copy_`` on vLLM). Changed + ``(int32 flat-index, bf16 value)`` pairs are extracted on the GPU; the per-param counts go to vLLM as + ``num_updates_list`` so it can pre-allocate the receive buffers, then each patch is broadcast in that order. - def resume(self) -> None: - t0 = time.time() - requests.post(f"{self.vllm_server_url}/resume") - logger.debug(f"[weight_sync] resume HTTP took {time.time() - t0:.1f}s") + Extraction runs in **bounded chunks**: the iterator yields one ``full_tensor()``-gathered param at a time, and + we keep only its sparse payload (~1% of the param) before dropping the dense tensor. So rank 0 never holds the + whole gathered model at once — essential under FSDP2 at large scale, where the full model is tens of GB. + """ + # patches: (name, int32 indices, bf16 values, shape) — only the sparse payload is retained across chunks. + patches: list[tuple] = [] + chunk: list[tuple] = [] + chunk_numel = 0 + for name, tensor, mask in iterator: + chunk.append((name, tensor, mask)) + chunk_numel += tensor.numel() + if chunk_numel >= 256_000_000: + patches.extend(self._extract_chunk(chunk)) + chunk, chunk_numel = [], 0 # drop the chunk's dense gathered tensors + if chunk: + patches.extend(self._extract_chunk(chunk)) + + names = [name for name, _, _, _ in patches] + if not names: # nothing changed this step -> vLLM rejects an empty sparse update; skip + logger.debug("[weight_sync] sparse NCCL: no changed params, skipping transfer") + return + update_info = { + "update_kind": UpdateKind.SPARSE_FLAT, + "names": names, + "dtype_names": ["bfloat16"] * len(names), + "shapes": [shape for _, _, _, shape in patches], + "num_updates_list": [int(idx.numel()) for _, idx, _, _ in patches], + "packed": False, + } + self._post_vllm("/start_weight_update", {"is_checkpoint_format": False}) + t_update = threading.Thread( + target=requests.post, + args=(f"{self.vllm_server_url}/update_weights",), + kwargs={"json": {"update_info": update_info}, "timeout": 1800}, + ) + t_update.start() + NCCLWeightTransferEngine.trainer_send_sparse_weights( + iterator=(SparseWeightPatch(name=name, indices=idx, values=vals) for name, idx, vals, _ in patches), + trainer_args=NCCLTrainerSendWeightsArgs(group=self.model_update_group, src=0, packed=False), + ) + t_update.join() + self._post_vllm("/finish_weight_update", {}) + + @staticmethod + def _extract_chunk(chunk) -> list[tuple]: + """Extract ``(name, int32 indices, bf16 values, shape)`` for a chunk of ``(name, full, mask)`` triples. + Values are cast to bf16 to match the dtype vLLM allocates for the receive buffer (the served policy is bf16).""" + shape_by_name = {name: list(tensor.shape) for name, tensor, _ in chunk} + return [ + (name, idx.contiguous(), vals.to(torch.bfloat16).contiguous(), shape_by_name[name]) + for name, idx, vals in extract_sparse_batched(chunk) + ] def destroy(self) -> None: if self.model_update_group is None: @@ -139,3 +302,109 @@ def destroy(self) -> None: self.model_update_group.group.store = None self.model_update_group.group.socket = None self.model_update_group = None + + +class BucketWeightTransfer(WeightTransfer): + """Route the patch through an HF Storage Bucket. Two-phase: upload while inference runs, then pause → apply → + resume (vLLM fetches the patch from the bucket inside the apply). + + Args: + bucket_id (`str`): + HF Storage Bucket for the patches/anchors (created if missing). + encoding (`str`, *optional*, defaults to `"gap_delta"`): + Index encoding for the patches: ``"raw"``, ``"gap_delta"``, or ``"nvcomp_cascaded"``. + """ + + def __init__(self, *args, bucket_id: str, encoding: str = "gap_delta", **kwargs): + super().__init__(*args, **kwargs) + self._bucket_id = bucket_id + self._encoding = encoding + self._pending: PendingPatch | None = None + + def init(self, accelerator) -> None: + if not accelerator.is_main_process: + return + self._wait_for_server_ready_sync() + create_bucket(self._bucket_id, exist_ok=True) + requests.post( + f"{self.vllm_server_url}/init_weight_transfer_engine", + json={"init_info": {}}, + timeout=self.init_weight_transfer_timeout, + ) + logger.info("Initialised bucket weight transfer (bucket %s)", self._bucket_id) + + def sync(self, *, iter_fn: WeightIterFn, sparse: bool, is_anchor: bool, version: int, accelerator) -> None: + is_main = accelerator.is_main_process + t0 = time.time() + # Phase 1: encode + upload the patch while inference keeps running (weights materialized here). + if is_main: + logger.info("Weight sync: uploading %s patch to bucket...", "anchor" if is_anchor else "delta") + self._upload(iter_fn(sparse), is_anchor=is_anchor, version=version) + else: + for _ in iter_fn(sparse): # participate in the full_tensor() collectives + pass + accelerator.wait_for_everyone() + # Phase 2: pause, then signal vLLM to fetch + apply the uploaded patch. + if is_main: + self.pause() + try: + self._apply() + except Exception as e: + logger.warning(f"Weight sync: bucket apply failed ({e}), skipping; vLLM keeps stale weights") + self.resume() + logger.info("Weight sync: done (bucket, %.1fs)", time.time() - t0) + accelerator.wait_for_everyone() + + def _upload(self, iterator, is_anchor: bool, version: int) -> None: + if is_anchor: + iterator = ((name, tensor, None) for name, tensor, _mask in iterator) # strip masks -> full tensors + subdir = "anchors" if is_anchor else "deltas" + filename = f"{subdir}/step_{version:06d}.safetensors" + meta = HFBucketWeightTransferEngine.upload( + iterator=iterator, + bucket_id=self._bucket_id, + filename=filename, + model_version=version, + encoding=self._encoding, + ) + self._pending = ( + None + if meta is None + else PendingPatch( + repo_id=self._bucket_id, + filename=filename, + update_kind=UpdateKind.DENSE if is_anchor else UpdateKind.SPARSE_FLAT, + ) + ) + + def _apply(self) -> None: + """No-op when nothing was uploaded this step; ``_pending`` is cleared up front so a failed apply leaves no + stale state.""" + pending, self._pending = self._pending, None + if pending is None: + return + # Anchors are HF-checkpoint-format full tensors; deltas are sparse kernel-format. + self._post_vllm("/start_weight_update", {"is_checkpoint_format": pending["update_kind"] is UpdateKind.DENSE}) + # vLLM fetches the patch inside this call; a full anchor can take minutes, so the timeout must cover the + # download, otherwise a read-timeout would retry into a re-download. + self._post_vllm("/update_weights", {"update_info": pending}, retries=5, timeout=1800) + self._post_vllm("/finish_weight_update", {}) + + +def make_weight_transfer( + backend: str, + *, + vllm_server_url: str, + weight_update_info: dict, + server_timeout: float = 240.0, + bucket_id: str | None = None, + encoding: str = "gap_delta", +) -> WeightTransfer: + """Build the [`WeightTransfer`] for ``backend`` (``"nccl"`` or ``"bucket"``).""" + if backend == "nccl": + return NCCLWeightTransfer(vllm_server_url, weight_update_info, server_timeout=server_timeout) + if backend == "bucket": + return BucketWeightTransfer( + vllm_server_url, weight_update_info, server_timeout=server_timeout, bucket_id=bucket_id, encoding=encoding + ) + raise ValueError(f"Unknown weight_sync_backend: {backend!r}")