diff --git a/examples/scripts/openenv/async_wordle.py b/examples/scripts/openenv/async_wordle.py new file mode 100644 index 00000000000..dc7047825d8 --- /dev/null +++ b/examples/scripts/openenv/async_wordle.py @@ -0,0 +1,303 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "trackio", +# "openenv-textarena @ git+https://huggingface.co/spaces/openenv/wordle", +# ] +# /// + + +""" +Async GRPO training with Wordle environment using delta weight sync to a remote vLLM HF Space. + +Architecture: + Local (1 GPU): AsyncGRPOTrainer + rollout worker (env tool calls run locally) + Remote Space 1: vLLM server with DeltaWorkerExtension (GPU, serves /v1/completions) + Remote Space 2: TextArena Wordle game server (no GPU, already at openenv-wordle.hf.space) + HF Hub Bucket: Stores weight anchors and sparse deltas + +# Option 1: Remote vLLM Space + Remote Wordle Space (fully remote inference) + +## Step 1: Deploy vLLM on HF Spaces + +The Dockerfile and README.md are provided in `examples/scripts/openenv/vllm_space/`. +Deploy with the HF CLI: + +```sh +# Create the Space (l4x1 GPU, Docker SDK) +hf repos create /vllm-wordle-inference \\ + --type space --space-sdk docker --flavor l4x1 \\ + --secrets HF_TOKEN=$HF_TOKEN \\ + --env VLLM_SERVER_DEV_MODE=1 + +# Upload Dockerfile + README +hf upload /vllm-wordle-inference \\ + examples/scripts/openenv/vllm_space/ --type space + +# Check status +hf spaces info /vllm-wordle-inference +``` + +## Step 2: Run training locally (1 GPU) + +```sh +CUDA_VISIBLE_DEVICES=0 python examples/scripts/openenv/async_wordle.py \\ + --vllm-server-url https://.hf.space \\ + --env-url https://openenv-wordle.hf.space \\ + --delta-sync-repo-id /wordle-deltas \\ + --model Qwen/Qwen3-1.7B +``` + +# Option 2: Local vLLM + Remote Wordle Space (for testing) + +## Terminal 1: Spin up local vLLM server + +```sh +CUDA_VISIBLE_DEVICES=0 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-1.7B \\ + --worker-extension-cls trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension \\ + --weight-transfer-config '{"backend":"nccl"}' \\ + --max-model-len 8192 \\ + --enforce-eager \\ + --gpu-memory-utilization 0.8 \\ + --logprobs-mode processed_logprobs +``` + +## Terminal 2: Run training + +```sh +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/async_wordle.py \\ + --vllm-server-url http://localhost:8000 \\ + --delta-sync-repo-id /wordle-deltas \\ + --model Qwen/Qwen3-1.7B +``` +""" + +import argparse +import logging +import os + +from datasets import Dataset +from textarena_env import TextArenaAction, TextArenaEnv + +from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer + + +logging.basicConfig( + level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Async GRPO training for Wordle with delta weight sync to a remote vLLM HF Space." + ) + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen3-1.7B", + help="Model identifier passed to AsyncGRPOTrainer for fine-tuning.", + ) + parser.add_argument( + "--env-url", + type=str, + default="https://openenv-wordle.hf.space", + help="URL for the Wordle environment server.", + ) + parser.add_argument( + "--vllm-server-url", + type=str, + default="http://localhost:8000", + help="URL for the vLLM server (local or remote HF Space).", + ) + parser.add_argument( + "--delta-sync-repo-id", + type=str, + default=None, + help="HF Hub bucket for delta weight patches (e.g. 'user/wordle-deltas'). Required.", + ) + parser.add_argument( + "--delta-sync-anchor-interval", + type=int, + default=10, + help="Upload a full anchor checkpoint every N weight syncs.", + ) + parser.add_argument( + "--dataset-size", + type=int, + default=1000, + help="Number of entries in the synthetic training dataset.", + ) + parser.add_argument( + "--num-generations", + type=int, + default=16, + help="Number of rollout generations per prompt.", + ) + parser.add_argument( + "--per-device-train-batch-size", + type=int, + default=32, + help="Per-device training batch size.", + ) + parser.add_argument( + "--max-steps", + type=int, + default=100, + help="Maximum number of training steps.", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=1e-6, + help="Learning rate for training.", + ) + parser.add_argument( + "--max-staleness", + type=int, + default=5, + help="Drop rollout samples generated more than this many weight versions ago.", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory where training outputs and checkpoints are stored.", + ) + parser.add_argument( + "--trackio-space-id", + type=str, + default="aminediroHF/async_wordle_trackio", + help="Trackio space identifier for logging.", + ) + return parser.parse_args() + + +prompt = """You are an expert Wordle solver with deep knowledge of English vocabulary, letter frequency patterns, and optimal guessing strategies. + +Follow these rules to play Wordle: + +1. The target is a 5-letter English word +2. You have 6 attempts to guess the correct word +3. After each guess, you receive color-coded feedback: + - GREEN (G): Letter is correct and in the correct position + - YELLOW (Y): Letter is in the word but in the wrong position + - GRAY (X): Letter is not in the word at all +4. All guesses must be valid 5-letter English words +5. You cannot reuse a word you've already guessed +6. Use the tool `guess` to make a guess. +""" + + +def reward_func(environments, **kwargs) -> list[float]: + return [env.reward for env in environments] + + +def main() -> None: + args = parse_args() + + env_url = args.env_url + + class WordleEnv: + def __init__(self): + self.client = TextArenaEnv(base_url=env_url).sync() + self.reward = 0.0 + self.done = False + + def _reconnect(self): + self.client = TextArenaEnv(base_url=env_url).sync() + + def reset(self, **kwargs) -> str | None: + try: + result = self.client.reset() + except Exception: + self._reconnect() + result = self.client.reset() + # The game returns cumulative feedback each turn (new text appended at the end), so + # we store the previous full response and slice out only the newly appended part. + self._last_full_feedback = result.observation.messages[0].content + self.reward = 0.0 + self.done = False + return self._last_full_feedback + + def guess(self, guess: str) -> str: + """ + Make a guess in the Wordle environment. + + Args: + guess: The guessed word, formatted as '[abcde]' + + Returns: + The feedback message from the environment. + """ + if self.done: + raise ValueError("Game over.") + try: + result = self.client.step(TextArenaAction(message=guess)) + except Exception: + self._reconnect() + result = self.client.step(TextArenaAction(message=guess)) + _full_feedback = result.observation.messages[0].content + # Just take the new feedback since the last guess, which is the part appended to the end of the full feedback + feedback = _full_feedback[len(self._last_full_feedback) :] + self._last_full_feedback = _full_feedback + # For some reason, the environment doesn't penalize invalid moves and just returns the last reward. + # We check the feedback for the invalid move message and penalize it if found. + if "You attempted an invalid move" in feedback: + self.reward = 0.0 + else: + self.reward = result.reward + self.done = result.done + return feedback + + output_dir = args.output_dir or f"{args.model.split('/')[-1]}-async-wordle-GRPO" + dataset = Dataset.from_dict({"prompt": [[{"role": "user", "content": prompt}] for _ in range(args.dataset_size)]}) + + config = AsyncGRPOConfig( + delta_sync_enabled=True, + delta_sync_repo_id=args.delta_sync_repo_id, + delta_sync_anchor_interval=args.delta_sync_anchor_interval, + vllm_server_base_url=args.vllm_server_url, + learning_rate=args.learning_rate, + bf16=True, + output_dir=output_dir, + max_completion_length=1024, + max_tool_calling_iterations=3, + per_device_train_batch_size=args.per_device_train_batch_size, + num_generations=args.num_generations, + max_staleness=args.max_staleness, + max_steps=args.max_steps, + logging_steps=1, + log_completions=True, + num_completions_to_print=1, + report_to="trackio", + trackio_space_id=args.trackio_space_id, + chat_template_kwargs={"enable_thinking": False}, + ) + + trainer = AsyncGRPOTrainer( + model=args.model, + args=config, + train_dataset=dataset, + reward_funcs=reward_func, + environment_factory=WordleEnv, + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/openenv/vllm_space/Dockerfile b/examples/scripts/openenv/vllm_space/Dockerfile new file mode 100644 index 00000000000..a71babd07a5 --- /dev/null +++ b/examples/scripts/openenv/vllm_space/Dockerfile @@ -0,0 +1,28 @@ +FROM vllm/vllm-openai:latest + +# Install git (needed to pip install from git repos), then TRL from delta-weight-sync branch +RUN apt-get update && apt-get install -y --no-install-recommends git && rm -rf /var/lib/apt/lists/* +RUN pip install "trl @ git+https://github.com/huggingface/trl.git@delta-weight-sync" +RUN pip install "transformers==5.2.0" + +# HF Spaces expects port 7860 +EXPOSE 7860 + +# HF Spaces runs as uid 1000 without a passwd entry — PyTorch needs USER set +ENV VLLM_SERVER_DEV_MODE=1 +ENV USER=user +ENV HOME=/tmp +ENV HF_HOME=/tmp/hf_cache +ENV TORCH_HOME=/tmp/torch_cache +ENV XDG_CACHE_HOME=/tmp/.cache +ENV FLASHINFER_WORKSPACE_DIR=/tmp/flashinfer + +ENTRYPOINT ["vllm", "serve", "Qwen/Qwen3-1.7B", \ + "--host", "0.0.0.0", \ + "--port", "7860", \ + "--worker-extension-cls", "trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension", \ + "--weight-transfer-config", "{\"backend\":\"nccl\"}", \ + "--max-model-len", "32768", \ + "--enforce-eager", \ + "--gpu-memory-utilization", "0.8", \ + "--logprobs-mode", "processed_logprobs"] diff --git a/examples/scripts/openenv/vllm_space/README.md b/examples/scripts/openenv/vllm_space/README.md new file mode 100644 index 00000000000..ace0b5e0651 --- /dev/null +++ b/examples/scripts/openenv/vllm_space/README.md @@ -0,0 +1,14 @@ +--- +title: vLLM Wordle Inference +emoji: 🎮 +colorFrom: blue +colorTo: green +sdk: docker +app_port: 7860 +hardware: l4 +--- + +vLLM server with DeltaWorkerExtension for async GRPO training. + +Serves Qwen/Qwen3-1.7B with delta weight sync via HF Hub buckets. +Used by `examples/scripts/openenv/async_wordle.py` in the TRL repo. diff --git a/examples/scripts/openenv/wordle_space/Dockerfile b/examples/scripts/openenv/wordle_space/Dockerfile new file mode 100644 index 00000000000..bfeca9af1e7 --- /dev/null +++ b/examples/scripts/openenv/wordle_space/Dockerfile @@ -0,0 +1,15 @@ +FROM python:3.11-slim + +RUN apt-get update && apt-get install -y --no-install-recommends git git-lfs && rm -rf /var/lib/apt/lists/* +RUN git lfs install +ENV GIT_CLONE_PROTECTION_ACTIVE=false +RUN pip install --no-build-isolation "openenv-textarena @ git+https://huggingface.co/spaces/openenv/wordle" \ + || (git clone https://huggingface.co/spaces/openenv/wordle /tmp/wordle && pip install /tmp/wordle && rm -rf /tmp/wordle) + +COPY app.py /app/app.py + +EXPOSE 7860 + +CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"] + +WORKDIR /app diff --git a/examples/scripts/openenv/wordle_space/README.md b/examples/scripts/openenv/wordle_space/README.md new file mode 100644 index 00000000000..6857374dff2 --- /dev/null +++ b/examples/scripts/openenv/wordle_space/README.md @@ -0,0 +1,11 @@ +--- +title: Wordle Environment (High Capacity) +emoji: 🟩 +colorFrom: green +colorTo: yellow +sdk: docker +app_port: 7860 +--- + +Wordle environment server with 256 concurrent session support for async GRPO training. +Used by `examples/scripts/openenv/async_wordle.py` in the TRL repo. diff --git a/examples/scripts/openenv/wordle_space/app.py b/examples/scripts/openenv/wordle_space/app.py new file mode 100644 index 00000000000..9e0e251df8d --- /dev/null +++ b/examples/scripts/openenv/wordle_space/app.py @@ -0,0 +1,20 @@ +"""Wordle environment server with higher concurrent session capacity for async GRPO training.""" + +import os + +from openenv.core.env_server.http_server import ConcurrencyConfig, create_app +from textarena_env.models import TextArenaAction, TextArenaObservation +from textarena_env.server.environment import TextArenaEnvironment + + +# Mark TextArena as safe for concurrent sessions (each session gets its own game instance) +TextArenaEnvironment.SUPPORTS_CONCURRENT_SESSIONS = True + +max_sessions = int(os.getenv("MAX_CONCURRENT_SESSIONS", "256")) + +app = create_app( + TextArenaEnvironment, + TextArenaAction, + TextArenaObservation, + concurrency_config=ConcurrencyConfig(max_concurrent_envs=max_sessions), +) diff --git a/pyproject.toml b/pyproject.toml index f4831765478..15b0add31dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ requires-python = ">=3.10" dependencies = [ "accelerate>=1.4.0", "datasets>=4.7.0", # Support Json type and on_mixed_types="use_json" + "huggingface-hub>=0.36.2", "packaging>20.0", "transformers>=4.56.2", ] diff --git a/trl/experimental/async_grpo/async_grpo_config.py b/trl/experimental/async_grpo/async_grpo_config.py index 2afd760e7fc..775c6811824 100644 --- a/trl/experimental/async_grpo/async_grpo_config.py +++ b/trl/experimental/async_grpo/async_grpo_config.py @@ -185,6 +185,38 @@ class AsyncGRPOConfig(_BaseConfig): metadata={"help": "Number of training steps between weight synchronizations to the vLLM server."}, ) + # Delta weight sync + delta_sync_enabled: bool = field( + default=False, + metadata={ + "help": "Enable delta-compressed weight synchronization. Instead of transferring all " + "weights over NCCL, encode only changed bf16 weights as sparse safetensors patches, " + "upload them to HuggingFace Hub (Xet storage), and signal vLLM to fetch and apply them." + }, + ) + delta_sync_repo_id: str | None = field( + default=None, + metadata={ + "help": "HuggingFace Hub repository for storing delta weight patches and anchor " + "checkpoints (e.g. 'user/training-run-xyz'). Required when delta_sync_enabled=True. " + "The repo is created automatically if it does not exist." + }, + ) + delta_sync_anchor_interval: int = field( + default=10, + metadata={ + "help": "Save a full bf16 anchor checkpoint every N weight sync steps. Between anchors " + "only sparse delta patches are saved. Fireworks blog uses N=25, PULSE paper uses N=50." + }, + ) + delta_sync_verify_checksum: bool = field( + default=True, + metadata={ + "help": "Verify SHA256 checksum after applying each delta patch on the vLLM server. " + "Adds overhead per sync but guarantees bit-exact reconstruction." + }, + ) + # Parameters that control the logging log_completions: bool = field( default=False, @@ -201,6 +233,9 @@ class AsyncGRPOConfig(_BaseConfig): def __post_init__(self): super().__post_init__() + if self.delta_sync_enabled and self.delta_sync_repo_id is None: + raise ValueError("delta_sync_repo_id is required when delta_sync_enabled=True") + # Accelerator config: required for the async IterableDataset-backed dataloader to work correctly. # split_batches=True and dispatch_batches=True ensure that the main process drives the dataloader # and batches are broadcast to other processes rather than each process pulling independently. diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index aca72c73596..f52b0a5bef3 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -35,6 +35,7 @@ from .async_grpo_config import AsyncGRPOConfig from .async_rollout_worker import AsyncRolloutWorker +from .weight_diff import BF16ChangeDetector logger = get_logger(__name__) @@ -379,6 +380,9 @@ def __init__( weight_names=weight_names, weight_dtype_names=weight_dtype_names, weight_shapes=weight_shapes, + delta_sync_enabled=self.args.delta_sync_enabled, + delta_sync_repo_id=self.args.delta_sync_repo_id, + delta_sync_anchor_interval=self.args.delta_sync_anchor_interval, ) self.rollout_queue = self.rollout_worker.rollout_buffer else: @@ -388,6 +392,9 @@ def __init__( # Add callbacks self.add_callback(StepIntervalCallback(self._sync_weight, self.args.weight_sync_steps)) + # ULP change detector for diagnostic logging (delta sync only) + self._change_detector: BF16ChangeDetector | None = None + def get_train_dataloader(self) -> DataLoader: if self.accelerator.is_main_process: dataset = RolloutQueueDataset( @@ -562,43 +569,118 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: self._metrics[mode].clear() def _streaming_iter(self): - # Iterate parameters one at a time. For FSDP2 (DTensor), full_tensor() all-gathers just this parameter across - # FSDP ranks, then frees it once the generator advances — avoiding materializing the full model in memory. + """Yield ``(name, tensor, mask)`` tuples. + + - No change detector (NCCL path or first sync): all params, ``mask=None``. + - Change detector active: only changed params with element-level masks. + + The anchor/delta decision is NOT made here — the rollout worker handles that. + """ + if self._change_detector is None or not self._change_detector._validated_masks: + for name, param in self.model.named_parameters(): + name = name.removeprefix("module.") + yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), None + return + + masks = self._change_detector._validated_masks + total, yielded = 0, 0 for name, param in self.model.named_parameters(): - name = name.removeprefix("module.") # DDP/FSDP1 wrapping - full = param.full_tensor() if isinstance(param, DTensor) else param.detach() - yield name, full + total += 1 + name = name.removeprefix("module.") + mask = masks.get(name) + if mask is None or not mask.any(): + continue + yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), mask + yielded += 1 + logger.info(f"Delta: {yielded}/{total} params changed") def _sync_weight(self): + # Lazy-init ULP detector for diagnostic logging (delta sync only) bc + # Optimizer only exists after Trainer creates it inside super()._inner_training_loop(). + if ( + self.args.delta_sync_enabled + and self._change_detector is None + and hasattr(self, "optimizer") + and self.optimizer is not None + ): + # TODO(@aminediro): check this works with FSDP2 + # Unwrap AcceleratedOptimizer to get the native PyTorch optimizer + # (register_step_pre_hook requires torch.optim.Optimizer internals) + raw_optimizer = getattr(self.optimizer, "optimizer", self.optimizer) + self._change_detector = BF16ChangeDetector(self.model, raw_optimizer) + + if ( + self.args.delta_sync_enabled + and self._change_detector is not None + and self._change_detector._validated_masks + and self.accelerator.is_main_process + ): + total_changed = 0 + total_elements = 0 + for mask in self._change_detector._validated_masks.values(): + total_changed += mask.sum().item() + total_elements += mask.numel() + sparsity = 1.0 - total_changed / max(total_elements, 1) + self._metrics["train"]["delta/sparsity"].append(sparsity) + self._metrics["train"]["delta/total_changed"].append(total_changed) + self._metrics["train"]["delta/total_elements"].append(total_elements) + logger.info(f"Delta: {total_changed}/{total_elements} elements changed (sparsity={sparsity:.4%})") + t0 = time.time() - logger.info("Weight sync: pausing vLLM...") + is_delta = self.args.delta_sync_enabled + + if is_delta: + # Phase 1: Upload to HF Hub while inference continues + logger.info("Weight sync: uploading to HF Hub (inference still running)...") + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.send_weights(self._streaming_iter()) + else: + for _ in self._streaming_iter(): + pass + self.accelerator.wait_for_everyone() + t_upload = time.time() + logger.info(f"Weight sync: upload took {t_upload - t0:.1f}s, now pausing vLLM...") + + # Phase 2: Pause inference if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.pause() t_pause = time.time() - logger.info(f"Weight sync: pause took {t_pause - t0:.1f}s, waiting for all ranks...") self.accelerator.wait_for_everyone() t_barrier = time.time() - logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)") - if self.accelerator.is_main_process and self.rollout_worker: - self.rollout_worker.send_weights(self._streaming_iter()) + if is_delta: + # Phase 3: Signal vLLM to fetch the already-uploaded weights + logger.info(f"Weight sync: signaling vLLM to apply... (pause took {t_pause - t_upload:.1f}s)") + if self.accelerator.is_main_process and self.rollout_worker: + try: + self.rollout_worker.send_weights(iter([])) + except Exception as e: + logger.warning(f"Weight sync: apply failed ({e}), skipping — vLLM will use stale weights") else: - # Non-rank-0 processes must still participate in full_tensor() collectives for FSDP2. - for _ in self._streaming_iter(): - pass + # NCCL: transfer all weights directly + logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)") + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.send_weights(self._streaming_iter()) + else: + for _ in self._streaming_iter(): + pass + t_transfer = time.time() self.accelerator.wait_for_everyone() - logger.info(f"Weight sync: resuming vLLM... (transfer took {t_transfer - t_barrier:.1f}s)") + # Phase 4: Resume + logger.info(f"Weight sync: resuming vLLM... (apply took {t_transfer - t_barrier:.1f}s)") if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.resume() self.model_version += 1 self.rollout_worker.update_model_version(self.model_version) weight_sync_time_s = time.time() - t0 self._metrics["train"]["weight_sync_time_s"].append(weight_sync_time_s) - logger.info(f"Weight sync: done. Total {weight_sync_time_s:.1f}s") + logger.info( + f"Weight sync: done. Total {weight_sync_time_s:.1f}s (inference paused {t_transfer - t_pause:.1f}s)" + ) def _inner_training_loop(self, *args, **kwargs): # Start the rollout worker here (not in __init__) so that checkpoint loading in Trainer.train() @@ -613,3 +695,6 @@ def _inner_training_loop(self, *args, **kwargs): finally: if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.stop() + if self._change_detector is not None: + self._change_detector.close() + self._change_detector = None diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 3d7350d5a71..1671a85dda0 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -13,7 +13,9 @@ # limitations under the License. import asyncio +import copy import inspect +import itertools import queue import threading import time @@ -26,12 +28,15 @@ import requests from accelerate.logging import get_logger from datasets import Dataset +from huggingface_hub import create_bucket from transformers import AutoTokenizer from trl.chat_template_utils import add_response_schema, get_training_chat_template, parse_response from trl.import_utils import is_vllm_available from trl.trainer.utils import print_prompt_completions_sample +from .delta_engine import DeltaWeightTransferEngine + if is_vllm_available(min_version="0.17.1"): from vllm.distributed.weight_transfer.nccl_engine import NCCLTrainerSendWeightsArgs, NCCLWeightTransferEngine @@ -56,6 +61,7 @@ class RolloutGroup: tool_mask: list[list[int]] tool_call_counts: list[int] tool_failure_counts: list[int] + environments: list[object] model_version: int queued_at: float = 0.0 @@ -102,11 +108,15 @@ def __init__( weight_names: list[str] | None = None, weight_dtype_names: list[str] | None = None, weight_shapes: list[list[int]] | None = None, + delta_sync_enabled: bool = False, + delta_sync_repo_id: str | None = None, + delta_sync_anchor_interval: int = 10, ): if not is_vllm_available(min_version="0.17.1"): raise ImportError( "vLLM >= 0.17.1 is required to use AsyncRolloutWorker. Install it with: pip install 'vllm>=0.17.1'" ) + self.delta_sync_enabled = delta_sync_enabled self.model_name = model_name self.max_tool_calling_iterations = max_tool_calling_iterations self.dataset = dataset @@ -171,7 +181,10 @@ def __init__( self.model_version = 0 self.session = None - # Wait for the vLLM server and initialize NCCL weight transfer. + self._delta_sync_repo_id = delta_sync_repo_id + self._delta_sync_anchor_interval = delta_sync_anchor_interval + + # Wait for the vLLM server and initialize weight transfer. self._wait_for_server_ready_sync(timeout_s=self.server_timeout) self._init_weight_transfer() @@ -199,6 +212,18 @@ def _wait_for_server_ready_sync(self, timeout_s: float = 240.0, poll_interval_s: time.sleep(poll_interval_s) def _init_weight_transfer(self) -> None: + if self.delta_sync_enabled: + create_bucket(self._delta_sync_repo_id, exist_ok=True) + self._delta_model_version = 0 + self._delta_pending_update_info: dict | None = None + requests.post( + f"{self.vllm_server_url}/init_weight_transfer_engine", + json={"init_info": {}}, + timeout=120, + ) + logger.info("Init delta weight transfer with HF Hub repo %s", self._delta_sync_repo_id) + return + response = requests.get(f"{self.vllm_server_url}/get_world_size") inference_world_size = response.json()["world_size"] world_size = inference_world_size + 1 @@ -287,6 +312,73 @@ def resume(self) -> None: logger.debug(f"[weight_sync] resume HTTP took {time.time() - t0:.1f}s") def send_weights(self, iterator) -> None: + if self.delta_sync_enabled: + self._send_weights_delta(iterator) + else: + self._send_weights_nccl(iterator) + + def _send_weights_delta(self, iterator) -> None: + """Delta sync via HF Bucket. + + - Non-empty iterator: upload (anchor or delta based on step count). + - Empty iterator + pending info: signal vLLM to apply. + - Empty iterator + nothing pending: no-op. + """ + first = next(iterator, None) + + # (empty iterator) + if first is None: + if self._delta_pending_update_info is not None: + for attempt in range(5): + try: + resp = requests.post( + f"{self.vllm_server_url}/update_weights", + json={"update_info": self._delta_pending_update_info}, + timeout=300, + ) + if resp.status_code < 429: + break + except requests.RequestException as e: + resp = None + logger.warning(f"[weight_sync] /update_weights request failed: {e}") + wait = min(2**attempt, 30) + status = resp.status_code if resp is not None else "connection error" + logger.warning( + f"[weight_sync] /update_weights returned {status}, " + f"retrying in {wait}s (attempt {attempt + 1}/5)" + ) + time.sleep(wait) + if resp is not None: + resp.raise_for_status() + self._delta_pending_update_info = None + return + + # Upload phase + self._delta_model_version += 1 + is_anchor = self._delta_model_version == 1 or self._delta_model_version % self._delta_sync_anchor_interval == 0 + + full_iter = itertools.chain([first], iterator) + if is_anchor: + # Force full tensors — strip masks + full_iter = ((name, tensor, None) for name, tensor, _mask in full_iter) + + subdir = "anchors" if is_anchor else "deltas" + filename = f"{subdir}/step_{self._delta_model_version:06d}.safetensors" + meta = DeltaWeightTransferEngine.upload( + iterator=full_iter, + bucket_id=self._delta_sync_repo_id, + filename=filename, + model_version=self._delta_model_version, + ) + if meta is not None: + self._delta_pending_update_info = { + "repo_id": self._delta_sync_repo_id, + "filename": filename, + "is_checkpoint_format": True, + } + + def _send_weights_nccl(self, iterator) -> None: + """NCCL sync: broadcast all params via NCCL + signal /update_weights.""" if self.model_update_group is None: return t0 = time.time() @@ -299,7 +391,7 @@ def send_weights(self, iterator) -> None: logger.debug(f"[weight_sync] /update_weights POST sent ({time.time() - t0:.1f}s)") t_nccl = time.time() NCCLWeightTransferEngine.trainer_send_weights( - iterator=iterator, + iterator=((name, tensor) for name, tensor, _mask in iterator), trainer_args=NCCLTrainerSendWeightsArgs(group=self.model_update_group, packed=True), ) logger.debug(f"[weight_sync] NCCL transfer took {time.time() - t_nccl:.1f}s") @@ -346,6 +438,7 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None: tool_mask=[], tool_call_counts=[], tool_failure_counts=[], + environments=[], model_version=self.model_version, ) pending_completed[group_id] = 0 @@ -353,10 +446,14 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None: slot = free_slots.pop() if self.environments is not None: - # Current assumption: reset side effects matter, return value is ignored. - self.environments[slot].reset(**row) - - logger.info(f"[slot] assigned slot={slot} group={group_id} free_after={len(free_slots)}") + try: + self.environments[slot].reset(**row) + except Exception as e: + logger.warning(f"[slot={slot}] env.reset() failed: {e}, skipping generation") + free_slots.add(slot) + continue + + logger.debug(f"[slot] assigned slot={slot} group={group_id} free_after={len(free_slots)}") task = asyncio.create_task( self._generate_one(pending_groups[group_id].prompt, tool_dict=self._sync_tool_dicts[slot]) ) @@ -382,7 +479,27 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None: free_slots.add(slot) logger.debug(f"[slot] freed slot={slot} group={group_id} free_after={len(free_slots)}") if task.exception() is not None: - raise task.exception() + logger.warning( + f"[slot={slot}] generation failed for group {group_id}: {task.exception()}, skipping" + ) + pending_completed[group_id] += 1 + if pending_completed[group_id] == self.num_generations: + # All generations attempted but some failed — drop the group + group = pending_groups[group_id] + if group.completions: + # Score whatever we have + group.queued_at = time.monotonic() + try: + self._groups_to_score.put_nowait(group) + except asyncio.QueueFull: + pass + logger.warning( + f"Group {group_id} had failures, got {len(pending_groups[group_id].completions)}" + f"/{self.num_generations} completions" + ) + del pending_groups[group_id] + del pending_completed[group_id] + continue ( completion, @@ -399,6 +516,8 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None: group.tool_mask.append(tool_mask) group.tool_call_counts.append(tool_call_count) group.tool_failure_counts.append(tool_failure_count) + if self.environments is not None: + group.environments.append(copy.copy(self.environments[slot])) # TODO: move this in generation task, shouldn't matter but is correct self._total_completion_tokens += sum(tool_mask) pending_completed[group_id] += 1 @@ -636,6 +755,8 @@ async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]: completion_ids=group.completions_ids, **group.reward_kwargs, ) + if group.environments: + kwargs["environments"] = group.environments all_rewards = await asyncio.gather( *[ reward_func(**kwargs) diff --git a/trl/experimental/async_grpo/delta_engine.py b/trl/experimental/async_grpo/delta_engine.py new file mode 100644 index 00000000000..110ddf8d199 --- /dev/null +++ b/trl/experimental/async_grpo/delta_engine.py @@ -0,0 +1,223 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import logging +import tempfile +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from typing import Any + +import torch +from huggingface_hub import batch_bucket_files, download_bucket_files +from safetensors import safe_open +from safetensors.torch import save +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferEngine, + WeightTransferInitInfo, + WeightTransferUpdateInfo, +) +from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory + +from .weight_diff import PatchMetadata + + +logger = logging.getLogger(__name__) + + +@dataclass +class DeltaWeightTransferInitInfo(WeightTransferInitInfo): + pass + + +@dataclass +class DeltaWeightTransferUpdateInfo(WeightTransferUpdateInfo): + """Metadata sent via ``/update_weights`` — just bucket coordinates.""" + + repo_id: str = "" # bucket_id + filename: str = "" + + +class DeltaWeightTransferEngine(WeightTransferEngine[DeltaWeightTransferInitInfo, DeltaWeightTransferUpdateInfo]): + """Weight transfer engine that uses HF Hub (Xet) as the data plane. + + Worker side: downloads safetensors from Hub, feeds to ``load_weights``. + Trainer side: uploads changed params as safetensors to Hub. + """ + + init_info_cls = DeltaWeightTransferInitInfo + update_info_cls = DeltaWeightTransferUpdateInfo + + def __init__(self, config: WeightTransferConfig, parallel_config: ParallelConfig) -> None: + super().__init__(config, parallel_config) + # TODO: might be able to eliminate completely + # CPU-side bf16 snapshot — needed because vLLM's load_weights expects full + # tensors, so we must reconstruct them from sparse (indices, values) patches. + # Kept on CPU to avoid GPU memory overhead (~2 bytes/param, e.g. ~1.2 GB for 0.6B model). + self._bf16_snapshot: dict[str, torch.Tensor] | None = None + + def init_transfer_engine(self, init_info: DeltaWeightTransferInitInfo) -> None: + pass + + def receive_weights( + self, + update_info: DeltaWeightTransferUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + """Download safetensors from Hub and feed to load_weights. + + Handles two formats based on the ``sparse`` metadata flag: + + - **Full** (first sync): keys are param names → feed directly to load_weights, + build snapshot for future sparse applies. + - **Sparse** (subsequent): keys are ``{name}.indices`` + ``{name}.values`` → + apply to snapshot, feed reconstructed full tensors to load_weights. + """ + with tempfile.TemporaryDirectory() as tmpdir: + local_path = f"{tmpdir}/weights.safetensors" + download_bucket_files( + update_info.repo_id, + files=[(update_info.filename, local_path)], + ) + + with safe_open(local_path, framework="pt", device="cpu") as f: + meta = PatchMetadata.from_metadata_dict(f.metadata()) + + if not meta.sparse: + self._bf16_snapshot = {} + for name in f.keys(): + tensor = f.get_tensor(name) + self._bf16_snapshot[name] = tensor.to(torch.bfloat16).clone() + load_weights([(name, tensor)]) + logger.info("Applied anchor (step %d, %d params)", meta.model_version, meta.num_changed_params) + else: + changed_names = json.loads(meta.changed_params) + for name in changed_names: + if name not in self._bf16_snapshot: + logger.warning("Skipping delta for %s: not in snapshot (missing from anchor)", name) + continue + indices = f.get_tensor(f"{name}.indices").long() + values = f.get_tensor(f"{name}.values") + snap = self._bf16_snapshot[name].flatten() + snap[indices] = values + self._bf16_snapshot[name] = snap.reshape(self._bf16_snapshot[name].shape) + load_weights([(name, self._bf16_snapshot[name])]) + logger.info( + "Applied delta (step %d, %d params, sparsity=%.4f)", + meta.model_version, + meta.num_changed_params, + meta.sparsity, + ) + + def shutdown(self) -> None: + self._bf16_snapshot = None + + @staticmethod + def trainer_send_weights( + iterator: Iterator[tuple[str, torch.Tensor]], + trainer_args: dict[str, Any] | Any, + ) -> None: + """Not used directly — the rollout worker manages upload + signaling.""" + raise NotImplementedError("Use AsyncRolloutWorker._send_weights_delta instead") + + @staticmethod + def upload( + iterator: Iterator[tuple[str, torch.Tensor, torch.Tensor | None]], + bucket_id: str, + filename: str, + model_version: int = 0, + ) -> PatchMetadata | None: + """Encode params as safetensors and upload to HF Hub. + + Each item is ``(name, tensor, mask)``: + + - ``mask is None``: full tensor stored as ``name`` (anchor). + - ``mask`` provided: sparse encoding — only changed elements stored + as ``{name}.indices`` (int32) + ``{name}.values`` (bf16). + + Returns :class:`PatchMetadata` or ``None`` if the iterator was empty. + """ + tensors: dict[str, torch.Tensor] = {} + changed_names: list[str] = [] + total_changed = 0 + total_elements = 0 + sparse = False + + for name, tensor, mask in iterator: + bf16 = tensor.to(torch.bfloat16).cpu() + total_elements += bf16.numel() + if mask is None: + tensors[name] = bf16.clone() + changed_names.append(name) + total_changed += bf16.numel() + else: + sparse = True + indices = mask.flatten().nonzero(as_tuple=False).squeeze(1).to(torch.int32) + values = bf16.flatten()[indices.long()] + tensors[f"{name}.indices"] = indices + tensors[f"{name}.values"] = values + changed_names.append(name) + total_changed += len(indices) + + if not tensors: + return None + + meta = PatchMetadata( + sparse=sparse, + model_version=model_version, + num_changed_params=len(changed_names), + total_changed_elements=total_changed, + total_elements=total_elements, + sparsity=1.0 - total_changed / max(total_elements, 1), + changed_params=json.dumps(changed_names), + ) + buf = save(tensors, metadata=meta.to_metadata_dict()) + + batch_bucket_files(bucket_id, add=[(buf, filename)]) + + logger.info( + "[delta_engine] uploaded %s/%s (%.1f MB, %d params, sparse=%s, sparsity=%.4f)", + bucket_id, + filename, + len(buf) / 1e6, + len(changed_names), + sparse, + meta.sparsity, + ) + return meta + + +class DeltaWorkerExtension: + """vLLM worker extension for the delta weight transfer backend. + + This class is intentionally minimal. Its import (via ``--worker-extension-cls``) + registers the engine and overrides the ``"nccl"`` factory entry. + + ``backend`` must be ``"nccl"`` in the CLI (pydantic ``Literal`` validation). + This module overrides the ``"nccl"`` factory entry so that the actual engine + created is ``DeltaWeightTransferEngine``. + """ + + pass + + +if "delta" not in WeightTransferEngineFactory._registry: + WeightTransferEngineFactory.register_engine("delta", DeltaWeightTransferEngine) + +# Override "nccl" so --weight-transfer-config '{"backend":"nccl"}' creates our engine. +WeightTransferEngineFactory._registry["nccl"] = lambda: DeltaWeightTransferEngine diff --git a/trl/experimental/async_grpo/weight_diff.py b/trl/experimental/async_grpo/weight_diff.py new file mode 100644 index 00000000000..c53f578aea2 --- /dev/null +++ b/trl/experimental/async_grpo/weight_diff.py @@ -0,0 +1,125 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Delta-compressed weight synchronization utilities. + +- ``BF16ChangeDetector``: hooks into the optimizer to detect which bf16 elements + actually changed after each step. +- ``PatchMetadata``: structured metadata stored in safetensors headers for both + anchor (full) and delta (sparse) weight files. +""" + +from __future__ import annotations + +import logging +from dataclasses import asdict, dataclass + +import torch + + +logger = logging.getLogger(__name__) + + +class BF16ChangeDetector: + """Detects which bf16 weights actually changed across an optimizer step. + + Hooks into the optimizer via ``register_step_pre_hook`` / ``register_step_post_hook`` + (PyTorch >= 2.1). Snapshots bf16 values before the step, compares after. + + ``_validated_masks[name]`` is a boolean tensor with True for each element that changed. + """ + + def __init__(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer): + self._validated_masks: dict[str, torch.Tensor] = {} + self._pre_step_bf16: dict[str, torch.Tensor] = {} + + # Match model param names to optimizer param objects via data_ptr() + # (id() doesn't work because Accelerate wraps params as different objects) + model_params = {p.data_ptr(): name.removeprefix("module.") for name, p in model.named_parameters()} + self._param_id_to_name: dict[int, str] = {} + for group in optimizer.param_groups: + for p in group["params"]: + name = model_params.get(p.data_ptr()) + if name is not None: + self._param_id_to_name[id(p)] = name + + logger.info( + "BF16ChangeDetector: matched %d/%d optimizer params", + len(self._param_id_to_name), + sum(1 for _ in model.named_parameters()), + ) + + self._pre_hook_handle = optimizer.register_step_pre_hook(self._pre_step_hook) + self._post_hook_handle = optimizer.register_step_post_hook(self._post_step_hook) + + def _pre_step_hook(self, optimizer, args, kwargs) -> None: + self._pre_step_bf16.clear() + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + name = self._param_id_to_name.get(id(p)) + if name is None: + continue + self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).cpu().clone() + + def _post_step_hook(self, optimizer, args, kwargs) -> None: + self._validated_masks.clear() + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + name = self._param_id_to_name.get(id(p)) + if name is None or name not in self._pre_step_bf16: + continue + self._validated_masks[name] = p.detach().to(torch.bfloat16).cpu() != self._pre_step_bf16[name] + + def close(self): + self._pre_hook_handle.remove() + self._post_hook_handle.remove() + + +@dataclass +class PatchMetadata: + format: str = "sparse_weight_patch" + version: str = "1" + sparse: bool = False + model_version: int = 0 + num_changed_params: int = 0 + total_changed_elements: int = 0 + total_elements: int = 0 + sparsity: float = 0.0 + changed_params: str = "[]" + + def to_metadata_dict(self) -> dict[str, str]: + return {k: str(v) for k, v in asdict(self).items()} + + @classmethod + def from_metadata_dict(cls, d: dict[str, str]) -> PatchMetadata: + field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()} + kwargs = {} + for k, v in d.items(): + if k not in field_types: + continue + ft = field_types[k] + if ft == "int": + kwargs[k] = int(v) + elif ft == "float": + kwargs[k] = float(v) + elif ft == "bool": + kwargs[k] = v == "True" + else: + kwargs[k] = v + return cls(**kwargs)