diff --git a/examples/scripts/async_grpo_delta.py b/examples/scripts/async_grpo_delta.py new file mode 100644 index 00000000000..cddf22eb5d3 --- /dev/null +++ b/examples/scripts/async_grpo_delta.py @@ -0,0 +1,95 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +AsyncGRPO with delta weight sync (Transport B: HF Storage Bucket + in-place sparse apply). + +Only changed bf16 weights are encoded as a sparse safetensors patch, uploaded to a bucket, and +applied in place on vLLM via PR #40096 โ€” no full-model broadcast, no vLLM-side snapshot. + +Start the vLLM server with the `delta` backend + worker extension (registers the engine) and the +`transformers` model impl (so vLLM's runtime param names match the trainer's HF names โ€” every +param is then addressable by the in-place sparse apply, no fuse/unfuse remap needed): + +# VLLM_USE_V2_MODEL_RUNNER=0 is required: the in-place sparse apply (apply_sparse_weight_patches, +# vLLM #40096) exists only on the V1 model runner. Without it the server picks V2 and every sparse +# delta update fails (the dense anchors still work, so it silently degrades to anchor-only sync). +CUDA_VISIBLE_DEVICES=1 VLLM_SERVER_DEV_MODE=1 VLLM_USE_V2_MODEL_RUNNER=0 vllm serve Qwen/Qwen3-1.7B \ + --model-impl transformers \ + --worker-extension-cls trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension \ + --weight-transfer-config '{"backend":"delta"}' \ + --max-model-len 2560 + +CUDA_VISIBLE_DEVICES=0 accelerate launch examples/scripts/async_grpo_delta.py +""" + +import logging +import os + +from datasets import load_dataset + +from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer +from trl.rewards import accuracy_reward + + +logging.basicConfig( + level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logging.getLogger("trl").setLevel(logging.INFO) + + +def format_sample(sample): + return { + "prompt": [{"role": "user", "content": sample["question"]}], + "solution": sample["answer"].split("####")[-1].strip(), + } + + +def main() -> None: + dataset = load_dataset("openai/gsm8k", "main", split="train") + dataset = dataset.map(format_sample, remove_columns=dataset.column_names) + + config = AsyncGRPOConfig( + output_dir="./results/async_grpo_delta", + per_device_train_batch_size=1, + num_generations=8, + max_completion_length=512, + max_steps=60, + learning_rate=1e-5, + logging_steps=1, + bf16=True, + report_to="none", + project="async_grpo_delta", + log_completions=True, + # Qwen3 thinking traces blow past the completion cap on GSM8K (truncated -> no answer -> + # zero reward); disable thinking so completions are short and accuracy_reward gets signal. + chat_template_kwargs={"enable_thinking": False}, + # --- delta weight sync (Transport B) --- + delta_sync_enabled=True, + delta_sync_repo_id="aminediroHF/async-grpo-delta-demo", + delta_sync_anchor_interval=20, # full anchor every N syncs; sparse deltas in between + delta_sync_encoding="gap_delta", # raw | gap_delta | nvcomp_cascaded + ) + trainer = AsyncGRPOTrainer( + model="Qwen/Qwen3-1.7B", + args=config, + train_dataset=dataset, + reward_funcs=accuracy_reward, + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/openenv/async_wordle.py b/examples/scripts/openenv/async_wordle.py new file mode 100644 index 00000000000..dc7047825d8 --- /dev/null +++ b/examples/scripts/openenv/async_wordle.py @@ -0,0 +1,303 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "trackio", +# "openenv-textarena @ git+https://huggingface.co/spaces/openenv/wordle", +# ] +# /// + + +""" +Async GRPO training with Wordle environment using delta weight sync to a remote vLLM HF Space. + +Architecture: + Local (1 GPU): AsyncGRPOTrainer + rollout worker (env tool calls run locally) + Remote Space 1: vLLM server with DeltaWorkerExtension (GPU, serves /v1/completions) + Remote Space 2: TextArena Wordle game server (no GPU, already at openenv-wordle.hf.space) + HF Hub Bucket: Stores weight anchors and sparse deltas + +# Option 1: Remote vLLM Space + Remote Wordle Space (fully remote inference) + +## Step 1: Deploy vLLM on HF Spaces + +The Dockerfile and README.md are provided in `examples/scripts/openenv/vllm_space/`. +Deploy with the HF CLI: + +```sh +# Create the Space (l4x1 GPU, Docker SDK) +hf repos create /vllm-wordle-inference \\ + --type space --space-sdk docker --flavor l4x1 \\ + --secrets HF_TOKEN=$HF_TOKEN \\ + --env VLLM_SERVER_DEV_MODE=1 + +# Upload Dockerfile + README +hf upload /vllm-wordle-inference \\ + examples/scripts/openenv/vllm_space/ --type space + +# Check status +hf spaces info /vllm-wordle-inference +``` + +## Step 2: Run training locally (1 GPU) + +```sh +CUDA_VISIBLE_DEVICES=0 python examples/scripts/openenv/async_wordle.py \\ + --vllm-server-url https://.hf.space \\ + --env-url https://openenv-wordle.hf.space \\ + --delta-sync-repo-id /wordle-deltas \\ + --model Qwen/Qwen3-1.7B +``` + +# Option 2: Local vLLM + Remote Wordle Space (for testing) + +## Terminal 1: Spin up local vLLM server + +```sh +CUDA_VISIBLE_DEVICES=0 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-1.7B \\ + --worker-extension-cls trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension \\ + --weight-transfer-config '{"backend":"nccl"}' \\ + --max-model-len 8192 \\ + --enforce-eager \\ + --gpu-memory-utilization 0.8 \\ + --logprobs-mode processed_logprobs +``` + +## Terminal 2: Run training + +```sh +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/async_wordle.py \\ + --vllm-server-url http://localhost:8000 \\ + --delta-sync-repo-id /wordle-deltas \\ + --model Qwen/Qwen3-1.7B +``` +""" + +import argparse +import logging +import os + +from datasets import Dataset +from textarena_env import TextArenaAction, TextArenaEnv + +from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer + + +logging.basicConfig( + level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Async GRPO training for Wordle with delta weight sync to a remote vLLM HF Space." + ) + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen3-1.7B", + help="Model identifier passed to AsyncGRPOTrainer for fine-tuning.", + ) + parser.add_argument( + "--env-url", + type=str, + default="https://openenv-wordle.hf.space", + help="URL for the Wordle environment server.", + ) + parser.add_argument( + "--vllm-server-url", + type=str, + default="http://localhost:8000", + help="URL for the vLLM server (local or remote HF Space).", + ) + parser.add_argument( + "--delta-sync-repo-id", + type=str, + default=None, + help="HF Hub bucket for delta weight patches (e.g. 'user/wordle-deltas'). Required.", + ) + parser.add_argument( + "--delta-sync-anchor-interval", + type=int, + default=10, + help="Upload a full anchor checkpoint every N weight syncs.", + ) + parser.add_argument( + "--dataset-size", + type=int, + default=1000, + help="Number of entries in the synthetic training dataset.", + ) + parser.add_argument( + "--num-generations", + type=int, + default=16, + help="Number of rollout generations per prompt.", + ) + parser.add_argument( + "--per-device-train-batch-size", + type=int, + default=32, + help="Per-device training batch size.", + ) + parser.add_argument( + "--max-steps", + type=int, + default=100, + help="Maximum number of training steps.", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=1e-6, + help="Learning rate for training.", + ) + parser.add_argument( + "--max-staleness", + type=int, + default=5, + help="Drop rollout samples generated more than this many weight versions ago.", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory where training outputs and checkpoints are stored.", + ) + parser.add_argument( + "--trackio-space-id", + type=str, + default="aminediroHF/async_wordle_trackio", + help="Trackio space identifier for logging.", + ) + return parser.parse_args() + + +prompt = """You are an expert Wordle solver with deep knowledge of English vocabulary, letter frequency patterns, and optimal guessing strategies. + +Follow these rules to play Wordle: + +1. The target is a 5-letter English word +2. You have 6 attempts to guess the correct word +3. After each guess, you receive color-coded feedback: + - GREEN (G): Letter is correct and in the correct position + - YELLOW (Y): Letter is in the word but in the wrong position + - GRAY (X): Letter is not in the word at all +4. All guesses must be valid 5-letter English words +5. You cannot reuse a word you've already guessed +6. Use the tool `guess` to make a guess. +""" + + +def reward_func(environments, **kwargs) -> list[float]: + return [env.reward for env in environments] + + +def main() -> None: + args = parse_args() + + env_url = args.env_url + + class WordleEnv: + def __init__(self): + self.client = TextArenaEnv(base_url=env_url).sync() + self.reward = 0.0 + self.done = False + + def _reconnect(self): + self.client = TextArenaEnv(base_url=env_url).sync() + + def reset(self, **kwargs) -> str | None: + try: + result = self.client.reset() + except Exception: + self._reconnect() + result = self.client.reset() + # The game returns cumulative feedback each turn (new text appended at the end), so + # we store the previous full response and slice out only the newly appended part. + self._last_full_feedback = result.observation.messages[0].content + self.reward = 0.0 + self.done = False + return self._last_full_feedback + + def guess(self, guess: str) -> str: + """ + Make a guess in the Wordle environment. + + Args: + guess: The guessed word, formatted as '[abcde]' + + Returns: + The feedback message from the environment. + """ + if self.done: + raise ValueError("Game over.") + try: + result = self.client.step(TextArenaAction(message=guess)) + except Exception: + self._reconnect() + result = self.client.step(TextArenaAction(message=guess)) + _full_feedback = result.observation.messages[0].content + # Just take the new feedback since the last guess, which is the part appended to the end of the full feedback + feedback = _full_feedback[len(self._last_full_feedback) :] + self._last_full_feedback = _full_feedback + # For some reason, the environment doesn't penalize invalid moves and just returns the last reward. + # We check the feedback for the invalid move message and penalize it if found. + if "You attempted an invalid move" in feedback: + self.reward = 0.0 + else: + self.reward = result.reward + self.done = result.done + return feedback + + output_dir = args.output_dir or f"{args.model.split('/')[-1]}-async-wordle-GRPO" + dataset = Dataset.from_dict({"prompt": [[{"role": "user", "content": prompt}] for _ in range(args.dataset_size)]}) + + config = AsyncGRPOConfig( + delta_sync_enabled=True, + delta_sync_repo_id=args.delta_sync_repo_id, + delta_sync_anchor_interval=args.delta_sync_anchor_interval, + vllm_server_base_url=args.vllm_server_url, + learning_rate=args.learning_rate, + bf16=True, + output_dir=output_dir, + max_completion_length=1024, + max_tool_calling_iterations=3, + per_device_train_batch_size=args.per_device_train_batch_size, + num_generations=args.num_generations, + max_staleness=args.max_staleness, + max_steps=args.max_steps, + logging_steps=1, + log_completions=True, + num_completions_to_print=1, + report_to="trackio", + trackio_space_id=args.trackio_space_id, + chat_template_kwargs={"enable_thinking": False}, + ) + + trainer = AsyncGRPOTrainer( + model=args.model, + args=config, + train_dataset=dataset, + reward_funcs=reward_func, + environment_factory=WordleEnv, + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/openenv/vllm_space/Dockerfile b/examples/scripts/openenv/vllm_space/Dockerfile new file mode 100644 index 00000000000..a71babd07a5 --- /dev/null +++ b/examples/scripts/openenv/vllm_space/Dockerfile @@ -0,0 +1,28 @@ +FROM vllm/vllm-openai:latest + +# Install git (needed to pip install from git repos), then TRL from delta-weight-sync branch +RUN apt-get update && apt-get install -y --no-install-recommends git && rm -rf /var/lib/apt/lists/* +RUN pip install "trl @ git+https://github.com/huggingface/trl.git@delta-weight-sync" +RUN pip install "transformers==5.2.0" + +# HF Spaces expects port 7860 +EXPOSE 7860 + +# HF Spaces runs as uid 1000 without a passwd entry โ€” PyTorch needs USER set +ENV VLLM_SERVER_DEV_MODE=1 +ENV USER=user +ENV HOME=/tmp +ENV HF_HOME=/tmp/hf_cache +ENV TORCH_HOME=/tmp/torch_cache +ENV XDG_CACHE_HOME=/tmp/.cache +ENV FLASHINFER_WORKSPACE_DIR=/tmp/flashinfer + +ENTRYPOINT ["vllm", "serve", "Qwen/Qwen3-1.7B", \ + "--host", "0.0.0.0", \ + "--port", "7860", \ + "--worker-extension-cls", "trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension", \ + "--weight-transfer-config", "{\"backend\":\"nccl\"}", \ + "--max-model-len", "32768", \ + "--enforce-eager", \ + "--gpu-memory-utilization", "0.8", \ + "--logprobs-mode", "processed_logprobs"] diff --git a/examples/scripts/openenv/vllm_space/README.md b/examples/scripts/openenv/vllm_space/README.md new file mode 100644 index 00000000000..ace0b5e0651 --- /dev/null +++ b/examples/scripts/openenv/vllm_space/README.md @@ -0,0 +1,14 @@ +--- +title: vLLM Wordle Inference +emoji: ๐ŸŽฎ +colorFrom: blue +colorTo: green +sdk: docker +app_port: 7860 +hardware: l4 +--- + +vLLM server with DeltaWorkerExtension for async GRPO training. + +Serves Qwen/Qwen3-1.7B with delta weight sync via HF Hub buckets. +Used by `examples/scripts/openenv/async_wordle.py` in the TRL repo. diff --git a/examples/scripts/openenv/wordle_space/Dockerfile b/examples/scripts/openenv/wordle_space/Dockerfile new file mode 100644 index 00000000000..bfeca9af1e7 --- /dev/null +++ b/examples/scripts/openenv/wordle_space/Dockerfile @@ -0,0 +1,15 @@ +FROM python:3.11-slim + +RUN apt-get update && apt-get install -y --no-install-recommends git git-lfs && rm -rf /var/lib/apt/lists/* +RUN git lfs install +ENV GIT_CLONE_PROTECTION_ACTIVE=false +RUN pip install --no-build-isolation "openenv-textarena @ git+https://huggingface.co/spaces/openenv/wordle" \ + || (git clone https://huggingface.co/spaces/openenv/wordle /tmp/wordle && pip install /tmp/wordle && rm -rf /tmp/wordle) + +COPY app.py /app/app.py + +EXPOSE 7860 + +CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"] + +WORKDIR /app diff --git a/examples/scripts/openenv/wordle_space/README.md b/examples/scripts/openenv/wordle_space/README.md new file mode 100644 index 00000000000..6857374dff2 --- /dev/null +++ b/examples/scripts/openenv/wordle_space/README.md @@ -0,0 +1,11 @@ +--- +title: Wordle Environment (High Capacity) +emoji: ๐ŸŸฉ +colorFrom: green +colorTo: yellow +sdk: docker +app_port: 7860 +--- + +Wordle environment server with 256 concurrent session support for async GRPO training. +Used by `examples/scripts/openenv/async_wordle.py` in the TRL repo. diff --git a/examples/scripts/openenv/wordle_space/app.py b/examples/scripts/openenv/wordle_space/app.py new file mode 100644 index 00000000000..9e0e251df8d --- /dev/null +++ b/examples/scripts/openenv/wordle_space/app.py @@ -0,0 +1,20 @@ +"""Wordle environment server with higher concurrent session capacity for async GRPO training.""" + +import os + +from openenv.core.env_server.http_server import ConcurrencyConfig, create_app +from textarena_env.models import TextArenaAction, TextArenaObservation +from textarena_env.server.environment import TextArenaEnvironment + + +# Mark TextArena as safe for concurrent sessions (each session gets its own game instance) +TextArenaEnvironment.SUPPORTS_CONCURRENT_SESSIONS = True + +max_sessions = int(os.getenv("MAX_CONCURRENT_SESSIONS", "256")) + +app = create_app( + TextArenaEnvironment, + TextArenaAction, + TextArenaObservation, + concurrency_config=ConcurrencyConfig(max_concurrent_envs=max_sessions), +) diff --git a/pyproject.toml b/pyproject.toml index f4831765478..45f6342690c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ requires-python = ">=3.10" dependencies = [ "accelerate>=1.4.0", "datasets>=4.7.0", # Support Json type and on_mixed_types="use_json" + "huggingface-hub>=0.36.2", "packaging>20.0", "transformers>=4.56.2", ] @@ -98,6 +99,15 @@ vlm = [ math_verify = [ "math-verify>=0.5.2", ] +delta_weight_sync = [ + # Delta weight sync for AsyncGRPO (trl/experimental/async_grpo). Needs vLLM's sparse weight + # transfer (vllm-project/vllm#40096): merged to main 2026-06-01, NOT in any release yet (latest + # is v0.22.0). Install vLLM from the nightly index until it ships in a stable release โ€” hence no + # `vllm` pin here (validated against 0.22.1rc1.dev*; expected to first land in ~0.23.0). + # "vllm>=0.23.0" + "huggingface-hub>=1.17.0", + "nvidia-nvcomp-cu12>=5.2.0", # optional: only for the nvcomp_cascaded index encoding (CUDA 12) +] dev = [ # bco "scikit-learn", diff --git a/trl/experimental/async_grpo/async_grpo_config.py b/trl/experimental/async_grpo/async_grpo_config.py index 2afd760e7fc..a0663535da6 100644 --- a/trl/experimental/async_grpo/async_grpo_config.py +++ b/trl/experimental/async_grpo/async_grpo_config.py @@ -185,6 +185,40 @@ class AsyncGRPOConfig(_BaseConfig): metadata={"help": "Number of training steps between weight synchronizations to the vLLM server."}, ) + # Delta weight sync + delta_sync_enabled: bool = field( + default=False, + metadata={ + "help": "Enable delta-compressed weight synchronization. Instead of transferring all " + "weights over NCCL, encode only changed bf16 weights as sparse safetensors patches, " + "upload them to HuggingFace Hub (Xet storage), and signal vLLM to fetch and apply them." + }, + ) + delta_sync_repo_id: str | None = field( + default=None, + metadata={ + "help": "HuggingFace Hub repository for storing delta weight patches and anchor " + "checkpoints (e.g. 'user/training-run-xyz'). Required when delta_sync_enabled=True. " + "The repo is created automatically if it does not exist." + }, + ) + delta_sync_anchor_interval: int = field( + default=10, + metadata={ + "help": "Save a full bf16 anchor checkpoint every N weight sync steps. Between anchors " + "only sparse delta patches are saved. Fireworks blog uses N=25, PULSE paper uses N=50." + }, + ) + # TODO: should probably inherit from the same (str,Enum) for delta sync encoding + delta_sync_encoding: str = field( + default="gap_delta", + metadata={ + "help": "Index encoding for delta patches: 'raw' (int32), 'gap_delta' (uint16 gaps, " + "default), or 'nvcomp_cascaded' (GPU Cascaded delta+bitpack, ~1.3 B/idx, needs " + "nvidia-nvcomp). Values are always stored raw; this only compresses the index half." + }, + ) + # Parameters that control the logging log_completions: bool = field( default=False, @@ -201,6 +235,9 @@ class AsyncGRPOConfig(_BaseConfig): def __post_init__(self): super().__post_init__() + if self.delta_sync_enabled and self.delta_sync_repo_id is None: + raise ValueError("delta_sync_repo_id is required when delta_sync_enabled=True") + # Accelerator config: required for the async IterableDataset-backed dataloader to work correctly. # split_batches=True and dispatch_batches=True ensure that the main process drives the dataloader # and batches are broadcast to other processes rather than each process pulling independently. diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index aca72c73596..903d2be056c 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -35,6 +35,7 @@ from .async_grpo_config import AsyncGRPOConfig from .async_rollout_worker import AsyncRolloutWorker +from .weight_diff import LowByteChangeDetector logger = get_logger(__name__) @@ -60,6 +61,8 @@ def stop(self) -> None: ... def pause(self) -> None: ... def resume(self) -> None: ... def send_weights(self, iterator: Iterator[tuple[str, torch.Tensor]]) -> None: ... + def upload_weights(self, iterator: Iterator[tuple[str, torch.Tensor]]) -> None: ... + def apply_weights(self) -> None: ... def update_model_version(self, version: int) -> None: ... @@ -379,6 +382,10 @@ def __init__( weight_names=weight_names, weight_dtype_names=weight_dtype_names, weight_shapes=weight_shapes, + delta_sync_enabled=self.args.delta_sync_enabled, + delta_sync_repo_id=self.args.delta_sync_repo_id, + delta_sync_anchor_interval=self.args.delta_sync_anchor_interval, + delta_sync_encoding=self.args.delta_sync_encoding, ) self.rollout_queue = self.rollout_worker.rollout_buffer else: @@ -388,6 +395,9 @@ def __init__( # Add callbacks self.add_callback(StepIntervalCallback(self._sync_weight, self.args.weight_sync_steps)) + # ULP change detector for diagnostic logging (delta sync only) + self._change_detector: LowByteChangeDetector | None = None + def get_train_dataloader(self) -> DataLoader: if self.accelerator.is_main_process: dataset = RolloutQueueDataset( @@ -428,6 +438,10 @@ def _set_signature_columns_if_needed(self): ] def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + # NOTE: Register the change detector before the first optimizer.step so step 1 is captured and the + # first delta sync is already sparse (not a full upload) + self._maybe_init_change_detector() + input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] completion_mask = inputs["completion_mask"] @@ -562,43 +576,126 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: self._metrics[mode].clear() def _streaming_iter(self): - # Iterate parameters one at a time. For FSDP2 (DTensor), full_tensor() all-gathers just this parameter across - # FSDP ranks, then frees it once the generator advances โ€” avoiding materializing the full model in memory. + """Yield ``(name, tensor, mask)`` tuples. + + - No change detector (NCCL path or first sync): all params, ``mask=None``. + - Change detector active: only changed params with element-level masks. + + The anchor/delta decision is NOT made here โ€” the rollout worker handles that. + """ + # TODO(@aminediro): _validated_masks maybe a property or a getter function because this is weird + if self._change_detector is None or not self._change_detector._validated_masks: + for name, param in self.model.named_parameters(): + name = name.removeprefix("module.") + yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), None + return + + masks = self._change_detector._validated_masks + total, yielded = 0, 0 for name, param in self.model.named_parameters(): - name = name.removeprefix("module.") # DDP/FSDP1 wrapping - full = param.full_tensor() if isinstance(param, DTensor) else param.detach() - yield name, full + total += 1 + name = name.removeprefix("module.") + mask = masks.get(name) + if mask is None or not mask.any(): + continue + yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), mask + yielded += 1 + logger.info(f"Delta: {yielded}/{total} params changed") + + def _maybe_init_change_detector(self): + """Create the bf16 change detector once the (prepared) optimizer exists. No-op otherwise. + + Called from ``compute_loss`` on the first training step so the optimizer step hooks are + registered *before* the first ``optimizer.step()``. If we waited until the first weight sync + (``on_step_end``), step 1's update would be missed and that sync would fall back to a full + upload instead of a sparse delta. Idempotent; only runs when delta sync is enabled. + """ + if ( + self.args.delta_sync_enabled + and self._change_detector is None + and getattr(self, "optimizer", None) is not None + ): + # Unwrap AcceleratedOptimizer to the native PyTorch optimizer (register_step_*_hook + # requires torch.optim.Optimizer internals). + raw_optimizer = getattr(self.optimizer, "optimizer", self.optimizer) + self._change_detector = LowByteChangeDetector(self.model, raw_optimizer) def _sync_weight(self): + # Normally already created in compute_loss; this is a fallback (e.g. a sync before any step). + self._maybe_init_change_detector() + + if ( + self.args.delta_sync_enabled + and self._change_detector is not None + and self._change_detector._validated_masks + and self.accelerator.is_main_process + ): + total_changed = 0 + total_elements = 0 + for mask in self._change_detector._validated_masks.values(): + total_changed += mask.sum().item() + total_elements += mask.numel() + sparsity = 1.0 - total_changed / max(total_elements, 1) + self._metrics["train"]["delta/sparsity"].append(sparsity) + self._metrics["train"]["delta/total_changed"].append(total_changed) + self._metrics["train"]["delta/total_elements"].append(total_elements) + logger.info(f"Delta: {total_changed}/{total_elements} elements changed (sparsity={sparsity:.4%})") + t0 = time.time() - logger.info("Weight sync: pausing vLLM...") + is_delta = self.args.delta_sync_enabled + + if is_delta: + # Phase 1: Upload to HF Hub while inference continues + logger.info("Weight sync: uploading to HF Hub (inference still running)...") + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.upload_weights(self._streaming_iter()) + else: + for _ in self._streaming_iter(): + pass + self.accelerator.wait_for_everyone() + t_upload = time.time() + logger.info(f"Weight sync: upload took {t_upload - t0:.1f}s, now pausing vLLM...") + + # Phase 2: Pause inference if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.pause() t_pause = time.time() - logger.info(f"Weight sync: pause took {t_pause - t0:.1f}s, waiting for all ranks...") self.accelerator.wait_for_everyone() t_barrier = time.time() - logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)") - if self.accelerator.is_main_process and self.rollout_worker: - self.rollout_worker.send_weights(self._streaming_iter()) + if is_delta: + # Phase 3: Signal vLLM to fetch the already-uploaded weights + logger.info(f"Weight sync: signaling vLLM to apply... (pause took {t_pause - t_upload:.1f}s)") + if self.accelerator.is_main_process and self.rollout_worker: + try: + self.rollout_worker.apply_weights() + except Exception as e: + logger.warning(f"Weight sync: apply failed ({e}), skipping โ€” vLLM will use stale weights") else: - # Non-rank-0 processes must still participate in full_tensor() collectives for FSDP2. - for _ in self._streaming_iter(): - pass + # NCCL: transfer all weights directly + logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)") + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.send_weights(self._streaming_iter()) + else: + for _ in self._streaming_iter(): + pass + t_transfer = time.time() self.accelerator.wait_for_everyone() - logger.info(f"Weight sync: resuming vLLM... (transfer took {t_transfer - t_barrier:.1f}s)") + # Phase 4: Resume + logger.info(f"Weight sync: resuming vLLM... (apply took {t_transfer - t_barrier:.1f}s)") if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.resume() self.model_version += 1 self.rollout_worker.update_model_version(self.model_version) weight_sync_time_s = time.time() - t0 self._metrics["train"]["weight_sync_time_s"].append(weight_sync_time_s) - logger.info(f"Weight sync: done. Total {weight_sync_time_s:.1f}s") + logger.info( + f"Weight sync: done. Total {weight_sync_time_s:.1f}s (inference paused {t_transfer - t_pause:.1f}s)" + ) def _inner_training_loop(self, *args, **kwargs): # Start the rollout worker here (not in __init__) so that checkpoint loading in Trainer.train() @@ -613,3 +710,6 @@ def _inner_training_loop(self, *args, **kwargs): finally: if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.stop() + if self._change_detector is not None: + self._change_detector.close() + self._change_detector = None diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 3d7350d5a71..75ea49bb9a8 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import copy import inspect import queue import threading @@ -26,12 +27,15 @@ import requests from accelerate.logging import get_logger from datasets import Dataset +from huggingface_hub import create_bucket from transformers import AutoTokenizer from trl.chat_template_utils import add_response_schema, get_training_chat_template, parse_response from trl.import_utils import is_vllm_available from trl.trainer.utils import print_prompt_completions_sample +from .delta_engine import DeltaWeightTransferEngine + if is_vllm_available(min_version="0.17.1"): from vllm.distributed.weight_transfer.nccl_engine import NCCLTrainerSendWeightsArgs, NCCLWeightTransferEngine @@ -56,6 +60,7 @@ class RolloutGroup: tool_mask: list[list[int]] tool_call_counts: list[int] tool_failure_counts: list[int] + environments: list[object] model_version: int queued_at: float = 0.0 @@ -102,11 +107,16 @@ def __init__( weight_names: list[str] | None = None, weight_dtype_names: list[str] | None = None, weight_shapes: list[list[int]] | None = None, + delta_sync_enabled: bool = False, + delta_sync_repo_id: str | None = None, + delta_sync_anchor_interval: int = 10, + delta_sync_encoding: str = "gap_delta", ): if not is_vllm_available(min_version="0.17.1"): raise ImportError( "vLLM >= 0.17.1 is required to use AsyncRolloutWorker. Install it with: pip install 'vllm>=0.17.1'" ) + self.delta_sync_enabled = delta_sync_enabled self.model_name = model_name self.max_tool_calling_iterations = max_tool_calling_iterations self.dataset = dataset @@ -171,7 +181,11 @@ def __init__( self.model_version = 0 self.session = None - # Wait for the vLLM server and initialize NCCL weight transfer. + self._delta_sync_repo_id = delta_sync_repo_id + self._delta_sync_anchor_interval = delta_sync_anchor_interval + self._delta_sync_encoding = delta_sync_encoding + + # Wait for the vLLM server and initialize weight transfer. self._wait_for_server_ready_sync(timeout_s=self.server_timeout) self._init_weight_transfer() @@ -199,6 +213,18 @@ def _wait_for_server_ready_sync(self, timeout_s: float = 240.0, poll_interval_s: time.sleep(poll_interval_s) def _init_weight_transfer(self) -> None: + if self.delta_sync_enabled: + create_bucket(self._delta_sync_repo_id, exist_ok=True) + self._delta_model_version = 0 + self._delta_pending: dict | None = None + requests.post( + f"{self.vllm_server_url}/init_weight_transfer_engine", + json={"init_info": {}}, + timeout=120, + ) + logger.info("Init delta weight transfer with HF Hub repo %s", self._delta_sync_repo_id) + return + response = requests.get(f"{self.vllm_server_url}/get_world_size") inference_world_size = response.json()["world_size"] world_size = inference_world_size + 1 @@ -287,6 +313,83 @@ def resume(self) -> None: logger.debug(f"[weight_sync] resume HTTP took {time.time() - t0:.1f}s") def send_weights(self, iterator) -> None: + """NCCL path: broadcast all params and signal the apply in one call. + + Delta sync uses the explicit two-phase [`upload_weights`] / [`apply_weights`] instead. + """ + self._send_weights_nccl(iterator) + + def upload_weights(self, iterator) -> None: + """Delta phase 1 (inference still running): encode the changed params as a patch, upload it + to the bucket, and record where [`apply_weights`] should fetch it from. + + If nothing changed (empty ``iterator``) the upload is a no-op and ``_delta_pending`` is left + cleared, so the later apply is skipped and vLLM keeps its current weights. The phase is + explicit โ€” never inferred from iterator emptiness โ€” so a zero-change step can't be mistaken + for the apply signal. + """ + self._delta_model_version += 1 + is_anchor = self._delta_model_version == 1 or self._delta_model_version % self._delta_sync_anchor_interval == 0 + if is_anchor: + iterator = ((name, tensor, None) for name, tensor, _mask in iterator) # strip masks -> full tensors + + subdir = "anchors" if is_anchor else "deltas" + filename = f"{subdir}/step_{self._delta_model_version:06d}.safetensors" + meta = DeltaWeightTransferEngine.upload( + iterator=iterator, + bucket_id=self._delta_sync_repo_id, + filename=filename, + model_version=self._delta_model_version, + encoding=self._delta_sync_encoding, + ) + self._delta_pending = ( + None + if meta is None + else { + "is_anchor": is_anchor, + "update_info": { + "repo_id": self._delta_sync_repo_id, + "filename": filename, + "update_kind": "dense" if is_anchor else "sparse_flat", + }, + } + ) + + def apply_weights(self) -> None: + """Delta phase 3 (inference paused): signal vLLM to fetch + apply the uploaded patch. + + No-op when nothing was uploaded this step. ``_delta_pending`` is cleared up front so a failed + apply leaves no stale state for the next step. + """ + if self._delta_pending is None: + return + p, self._delta_pending = self._delta_pending, None + # Anchors are HF-checkpoint-format full tensors; deltas are sparse kernel-format. + self._post_vllm("/start_weight_update", {"is_checkpoint_format": p["is_anchor"]}) + self._post_vllm("/update_weights", {"update_info": p["update_info"]}, retries=5) + self._post_vllm("/finish_weight_update", {}) + + def _post_vllm(self, path: str, json_body: dict, retries: int = 1, timeout: int = 300) -> None: + """POST to a vLLM server endpoint with bounded retry on 429 / connection errors.""" + url = f"{self.vllm_server_url}{path}" + for attempt in range(retries): + try: + resp = requests.post(url, json=json_body, timeout=timeout) + if resp.status_code < 429: + resp.raise_for_status() + return + status = resp.status_code + except requests.RequestException as e: + logger.warning(f"[weight_sync] POST {path} failed: {e}") + status = "connection error" + if attempt < retries - 1: + wait = min(2**attempt, 30) + logger.warning(f"[weight_sync] POST {path} -> {status}, retry in {wait}s ({attempt + 1}/{retries})") + time.sleep(wait) + raise RuntimeError(f"[weight_sync] POST {path} failed after {retries} attempt(s)") + + def _send_weights_nccl(self, iterator) -> None: + """NCCL sync: broadcast all params via NCCL + signal /update_weights.""" if self.model_update_group is None: return t0 = time.time() @@ -299,7 +402,7 @@ def send_weights(self, iterator) -> None: logger.debug(f"[weight_sync] /update_weights POST sent ({time.time() - t0:.1f}s)") t_nccl = time.time() NCCLWeightTransferEngine.trainer_send_weights( - iterator=iterator, + iterator=((name, tensor) for name, tensor, _mask in iterator), trainer_args=NCCLTrainerSendWeightsArgs(group=self.model_update_group, packed=True), ) logger.debug(f"[weight_sync] NCCL transfer took {time.time() - t_nccl:.1f}s") @@ -346,6 +449,7 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None: tool_mask=[], tool_call_counts=[], tool_failure_counts=[], + environments=[], model_version=self.model_version, ) pending_completed[group_id] = 0 @@ -353,10 +457,14 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None: slot = free_slots.pop() if self.environments is not None: - # Current assumption: reset side effects matter, return value is ignored. - self.environments[slot].reset(**row) - - logger.info(f"[slot] assigned slot={slot} group={group_id} free_after={len(free_slots)}") + try: + self.environments[slot].reset(**row) + except Exception as e: + logger.warning(f"[slot={slot}] env.reset() failed: {e}, skipping generation") + free_slots.add(slot) + continue + + logger.debug(f"[slot] assigned slot={slot} group={group_id} free_after={len(free_slots)}") task = asyncio.create_task( self._generate_one(pending_groups[group_id].prompt, tool_dict=self._sync_tool_dicts[slot]) ) @@ -382,7 +490,27 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None: free_slots.add(slot) logger.debug(f"[slot] freed slot={slot} group={group_id} free_after={len(free_slots)}") if task.exception() is not None: - raise task.exception() + logger.warning( + f"[slot={slot}] generation failed for group {group_id}: {task.exception()}, skipping" + ) + pending_completed[group_id] += 1 + if pending_completed[group_id] == self.num_generations: + # All generations attempted but some failed โ€” drop the group + group = pending_groups[group_id] + if group.completions: + # Score whatever we have + group.queued_at = time.monotonic() + try: + self._groups_to_score.put_nowait(group) + except asyncio.QueueFull: + pass + logger.warning( + f"Group {group_id} had failures, got {len(pending_groups[group_id].completions)}" + f"/{self.num_generations} completions" + ) + del pending_groups[group_id] + del pending_completed[group_id] + continue ( completion, @@ -399,6 +527,8 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None: group.tool_mask.append(tool_mask) group.tool_call_counts.append(tool_call_count) group.tool_failure_counts.append(tool_failure_count) + if self.environments is not None: + group.environments.append(copy.copy(self.environments[slot])) # TODO: move this in generation task, shouldn't matter but is correct self._total_completion_tokens += sum(tool_mask) pending_completed[group_id] += 1 @@ -636,6 +766,8 @@ async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]: completion_ids=group.completions_ids, **group.reward_kwargs, ) + if group.environments: + kwargs["environments"] = group.environments all_rewards = await asyncio.gather( *[ reward_func(**kwargs) @@ -678,6 +810,7 @@ async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]: prompt=group.prompt, completion=completion, input_ids=group.prompt_ids + completion_ids, + # FIXME(): normalize completion_mask with GRPO, add tool_mask to the RolloutSample completion_mask=[0] * len(group.prompt_ids) + tool_mask, old_log_probs=[0.0] * len(group.prompt_ids) + logprobs, advantage=advantage, diff --git a/trl/experimental/async_grpo/delta_codec.py b/trl/experimental/async_grpo/delta_codec.py new file mode 100644 index 00000000000..b1330c7ddf0 --- /dev/null +++ b/trl/experimental/async_grpo/delta_codec.py @@ -0,0 +1,169 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GPU-resident sparse delta extraction + index encoding. + +The change mask, ``nonzero``, and value gather all run on the device, so only the sparse +payload โ€” and, for the disk transport, its compressed index form โ€” crosses PCIe. This +replaces the dense ``tensor.to(bf16).cpu()`` + CPU ``nonzero``/gather in the v1 upload path, +which copied the full dense tensor (~100%) to host just to keep ~1-3% of it. + +Index encodings (lossless, shrink only the index half โ€” values are sent raw): + +- ``raw`` : int32 absolute flat positions (4 B/elem) โ€” what NCCL / vLLM #40096 expect. +- ``gap_delta`` : ``idx[k] - idx[k-1] - 1`` packed to uint16 (2 B), uint32 fallback per param. +- ``nvcomp`` : nvCOMP "Cascaded" (delta + bit-pack [+ RLE]) over the int32 indices, on GPU. +""" + +from __future__ import annotations + +from enum import Enum + +import numpy as np +import torch + + +try: + from nvidia import nvcomp # pip: nvidia-nvcomp + + _NVCOMP_OK = True +except Exception: # pragma: no cover - environment dependent + nvcomp = None + _NVCOMP_OK = False + + +class Encoding(str, Enum): + """Index-encoding scheme for a sparse delta patch (values are always stored raw).""" + + RAW = "raw" # int32 absolute positions (4 B/elem) + GAP_DELTA = "gap_delta" # uint16 gaps, uint32 fallback per param (~2 B/elem) + NVCOMP_CASCADED = "nvcomp_cascaded" # nvCOMP Cascaded delta+bitpack on the GPU (~1.3 B/elem) + + +def extract_sparse( + param: torch.Tensor, + mask: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pull changed ``(indices, values)`` for one param, entirely on ``param``'s device. + + Args: + param (`torch.Tensor`): + Full dense parameter (kept on the GPU; never copied to host here). + mask (`torch.Tensor`): + Boolean change mask, same numel as ``param``, same device. + + Returns: + `tuple` of: + - indices (`torch.Tensor`): 1D `int32` flat positions (ascending), on device. + - values (`torch.Tensor`): 1D values at those positions, param dtype, on device. + """ + flat = param.detach().reshape(-1) + idx = mask.reshape(-1).nonzero(as_tuple=True)[0] # device; one sync (dynamic size) + vals = flat.index_select(0, idx) + return idx.to(torch.int32), vals + + +def extract_sparse_batched( + items: list[tuple[str, torch.Tensor, torch.Tensor]], +) -> list[tuple[str, torch.Tensor, torch.Tensor]]: + """Batched [`extract_sparse`] over many params with a single ``nonzero``. + + ``nonzero`` forces a deviceโ†’host sync (its output size is data-dependent), so the per-param + loop costs one sync per param. This concatenates all flattened masks, runs **one** ``nonzero`` + over the whole set, then splits the global positions back per param with ``searchsorted`` on the + cumulative sizes โ€” collapsing ~N syncs into 2 (the ``nonzero`` and one boundary D2H). Indices are + returned local to each param's flat space (ready for that param's ``index_copy_``). + + All tensors must be on the same device. (Transient cost: a concatenated full-size bool mask โ€” + fine for the model sizes here; very large models should shard, see the multi-file TODO.) + + Args: + items (`list[tuple[str, torch.Tensor, torch.Tensor]]`): + ``(name, tensor, mask)`` triples; ``mask`` is the per-param boolean change mask. + + Returns: + `list[tuple[str, torch.Tensor, torch.Tensor]]`: ``(name, int32 local indices, values)`` per + input param, in input order. + """ + if not items: + return [] + device = items[0][1].device + flats = [tensor.detach().reshape(-1) for _, tensor, _ in items] + sizes = torch.tensor([f.numel() for f in flats], device=device) + offsets = torch.cat([sizes.new_zeros(1), torch.cumsum(sizes, 0)]) + global_idx = torch.cat([mask.reshape(-1) for _, _, mask in items]).nonzero(as_tuple=True)[0] + bounds = torch.searchsorted(global_idx, offsets[1:]).tolist() + + out = [] + prev = 0 + for i, (name, _, _) in enumerate(items): + g = global_idx[prev : bounds[i]] - offsets[i] # local positions within this param + out.append((name, g.to(torch.int32), flats[i].index_select(0, g))) + prev = bounds[i] + return out + + +def gap_delta_encode(idx: torch.Tensor) -> torch.Tensor: + """Gap-encode sorted positions: ``delta[k] = idx[k] - idx[k-1] - 1`` (idx[-1] := -1). + + Returns the gaps as ``uint16`` if the max gap fits, else ``uint32`` โ€” so one outlier never + bumps the whole param to 4 B. The dtype *is* the width (no separate width needed); the + receiver inverts with [`gap_delta_decode`]. + + Args: + idx (`torch.Tensor`): + 1D ascending `int32` positions (as returned by [`extract_sparse`]). + + Returns: + `torch.Tensor`: 1D gaps, dtype `uint16` or `uint32`, same device. + """ + if idx.numel() == 0: + return idx.to(torch.uint16) + idx64 = idx.long() + prev = torch.cat([idx64.new_full((1,), -1), idx64[:-1]]) + deltas = idx64 - prev - 1 + return deltas.to(torch.uint16 if int(deltas.max()) <= 0xFFFF else torch.uint32) + + +def gap_delta_decode(deltas: torch.Tensor) -> torch.Tensor: + """Invert [`gap_delta_encode`] โ†’ 1D ascending `int32` positions on ``deltas``' device.""" + if deltas.numel() == 0: + return deltas.to(torch.int32) + return (torch.cumsum(deltas.long() + 1, dim=0) - 1).to(torch.int32) + + +def nvcomp_available() -> bool: + return _NVCOMP_OK + + +def nvcomp_encode(idx: torch.Tensor) -> torch.Tensor: + """Compress absolute int32 indices with nvCOMP Cascaded โ†’ ``uint8`` byte tensor (CPU). + + Cascaded does the delta + bit-pack internally, so callers pass raw int32 positions (no + gap-encoding needed). The self-describing bitstream carries dtype + length, so + [`nvcomp_decode`] needs nothing else. Cascaded is a GPU codec, so indices are moved to CUDA. + """ + if not _NVCOMP_OK: + raise RuntimeError("nvidia-nvcomp not installed") + comp = nvcomp.Codec(algorithm="Cascaded").encode(nvcomp.as_array(idx.to("cuda", torch.int32).contiguous())) + return torch.from_numpy(np.asarray(comp.cpu()).view(np.uint8).copy()) + + +def nvcomp_decode(raw: torch.Tensor) -> torch.Tensor: + """Inverse of [`nvcomp_encode`]: ``uint8`` bytes โ†’ 1D ``int32`` indices (CPU).""" + if not _NVCOMP_OK: + raise RuntimeError("nvidia-nvcomp not installed") + dec = nvcomp.Codec(algorithm="Cascaded").decode(nvcomp.as_array(raw.to("cuda").contiguous())) + return torch.from_numpy(np.asarray(dec.cpu()).view(np.int32).copy()) # reinterpret bytes โ†’ int32 diff --git a/trl/experimental/async_grpo/delta_engine.py b/trl/experimental/async_grpo/delta_engine.py new file mode 100644 index 00000000000..adb572749f9 --- /dev/null +++ b/trl/experimental/async_grpo/delta_engine.py @@ -0,0 +1,299 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import os +import tempfile +import time +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from dataclasses import dataclass + +import torch +from huggingface_hub import batch_bucket_files, download_bucket_files +from safetensors import safe_open +from safetensors.torch import save_file +from vllm.distributed.weight_transfer.base import ( + SparseWeightPatch, + WeightTransferEngine, + WeightTransferInitInfo, + WeightTransferUpdateInfo, +) +from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory + +from .delta_codec import ( + Encoding, + extract_sparse_batched, + gap_delta_decode, + gap_delta_encode, + nvcomp_decode, + nvcomp_encode, +) +from .weight_diff import PatchMetadata + + +try: + from vllm.logger import init_logger + + logger = init_logger(f"vllm.{__name__}") +except Exception: # pragma: no cover - vLLM always present where this module is used + logger = logging.getLogger(__name__) + + +@contextmanager +def _fetch(update_info): + """Download a patch from the bucket to a temp file and yield an open safetensors handle.""" + # TODO(@aminediro): writes only 1 safetensors, for very large model, this needs to be multiple files + with tempfile.TemporaryDirectory() as tmpdir: + path = f"{tmpdir}/weights.safetensors" + download_bucket_files(update_info.repo_id, files=[(update_info.filename, path)]) + with safe_open(path, framework="pt", device="cpu") as f: + yield f + + +def _encode_idx(idx: torch.Tensor, encoding: Encoding) -> torch.Tensor: + """Encode absolute int32 indices for storage. Returns a CPU tensor (I32, U16/U32, or U8 bytes).""" + encoding = Encoding(encoding) + if encoding is Encoding.RAW: + return idx.to(torch.int32).cpu().contiguous() + if encoding is Encoding.GAP_DELTA: + return gap_delta_encode(idx).cpu().contiguous() # native uint16/uint32 (dtype = width) + return nvcomp_encode(idx) # Encoding.NVCOMP_CASCADED -> uint8 CPU bytes + + +def _decode_idx(raw: torch.Tensor, encoding: Encoding) -> torch.Tensor: + """Inverse of [`_encode_idx`] โ†’ 1D int32 absolute indices. + + Self-describing: ``raw.dtype`` carries the gap-delta width, so no element count is needed. + """ + encoding = Encoding(encoding) + if encoding is Encoding.RAW: + return raw.to(torch.int32) + if encoding is Encoding.GAP_DELTA: + return gap_delta_decode(raw) + return nvcomp_decode(raw) + + +@dataclass +class DeltaWeightTransferInitInfo(WeightTransferInitInfo): + pass + + +@dataclass +class DeltaWeightTransferUpdateInfo(WeightTransferUpdateInfo): + """Per-sync info sent via ``/update_weights`` โ€” just bucket coordinates + kind. + + Names and per-param nnz are read from the downloaded file, so ``num_updates_list`` is not + required here (we override the base validation that would otherwise demand it for sparse). + """ + + repo_id: str = "" # bucket_id + filename: str = "" + + def __post_init__(self) -> None: + if self.update_kind not in ("dense", "sparse_flat"): + raise ValueError(f"Unsupported update_kind: {self.update_kind}") + + +class DeltaWeightTransferEngine(WeightTransferEngine[DeltaWeightTransferInitInfo, DeltaWeightTransferUpdateInfo]): + """Weight transfer engine using an HF Storage Bucket as the data plane.""" + + init_info_cls = DeltaWeightTransferInitInfo + update_info_cls = DeltaWeightTransferUpdateInfo + + def init_transfer_engine(self, init_info: DeltaWeightTransferInitInfo) -> None: + pass + + def receive_weights( + self, + update_info: DeltaWeightTransferUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + """Anchor path: download full safetensors from the bucket and load them directly.""" + t0 = time.time() + with _fetch(update_info) as f: + t_dl = time.time() + for name in f.keys(): + load_weights([(name, f.get_tensor(name))]) + meta = PatchMetadata.from_metadata_dict(f.metadata()) + t_apply = time.time() + logger.info( + "Applied anchor (step %d, %d params) | download %.2fs load %.2fs", + meta.model_version, + meta.num_changed_params, + t_dl - t0, + t_apply - t_dl, + ) + + def receive_sparse_weights( + self, + update_info: DeltaWeightTransferUpdateInfo, + apply_patches: Callable[[list[SparseWeightPatch]], None], + ) -> None: + t0 = time.time() + patches = [] + with _fetch(update_info) as f: + t_dl = time.time() + names, idxs, vals = [], [], [] + meta = PatchMetadata.from_metadata_dict(f.metadata()) + for name, idx, values in iter_sparse_patches(f): + names.append(name) + idxs.append(idx) + vals.append(values) + if names: + device = torch.accelerator.current_device_index() + sizes = [i.numel() for i in idxs] + all_idx = torch.cat(idxs).to(device) + all_val = torch.cat(vals).to(device) + off = 0 + for name, n in zip(names, sizes, strict=False): + patches.append( + SparseWeightPatch(name=name, indices=all_idx[off : off + n], values=all_val[off : off + n]) + ) + off += n + if patches: + apply_patches(patches) + + t_apply = time.time() + logger.info( + "Applied delta (step %d, %d params, sparsity=%.4f) | download %.2fs decode+apply %.2fs", + meta.model_version, + meta.num_changed_params, + meta.sparsity, + t_dl - t0, + t_apply - t_dl, + ) + + def shutdown(self) -> None: + pass + + @staticmethod + def trainer_send_weights(iterator, trainer_args) -> None: + raise NotImplementedError("Use AsyncRolloutWorker.upload_weights / apply_weights instead") + + @staticmethod + def upload( + iterator: Iterator[tuple[str, torch.Tensor, torch.Tensor | None]], + bucket_id: str, + filename: str, + model_version: int = 0, + encoding: Encoding = Encoding.GAP_DELTA, + ) -> PatchMetadata | None: + """Encode params as a safetensors patch and push to the bucket. + + Returns the :class:`PatchMetadata` (also written to the safetensors header), or ``None`` + if the iterator was empty. + """ + tensors, meta = encode_patch(iterator, model_version=model_version, encoding=encoding) + if tensors is None: + return None + # Write to a temp file and upload the path: hf-xet's in-memory `upload_bytes` panics on + # multi-GB buffers (e.g. a full-model anchor); the file path uses the large-file code path. + with tempfile.TemporaryDirectory() as tmpdir: + local_path = f"{tmpdir}/patch.safetensors" + save_file(tensors, local_path, metadata=meta.to_metadata_dict()) + size_mb = os.path.getsize(local_path) / 1e6 + batch_bucket_files(bucket_id, add=[(local_path, filename)]) + logger.info( + "[delta_engine] uploaded %s/%s (%.1f MB, %d params, sparse=%s, enc=%s, sparsity=%.4f)", + bucket_id, + filename, + size_mb, + meta.num_changed_params, + meta.sparse, + meta.encoding.value, + meta.sparsity, + ) + return meta + + +def encode_patch( + iterator: Iterator[tuple[str, torch.Tensor, torch.Tensor | None]], + model_version: int = 0, + encoding: Encoding = Encoding.GAP_DELTA, +) -> tuple[dict[str, torch.Tensor] | None, PatchMetadata | None]: + """Build the safetensors tensor dict + metadata for a patch (no I/O). + + Each item is ``(name, tensor, mask)``: + + - ``mask is None``: full tensor stored as ``name`` (anchor). + - ``mask`` provided: GPU sparse-extract; store encoded indices as ``{name}.idx`` and values + as ``{name}.val``. ``encoding`` is ``"raw"`` (int32) or ``"gap_delta"`` (uint16 gap bytes, + uint32 fallback per param). + """ + encoding = Encoding(encoding) + tensors: dict[str, torch.Tensor] = {} + delta_items: list[tuple[str, torch.Tensor, torch.Tensor]] = [] + n_params = 0 + total_changed = 0 + total_elements = 0 + + for name, tensor, mask in iterator: + n_params += 1 + total_elements += tensor.numel() + if mask is None: # anchor: store the full tensor + tensors[name] = tensor.detach().to(torch.bfloat16).cpu().contiguous().clone() + total_changed += tensor.numel() + else: + delta_items.append((name, tensor, mask)) + + for name, idx, vals in extract_sparse_batched(delta_items): + total_changed += idx.numel() + tensors[f"{name}.val"] = vals.to(torch.bfloat16).cpu().contiguous() + tensors[f"{name}.idx"] = _encode_idx(idx, encoding) + + sparse = bool(delta_items) + if not tensors: + return None, None + + meta = PatchMetadata( + sparse=sparse, + model_version=model_version, + num_changed_params=n_params, + total_changed_elements=total_changed, + total_elements=total_elements, + sparsity=1.0 - total_changed / max(total_elements, 1), + encoding=encoding, + ) + return tensors, meta + + +def iter_sparse_patches(f) -> Iterator[tuple[str, torch.Tensor, torch.Tensor]]: + """Yield ``(name, int32 indices, values)`` from an open sparse safetensors handle. + + Flat, self-describing format: param names are recovered from the ``{name}.val`` tensor keys and + the index encoding is a single global header field (the gap-delta width is carried by the index + tensor's own dtype). ``f`` is a ``safetensors.safe_open`` handle. + """ + encoding = PatchMetadata.from_metadata_dict(f.metadata()).encoding + names = sorted({k[: -len(".val")] for k in f.keys() if k.endswith(".val")}) + for name in names: + idx = _decode_idx(f.get_tensor(f"{name}.idx"), encoding) + yield name, idx, f.get_tensor(f"{name}.val") + + +class DeltaWorkerExtension: + """vLLM worker-extension hook (pass via ``--worker-extension-cls``). + + Required: ``--worker-extension-cls`` makes the vLLM *worker* process import this module, which + runs the ``register_engine`` call below so the ``"delta"`` backend exists in the worker (the + factory registry is per-process)""" + + pass + + +if "delta" not in WeightTransferEngineFactory._registry: + WeightTransferEngineFactory.register_engine("delta", DeltaWeightTransferEngine) diff --git a/trl/experimental/async_grpo/weight_diff.py b/trl/experimental/async_grpo/weight_diff.py new file mode 100644 index 00000000000..8d7f8fe3c06 --- /dev/null +++ b/trl/experimental/async_grpo/weight_diff.py @@ -0,0 +1,263 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Delta-compressed weight synchronization utilities. + +- ``BF16ChangeDetector``: hooks into the optimizer to detect which bf16 elements + actually changed after each step. +- ``PatchMetadata``: structured metadata stored in safetensors headers for both + anchor (full) and delta (sparse) weight files. +""" + +from __future__ import annotations + +import logging +from dataclasses import asdict, dataclass +from enum import Enum + +import torch + +from .delta_codec import Encoding + + +logger = logging.getLogger(__name__) + + +class BF16ChangeDetector: + """Detects which bf16 weights actually changed across an optimizer step. + + Hooks into the optimizer via ``register_step_pre_hook`` / ``register_step_post_hook`` + (PyTorch >= 2.1). Snapshots bf16 values before the step, compares after. + + ``_validated_masks[name]`` is a boolean tensor with True for each element that changed. + """ + + def __init__(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer): + self._validated_masks: dict[str, torch.Tensor] = {} + self._pre_step_bf16: dict[str, torch.Tensor] = {} + + # Match model param names to optimizer param objects via data_ptr() + # (id() doesn't work because Accelerate wraps params as different objects) + model_params = {p.data_ptr(): name.removeprefix("module.") for name, p in model.named_parameters()} + self._param_id_to_name: dict[int, str] = {} + for group in optimizer.param_groups: + for p in group["params"]: + name = model_params.get(p.data_ptr()) + if name is not None: + self._param_id_to_name[id(p)] = name + + logger.info( + "BF16ChangeDetector: matched %d/%d optimizer params", + len(self._param_id_to_name), + sum(1 for _ in model.named_parameters()), + ) + + self._pre_hook_handle = optimizer.register_step_pre_hook(self._pre_step_hook) + self._post_hook_handle = optimizer.register_step_post_hook(self._post_step_hook) + + def _pre_step_hook(self, optimizer, args, kwargs) -> None: + self._pre_step_bf16.clear() + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + name = self._param_id_to_name.get(id(p)) + if name is None: + continue + self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).cpu().clone() + + def _post_step_hook(self, optimizer, args, kwargs) -> None: + self._validated_masks.clear() + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + name = self._param_id_to_name.get(id(p)) + if name is None or name not in self._pre_step_bf16: + continue + self._validated_masks[name] = p.detach().to(torch.bfloat16).cpu() != self._pre_step_bf16[name] + + def close(self): + self._pre_hook_handle.remove() + self._post_hook_handle.remove() + + +def _low_byte(p: torch.Tensor, device: torch.device | str | None = None) -> torch.Tensor: + """Low byte of each element's bf16 bit pattern as ``uint8`` (1 byte/elem). + + bf16 layout (16 bits): ``[sign | exp(8) | mantissa(7)]``. The low byte holds all 7 + mantissa bits plus the exponent LSB, so any sub-ULP / mantissa-level update flips it. + Snapshotting just this byte costs 1 B/elem โ€” half a full bf16 clone. + + Stays on ``p``'s device by default (so the change mask is computed on-GPU and only the + sparse payload crosses PCIe). Pass ``device="cpu"`` to keep the snapshot in host memory + when a full on-device snapshot would not fit (e.g. DeepSpeed-Z2 holds full params/rank). + """ + bf16 = p.detach().to(torch.bfloat16) + if device is not None: + bf16 = bf16.to(device) + bf16 = bf16.contiguous() + # TODO: 0xFF should be configurable either at module level or via some number of precision bits. + return bf16.view(torch.int16).bitwise_and(0xFF).to(torch.uint8) + + +class LowByteChangeDetector: + """Detects changed bf16 weights from a 1-byte-per-element snapshot. + + Like [`BF16ChangeDetector`], hooks the optimizer (PyTorch >= 2.1) and diffs pre/post + step โ€” but snapshots only the **low byte** of each weight's bf16 pattern (1 B/elem) + instead of the full bf16 value (2 B/elem). A flipped low byte implies the bf16 value + changed, so the detected mask is a strict subset of the true change set: **no false + positives**, but rare false negatives (mantissa + exp-LSB unchanged while a high + exponent/sign bit changed). Those misses cause inference-side drift, bounded by + periodic anchors โ€” set ``validate_recall=True`` to measure the miss rate. + + ``_validated_masks[name]`` is a boolean tensor, True for each element detected as changed. + + Args: + model ([`~torch.nn.Module`]): + Model whose parameters are tracked. + optimizer ([`~torch.optim.Optimizer`]): + Optimizer to hook. Must expose native ``register_step_*_hook`` (unwrap Accelerate first). + validate_recall (`bool`, *optional*, defaults to `False`): + Also keep a full bf16 snapshot to score low-byte detection against the true diff. + Doubles the snapshot cost; for diagnostics only. + snapshot_to_cpu (`bool`, *optional*, defaults to `False`): + Keep the low-byte snapshot in host memory instead of on the param's device. Masks + are then produced on CPU. Use when a full on-device snapshot would not fit (e.g. + DeepSpeed-Z2 holds full params per rank). Default keeps it on-GPU so the change + mask and sparse extraction stay on the device. + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + validate_recall: bool = False, + snapshot_to_cpu: bool = False, + ): + self.validate_recall = validate_recall + self._snap_device = "cpu" if snapshot_to_cpu else None + self._validated_masks: dict[str, torch.Tensor] = {} + self._pre_step_low: dict[str, torch.Tensor] = {} + self._pre_step_bf16: dict[str, torch.Tensor] = {} # only populated when validate_recall + self._accuracy: dict[str, float] = {} + + # Match model param names to optimizer param objects via data_ptr() + # (id() doesn't work because Accelerate wraps params as different objects) + model_params = {p.data_ptr(): name.removeprefix("module.") for name, p in model.named_parameters()} + self._param_id_to_name: dict[int, str] = {} + for group in optimizer.param_groups: + for p in group["params"]: + name = model_params.get(p.data_ptr()) + if name is not None: + self._param_id_to_name[id(p)] = name + + logger.info( + "LowByteChangeDetector: matched %d/%d optimizer params (validate_recall=%s)", + len(self._param_id_to_name), + sum(1 for _ in model.named_parameters()), + validate_recall, + ) + + self._pre_hook_handle = optimizer.register_step_pre_hook(self._pre_step_hook) + self._post_hook_handle = optimizer.register_step_post_hook(self._post_step_hook) + + def _pre_step_hook(self, optimizer, args, kwargs) -> None: + self._pre_step_low.clear() + self._pre_step_bf16.clear() + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + name = self._param_id_to_name.get(id(p)) + if name is None: + continue + self._pre_step_low[name] = _low_byte(p, self._snap_device) + if self.validate_recall: + self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).to(self._snap_device or p.device).clone() + + def _post_step_hook(self, optimizer, args, kwargs) -> None: + self._validated_masks.clear() + total_tp, total_true, total_elements = 0, 0, 0 + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + name = self._param_id_to_name.get(id(p)) + if name is None or name not in self._pre_step_low: + continue + detected = _low_byte(p, self._snap_device) != self._pre_step_low[name] + self._validated_masks[name] = detected + if self.validate_recall: + # True bf16 diff, computed here while p still holds the post-step value. + post_bf16 = p.detach().to(torch.bfloat16).to(self._snap_device or p.device) + true_mask = post_bf16 != self._pre_step_bf16[name] + total_tp += (detected & true_mask).sum().item() + total_true += true_mask.sum().item() + total_elements += true_mask.numel() + if self.validate_recall: + # Low-byte changes โІ bf16 changes, so precision is 1.0 by construction; + # recall = fraction of truly-changed elements the low byte detected. + self._accuracy = { + "recall": total_tp / max(total_true, 1), + "true_changed": total_true, + "detected_changed": total_tp, + "total_elements": total_elements, + "sparsity": 1.0 - total_true / max(total_elements, 1), + } + + def get_prediction_accuracy(self) -> dict[str, float]: + return dict(self._accuracy) + + def close(self): + self._pre_hook_handle.remove() + self._post_hook_handle.remove() + + +@dataclass +class PatchMetadata: + format: str = "sparse_weight_patch" + version: str = "1" + sparse: bool = False + model_version: int = 0 + num_changed_params: int = 0 + total_changed_elements: int = 0 + total_elements: int = 0 + sparsity: float = 0.0 + encoding: Encoding = Encoding.GAP_DELTA # index encoding (delta files only) + + def to_metadata_dict(self) -> dict[str, str]: + return {k: (v.value if isinstance(v, Enum) else str(v)) for k, v in asdict(self).items()} + + @classmethod + def from_metadata_dict(cls, d: dict[str, str]) -> PatchMetadata: + field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()} + kwargs = {} + for k, v in d.items(): + if k not in field_types: + continue + ft = field_types[k] + if ft == "int": + kwargs[k] = int(v) + elif ft == "float": + kwargs[k] = float(v) + elif ft == "bool": + kwargs[k] = v == "True" + elif ft == "Encoding": + kwargs[k] = Encoding(v) + else: + kwargs[k] = v + return cls(**kwargs)