From 081a21a7092169076dcfa21bb2fd30b533b675d0 Mon Sep 17 00:00:00 2001 From: aminediro Date: Thu, 4 Jun 2026 09:51:24 +0000 Subject: [PATCH 1/5] Add delta weight sync (sparse bucket patches) for AsyncGRPO --- pyproject.toml | 8 + .../async_grpo/async_grpo_config.py | 29 ++ .../async_grpo/async_grpo_trainer.py | 58 +++- trl/experimental/async_grpo/delta_codec.py | 167 ++++++++++ trl/experimental/async_grpo/delta_engine.py | 298 ++++++++++++++++++ trl/experimental/async_grpo/weight_diff.py | 260 +++++++++++++++ .../async_grpo/weight_transfer.py | 89 ++++++ 7 files changed, 908 insertions(+), 1 deletion(-) create mode 100644 trl/experimental/async_grpo/delta_codec.py create mode 100644 trl/experimental/async_grpo/delta_engine.py create mode 100644 trl/experimental/async_grpo/weight_diff.py diff --git a/pyproject.toml b/pyproject.toml index 279ccce42c1..e7427e5a258 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,14 @@ vlm = [ math_verify = [ "math-verify>=0.5.2", ] +delta_weight_sync = [ + # Delta weight sync for AsyncGRPO (trl/experimental/async_grpo). Needs vLLM's sparse weight + # transfer (vllm-project/vllm#40096): merged to main after v0.22.0, not in a release yet — + # install vLLM from the nightly index until it ships (so no `vllm` pin here). Serve with + # `--model-impl transformers` and `VLLM_USE_V2_MODEL_RUNNER=0`. + "huggingface-hub>=1.17.0", # HF Storage Bucket API (create/batch/download_bucket_files) + "nvidia-nvcomp-cu12>=5.2.0", # optional: only for the nvcomp_cascaded index encoding (CUDA 12) +] openreward = [ "openreward>=0.1.109; python_version >= '3.11'", # openreward requires Python 3.11+ ] diff --git a/trl/experimental/async_grpo/async_grpo_config.py b/trl/experimental/async_grpo/async_grpo_config.py index d2b588a9033..c8b9764998e 100644 --- a/trl/experimental/async_grpo/async_grpo_config.py +++ b/trl/experimental/async_grpo/async_grpo_config.py @@ -195,6 +195,35 @@ class AsyncGRPOConfig(_BaseConfig): }, ) + # Parameters that control delta weight sync (Transport B: sparse patches over an HF Storage Bucket) + delta_sync_enabled: bool = field( + default=False, + metadata={ + "help": "Sync only the changed bf16 weights as sparse safetensors patches via an HF Storage Bucket " + "(applied in place on vLLM), instead of broadcasting the full policy over NCCL. Requires a vLLM with " + "sparse weight transfer (#40096), served with `--model-impl transformers` and `VLLM_USE_V2_MODEL_RUNNER=0`." + }, + ) + delta_sync_repo_id: str | None = field( + default=None, + metadata={ + "help": "HF Storage Bucket for the delta patches/anchors (created if missing). Required when " + "`delta_sync_enabled=True`." + }, + ) + delta_sync_anchor_interval: int = field( + default=10, + metadata={ + "help": "Send a full anchor every N syncs; sparse deltas in between (bounds drift from missed bits)." + }, + ) + delta_sync_encoding: str = field( + default="gap_delta", + metadata={ + "help": "Index encoding for delta patches: 'raw', 'gap_delta', or 'nvcomp_cascaded' (needs nvcomp)." + }, + ) + # Parameters that control the logging log_completions: bool = field( default=False, diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index 78a008a7ff8..31647a6294e 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -35,6 +35,7 @@ from .async_grpo_config import AsyncGRPOConfig from .async_rollout_worker import AsyncRolloutWorker +from .weight_diff import LowByteChangeDetector from .weight_transfer import WeightTransferClient @@ -413,6 +414,7 @@ def __init__( self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} self._train_tokens_start_time = None self.model_version = 0 + self._change_detector: LowByteChangeDetector | None = None # delta sync only; created in compute_loss # Create worker and queue on rank 0 if self.accelerator.is_main_process: if self.train_dataset is None: @@ -444,6 +446,10 @@ def __init__( "packed": True, "is_checkpoint_format": True, }, + delta_sync_enabled=self.args.delta_sync_enabled, + delta_sync_repo_id=self.args.delta_sync_repo_id, + delta_sync_anchor_interval=self.args.delta_sync_anchor_interval, + delta_sync_encoding=self.args.delta_sync_encoding, ) self.rollout_worker = AsyncRolloutWorker( model_name=model_name, @@ -517,6 +523,10 @@ def _set_signature_columns_if_needed(self): ] def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + # Register the change detector before the first optimizer.step so step 1 is captured and the + # first delta sync is already sparse (no-op unless delta sync is enabled). + self._maybe_init_change_detector() + input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] completion_mask = inputs["completion_mask"] @@ -667,8 +677,48 @@ def _streaming_iter(self): full = full.to(device) yield name, full + def _streaming_iter_delta(self): + """Like [`_streaming_iter`] but yields ``(name, full, mask)`` for delta sync. With an active + change detector, only changed params are yielded (their element-level masks); otherwise all params with + ``mask=None`` (the cold anchor). All ranks must walk this identically so the FSDP2 ``full_tensor()`` + collectives line up. + """ + device = self.accelerator.device + masks = self._change_detector._validated_masks if self._change_detector is not None else {} + for name, param in self.model.named_parameters(): + name = name.removeprefix("module.") # DDP/FSDP1 wrapping + mask = masks.get(name) if masks else None + if masks and (mask is None or not mask.any()): + continue # unchanged param -> not in this delta + full = param.full_tensor() if isinstance(param, DTensor) else param.detach() + if full.device != device: + full = full.to(device) + yield name, full, mask + + def _maybe_init_change_detector(self): + """Create the bf16 change detector once the (prepared) optimizer exists, before its first + step, so the first delta sync is already sparse. No-op unless delta sync is enabled.""" + if self.args.delta_sync_enabled and self._change_detector is None and getattr(self, "optimizer", None): + # Unwrap AcceleratedOptimizer to the native torch optimizer (register_step_*_hook). + raw_optimizer = getattr(self.optimizer, "optimizer", self.optimizer) + self._change_detector = LowByteChangeDetector(self.model, raw_optimizer) + def _sync_weight(self): t0 = time.time() + is_delta = self.args.delta_sync_enabled + + if is_delta: + # Delta phase 1: upload the sparse patch to the bucket while inference keeps running. + logger.info("Weight sync: uploading delta patch (inference still running)...") + if self.accelerator.is_main_process and self.weight_transfer: + self.weight_transfer.upload_weights(self._streaming_iter_delta()) + else: + # Non-rank-0 still walks the iterator so full_tensor() collectives complete. + for _ in self._streaming_iter_delta(): + pass + self.accelerator.wait_for_everyone() + + # Pause vllm both delta and full logger.info("Weight sync: pausing vLLM...") if self.accelerator.is_main_process and self.weight_transfer: self.weight_transfer.pause() @@ -679,7 +729,13 @@ def _sync_weight(self): t_barrier = time.time() logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)") - if self.accelerator.is_main_process and self.weight_transfer: + if is_delta: + if self.accelerator.is_main_process and self.weight_transfer: + try: + self.weight_transfer.apply_weights_delta() + except Exception as e: + logger.warning(f"Weight sync: apply failed ({e}), skipping, vLLM will use stale weights") + elif self.accelerator.is_main_process and self.weight_transfer: self.weight_transfer.send_weights(self._streaming_iter()) else: # Non-rank-0 processes must still participate in full_tensor() collectives for FSDP2. diff --git a/trl/experimental/async_grpo/delta_codec.py b/trl/experimental/async_grpo/delta_codec.py new file mode 100644 index 00000000000..62b7f930405 --- /dev/null +++ b/trl/experimental/async_grpo/delta_codec.py @@ -0,0 +1,167 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GPU-resident sparse delta extraction + index encoding. + +The change mask, ``nonzero``, and value gather all run on the device, so only the sparse payload — and, for the disk +transport, its compressed index form — crosses PCIe. This replaces the dense ``tensor.to(bf16).cpu()`` + CPU +``nonzero``/gather in the v1 upload path, which copied the full dense tensor (~100%) to host just to keep ~1-3% of it. + +Index encodings (lossless, shrink only the index half — values are sent raw): + +- ``raw`` : int32 absolute flat positions (4 B/elem) — what NCCL / vLLM #40096 expect. +- ``gap_delta`` : ``idx[k] - idx[k-1] - 1`` packed to uint16 (2 B), uint32 fallback per param. +- ``nvcomp`` : nvCOMP "Cascaded" (delta + bit-pack [+ RLE]) over the int32 indices, on GPU. +""" + +from __future__ import annotations + +from enum import Enum + +import numpy as np +import torch + + +try: + from nvidia import nvcomp # pip: nvidia-nvcomp + + _NVCOMP_OK = True +except Exception: # pragma: no cover - environment dependent + nvcomp = None + _NVCOMP_OK = False + + +class Encoding(str, Enum): + """Index-encoding scheme for a sparse delta patch (values are always stored raw).""" + + RAW = "raw" # int32 absolute positions (4 B/elem) + GAP_DELTA = "gap_delta" # uint16 gaps, uint32 fallback per param (~2 B/elem) + NVCOMP_CASCADED = "nvcomp_cascaded" # nvCOMP Cascaded delta+bitpack on the GPU (~1.3 B/elem) + + +def extract_sparse( + param: torch.Tensor, + mask: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pull changed ``(indices, values)`` for one param, entirely on ``param``'s device. + + Args: + param (`torch.Tensor`): + Full dense parameter (kept on the GPU; never copied to host here). + mask (`torch.Tensor`): + Boolean change mask, same numel as ``param``, same device. + + Returns: + `tuple` of: + - indices (`torch.Tensor`): 1D `int32` flat positions (ascending), on device. + - values (`torch.Tensor`): 1D values at those positions, param dtype, on device. + """ + flat = param.detach().reshape(-1) + idx = mask.reshape(-1).nonzero(as_tuple=True)[0] # device; one sync (dynamic size) + vals = flat.index_select(0, idx) + return idx.to(torch.int32), vals + + +def extract_sparse_batched( + items: list[tuple[str, torch.Tensor, torch.Tensor]], +) -> list[tuple[str, torch.Tensor, torch.Tensor]]: + """Batched [`extract_sparse`] over many params with a single ``nonzero``. + + ``nonzero`` forces a device→host sync (its output size is data-dependent), so the per-param loop costs one sync per + param. This concatenates all flattened masks, runs **one** ``nonzero`` over the whole set, then splits the global + positions back per param with ``searchsorted`` on the cumulative sizes — collapsing ~N syncs into 2 (the + ``nonzero`` and one boundary D2H). Indices are returned local to each param's flat space (ready for that param's + ``index_copy_``). + + All tensors must be on the same device. (Transient cost: a concatenated full-size bool mask — fine for the model + sizes here; very large models should shard, see the multi-file TODO.) + + Args: + items (`list[tuple[str, torch.Tensor, torch.Tensor]]`): + ``(name, tensor, mask)`` triples; ``mask`` is the per-param boolean change mask. + + Returns: + `list[tuple[str, torch.Tensor, torch.Tensor]]`: ``(name, int32 local indices, values)`` per input param, in + input order. + """ + if not items: + return [] + device = items[0][1].device + flats = [tensor.detach().reshape(-1) for _, tensor, _ in items] + sizes = torch.tensor([f.numel() for f in flats], device=device) + offsets = torch.cat([sizes.new_zeros(1), torch.cumsum(sizes, 0)]) + global_idx = torch.cat([mask.reshape(-1) for _, _, mask in items]).nonzero(as_tuple=True)[0] + bounds = torch.searchsorted(global_idx, offsets[1:]).tolist() + + out = [] + prev = 0 + for i, (name, _, _) in enumerate(items): + g = global_idx[prev : bounds[i]] - offsets[i] # local positions within this param + out.append((name, g.to(torch.int32), flats[i].index_select(0, g))) + prev = bounds[i] + return out + + +def gap_delta_encode(idx: torch.Tensor) -> torch.Tensor: + """Gap-encode sorted positions: ``delta[k] = idx[k] - idx[k-1] - 1`` (idx[-1] := -1). + + Returns the gaps as ``uint16`` if the max gap fits, else ``uint32`` — so one outlier never bumps the whole param to + 4 B. The dtype *is* the width (no separate width needed); the receiver inverts with [`gap_delta_decode`]. + + Args: + idx (`torch.Tensor`): + 1D ascending `int32` positions (as returned by [`extract_sparse`]). + + Returns: + `torch.Tensor`: 1D gaps, dtype `uint16` or `uint32`, same device. + """ + if idx.numel() == 0: + return idx.to(torch.uint16) + idx64 = idx.long() + prev = torch.cat([idx64.new_full((1,), -1), idx64[:-1]]) + deltas = idx64 - prev - 1 + return deltas.to(torch.uint16 if int(deltas.max()) <= 0xFFFF else torch.uint32) + + +def gap_delta_decode(deltas: torch.Tensor) -> torch.Tensor: + """Invert [`gap_delta_encode`] → 1D ascending `int32` positions on ``deltas``' device.""" + if deltas.numel() == 0: + return deltas.to(torch.int32) + return (torch.cumsum(deltas.long() + 1, dim=0) - 1).to(torch.int32) + + +def nvcomp_available() -> bool: + return _NVCOMP_OK + + +def nvcomp_encode(idx: torch.Tensor) -> torch.Tensor: + """Compress absolute int32 indices with nvCOMP Cascaded → ``uint8`` byte tensor (CPU). + + Cascaded does the delta + bit-pack internally, so callers pass raw int32 positions (no gap-encoding needed). The + self-describing bitstream carries dtype + length, so [`nvcomp_decode`] needs nothing else. Cascaded is a GPU codec, + so indices are moved to CUDA. + """ + if not _NVCOMP_OK: + raise RuntimeError("nvidia-nvcomp not installed") + comp = nvcomp.Codec(algorithm="Cascaded").encode(nvcomp.as_array(idx.to("cuda", torch.int32).contiguous())) + return torch.from_numpy(np.asarray(comp.cpu()).view(np.uint8).copy()) + + +def nvcomp_decode(raw: torch.Tensor) -> torch.Tensor: + """Inverse of [`nvcomp_encode`]: ``uint8`` bytes → 1D ``int32`` indices (CPU).""" + if not _NVCOMP_OK: + raise RuntimeError("nvidia-nvcomp not installed") + dec = nvcomp.Codec(algorithm="Cascaded").decode(nvcomp.as_array(raw.to("cuda").contiguous())) + return torch.from_numpy(np.asarray(dec.cpu()).view(np.int32).copy()) # reinterpret bytes → int32 diff --git a/trl/experimental/async_grpo/delta_engine.py b/trl/experimental/async_grpo/delta_engine.py new file mode 100644 index 00000000000..678e86dbbb2 --- /dev/null +++ b/trl/experimental/async_grpo/delta_engine.py @@ -0,0 +1,298 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import os +import tempfile +import time +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from dataclasses import dataclass + +import torch +from huggingface_hub import batch_bucket_files, download_bucket_files +from safetensors import safe_open +from safetensors.torch import save_file +from vllm.distributed.weight_transfer.base import ( + SparseWeightPatch, + WeightTransferEngine, + WeightTransferInitInfo, + WeightTransferUpdateInfo, +) +from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory + +from .delta_codec import ( + Encoding, + extract_sparse_batched, + gap_delta_decode, + gap_delta_encode, + nvcomp_decode, + nvcomp_encode, +) +from .weight_diff import PatchMetadata + + +try: + from vllm.logger import init_logger + + logger = init_logger(f"vllm.{__name__}") +except Exception: # pragma: no cover - vLLM always present where this module is used + logger = logging.getLogger(__name__) + + +@contextmanager +def _fetch(update_info): + """Download a patch from the bucket to a temp file and yield an open safetensors handle.""" + # TODO(@aminediro): writes only 1 safetensors, for very large model, this needs to be multiple files + with tempfile.TemporaryDirectory() as tmpdir: + path = f"{tmpdir}/weights.safetensors" + download_bucket_files(update_info.repo_id, files=[(update_info.filename, path)]) + with safe_open(path, framework="pt", device="cpu") as f: + yield f + + +def _encode_idx(idx: torch.Tensor, encoding: Encoding) -> torch.Tensor: + """Encode absolute int32 indices for storage. Returns a CPU tensor (I32, U16/U32, or U8 bytes).""" + encoding = Encoding(encoding) + if encoding is Encoding.RAW: + return idx.to(torch.int32).cpu().contiguous() + if encoding is Encoding.GAP_DELTA: + return gap_delta_encode(idx).cpu().contiguous() # native uint16/uint32 (dtype = width) + return nvcomp_encode(idx) # Encoding.NVCOMP_CASCADED -> uint8 CPU bytes + + +def _decode_idx(raw: torch.Tensor, encoding: Encoding) -> torch.Tensor: + """Inverse of [`_encode_idx`] → 1D int32 absolute indices. + + Self-describing: ``raw.dtype`` carries the gap-delta width, so no element count is needed. + """ + encoding = Encoding(encoding) + if encoding is Encoding.RAW: + return raw.to(torch.int32) + if encoding is Encoding.GAP_DELTA: + return gap_delta_decode(raw) + return nvcomp_decode(raw) + + +@dataclass +class DeltaWeightTransferInitInfo(WeightTransferInitInfo): + pass + + +@dataclass +class DeltaWeightTransferUpdateInfo(WeightTransferUpdateInfo): + """Per-sync info sent via ``/update_weights`` — just bucket coordinates + kind. + + Names and per-param nnz are read from the downloaded file, so ``num_updates_list`` is not required here (we + override the base validation that would otherwise demand it for sparse). + """ + + repo_id: str = "" # bucket_id + filename: str = "" + + def __post_init__(self) -> None: + if self.update_kind not in ("dense", "sparse_flat"): + raise ValueError(f"Unsupported update_kind: {self.update_kind}") + + +class DeltaWeightTransferEngine(WeightTransferEngine[DeltaWeightTransferInitInfo, DeltaWeightTransferUpdateInfo]): + """Weight transfer engine using an HF Storage Bucket as the data plane.""" + + init_info_cls = DeltaWeightTransferInitInfo + update_info_cls = DeltaWeightTransferUpdateInfo + + def init_transfer_engine(self, init_info: DeltaWeightTransferInitInfo) -> None: + pass + + def receive_weights( + self, + update_info: DeltaWeightTransferUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + """Anchor path: download full safetensors from the bucket and load them directly.""" + t0 = time.time() + with _fetch(update_info) as f: + t_dl = time.time() + for name in f.keys(): + load_weights([(name, f.get_tensor(name))]) + meta = PatchMetadata.from_metadata_dict(f.metadata()) + t_apply = time.time() + logger.info( + "Applied anchor (step %d, %d params) | download %.2fs load %.2fs", + meta.model_version, + meta.num_changed_params, + t_dl - t0, + t_apply - t_dl, + ) + + def receive_sparse_weights( + self, + update_info: DeltaWeightTransferUpdateInfo, + apply_patches: Callable[[list[SparseWeightPatch]], None], + ) -> None: + t0 = time.time() + patches = [] + with _fetch(update_info) as f: + t_dl = time.time() + names, idxs, vals = [], [], [] + meta = PatchMetadata.from_metadata_dict(f.metadata()) + for name, idx, values in iter_sparse_patches(f): + names.append(name) + idxs.append(idx) + vals.append(values) + if names: + device = torch.accelerator.current_device_index() + sizes = [i.numel() for i in idxs] + all_idx = torch.cat(idxs).to(device) + all_val = torch.cat(vals).to(device) + off = 0 + for name, n in zip(names, sizes, strict=False): + patches.append( + SparseWeightPatch(name=name, indices=all_idx[off : off + n], values=all_val[off : off + n]) + ) + off += n + if patches: + apply_patches(patches) + + t_apply = time.time() + logger.info( + "Applied delta (step %d, %d params, sparsity=%.4f) | download %.2fs decode+apply %.2fs", + meta.model_version, + meta.num_changed_params, + meta.sparsity, + t_dl - t0, + t_apply - t_dl, + ) + + def shutdown(self) -> None: + pass + + @staticmethod + def trainer_send_weights(iterator, trainer_args) -> None: + raise NotImplementedError("Use AsyncRolloutWorker.upload_weights / apply_weights instead") + + @staticmethod + def upload( + iterator: Iterator[tuple[str, torch.Tensor, torch.Tensor | None]], + bucket_id: str, + filename: str, + model_version: int = 0, + encoding: Encoding = Encoding.GAP_DELTA, + ) -> PatchMetadata | None: + """Encode params as a safetensors patch and push to the bucket. + + Returns the :class:`PatchMetadata` (also written to the safetensors header), or ``None`` if the iterator was + empty. + """ + tensors, meta = encode_patch(iterator, model_version=model_version, encoding=encoding) + if tensors is None: + return None + # Write to a temp file and upload the path: hf-xet's in-memory `upload_bytes` panics on + # multi-GB buffers (e.g. a full-model anchor); the file path uses the large-file code path. + with tempfile.TemporaryDirectory() as tmpdir: + local_path = f"{tmpdir}/patch.safetensors" + save_file(tensors, local_path, metadata=meta.to_metadata_dict()) + size_mb = os.path.getsize(local_path) / 1e6 + batch_bucket_files(bucket_id, add=[(local_path, filename)]) + logger.info( + "[delta_engine] uploaded %s/%s (%.1f MB, %d params, sparse=%s, enc=%s, sparsity=%.4f)", + bucket_id, + filename, + size_mb, + meta.num_changed_params, + meta.sparse, + meta.encoding.value, + meta.sparsity, + ) + return meta + + +def encode_patch( + iterator: Iterator[tuple[str, torch.Tensor, torch.Tensor | None]], + model_version: int = 0, + encoding: Encoding = Encoding.GAP_DELTA, +) -> tuple[dict[str, torch.Tensor] | None, PatchMetadata | None]: + """Build the safetensors tensor dict + metadata for a patch (no I/O). + + Each item is ``(name, tensor, mask)``: + + - ``mask is None``: full tensor stored as ``name`` (anchor). + - ``mask`` provided: GPU sparse-extract; store encoded indices as ``{name}.idx`` and values as ``{name}.val``. + ``encoding`` is ``"raw"`` (int32) or ``"gap_delta"`` (uint16 gap bytes, uint32 fallback per param). + """ + encoding = Encoding(encoding) + tensors: dict[str, torch.Tensor] = {} + delta_items: list[tuple[str, torch.Tensor, torch.Tensor]] = [] + n_params = 0 + total_changed = 0 + total_elements = 0 + + for name, tensor, mask in iterator: + n_params += 1 + total_elements += tensor.numel() + if mask is None: # anchor: store the full tensor + tensors[name] = tensor.detach().to(torch.bfloat16).cpu().contiguous().clone() + total_changed += tensor.numel() + else: + delta_items.append((name, tensor, mask)) + + for name, idx, vals in extract_sparse_batched(delta_items): + total_changed += idx.numel() + tensors[f"{name}.val"] = vals.to(torch.bfloat16).cpu().contiguous() + tensors[f"{name}.idx"] = _encode_idx(idx, encoding) + + sparse = bool(delta_items) + if not tensors: + return None, None + + meta = PatchMetadata( + sparse=sparse, + model_version=model_version, + num_changed_params=n_params, + total_changed_elements=total_changed, + total_elements=total_elements, + sparsity=1.0 - total_changed / max(total_elements, 1), + encoding=encoding, + ) + return tensors, meta + + +def iter_sparse_patches(f) -> Iterator[tuple[str, torch.Tensor, torch.Tensor]]: + """Yield ``(name, int32 indices, values)`` from an open sparse safetensors handle. + + Flat, self-describing format: param names are recovered from the ``{name}.val`` tensor keys and the index encoding + is a single global header field (the gap-delta width is carried by the index tensor's own dtype). ``f`` is a + ``safetensors.safe_open`` handle. + """ + encoding = PatchMetadata.from_metadata_dict(f.metadata()).encoding + names = sorted({k[: -len(".val")] for k in f.keys() if k.endswith(".val")}) + for name in names: + idx = _decode_idx(f.get_tensor(f"{name}.idx"), encoding) + yield name, idx, f.get_tensor(f"{name}.val") + + +class DeltaWorkerExtension: + """vLLM worker-extension hook (pass via ``--worker-extension-cls``). + + Required: ``--worker-extension-cls`` makes the vLLM *worker* process import this module, which runs the + ``register_engine`` call below so the ``"delta"`` backend exists in the worker (the factory registry is + per-process)""" + + pass + + +if "delta" not in WeightTransferEngineFactory._registry: + WeightTransferEngineFactory.register_engine("delta", DeltaWeightTransferEngine) diff --git a/trl/experimental/async_grpo/weight_diff.py b/trl/experimental/async_grpo/weight_diff.py new file mode 100644 index 00000000000..59143ac4024 --- /dev/null +++ b/trl/experimental/async_grpo/weight_diff.py @@ -0,0 +1,260 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Delta-compressed weight synchronization utilities. + +- ``BF16ChangeDetector``: hooks into the optimizer to detect which bf16 elements actually changed after each step. +- ``PatchMetadata``: structured metadata stored in safetensors headers for both anchor (full) and delta (sparse) weight + files. +""" + +from __future__ import annotations + +import logging +from dataclasses import asdict, dataclass +from enum import Enum + +import torch + +from .delta_codec import Encoding + + +logger = logging.getLogger(__name__) + + +class BF16ChangeDetector: + """Detects which bf16 weights actually changed across an optimizer step. + + Hooks into the optimizer via ``register_step_pre_hook`` / ``register_step_post_hook`` (PyTorch >= 2.1). Snapshots + bf16 values before the step, compares after. + + ``_validated_masks[name]`` is a boolean tensor with True for each element that changed. + """ + + def __init__(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer): + self._validated_masks: dict[str, torch.Tensor] = {} + self._pre_step_bf16: dict[str, torch.Tensor] = {} + + # Match model param names to optimizer param objects via data_ptr() + # (id() doesn't work because Accelerate wraps params as different objects) + model_params = {p.data_ptr(): name.removeprefix("module.") for name, p in model.named_parameters()} + self._param_id_to_name: dict[int, str] = {} + for group in optimizer.param_groups: + for p in group["params"]: + name = model_params.get(p.data_ptr()) + if name is not None: + self._param_id_to_name[id(p)] = name + + logger.info( + "BF16ChangeDetector: matched %d/%d optimizer params", + len(self._param_id_to_name), + sum(1 for _ in model.named_parameters()), + ) + + self._pre_hook_handle = optimizer.register_step_pre_hook(self._pre_step_hook) + self._post_hook_handle = optimizer.register_step_post_hook(self._post_step_hook) + + def _pre_step_hook(self, optimizer, args, kwargs) -> None: + self._pre_step_bf16.clear() + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + name = self._param_id_to_name.get(id(p)) + if name is None: + continue + self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).cpu().clone() + + def _post_step_hook(self, optimizer, args, kwargs) -> None: + self._validated_masks.clear() + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + name = self._param_id_to_name.get(id(p)) + if name is None or name not in self._pre_step_bf16: + continue + self._validated_masks[name] = p.detach().to(torch.bfloat16).cpu() != self._pre_step_bf16[name] + + def close(self): + self._pre_hook_handle.remove() + self._post_hook_handle.remove() + + +def _low_byte(p: torch.Tensor, device: torch.device | str | None = None) -> torch.Tensor: + """Low byte of each element's bf16 bit pattern as ``uint8`` (1 byte/elem). + + bf16 layout (16 bits): ``[sign | exp(8) | mantissa(7)]``. The low byte holds all 7 mantissa bits plus the exponent + LSB, so any sub-ULP / mantissa-level update flips it. Snapshotting just this byte costs 1 B/elem — half a full bf16 + clone. + + Stays on ``p``'s device by default (so the change mask is computed on-GPU and only the sparse payload crosses + PCIe). Pass ``device="cpu"`` to keep the snapshot in host memory when a full on-device snapshot would not fit (e.g. + DeepSpeed-Z2 holds full params/rank). + """ + bf16 = p.detach().to(torch.bfloat16) + if device is not None: + bf16 = bf16.to(device) + bf16 = bf16.contiguous() + # TODO: 0xFF should be configurable either at module level or via some number of precision bits. + return bf16.view(torch.int16).bitwise_and(0xFF).to(torch.uint8) + + +class LowByteChangeDetector: + """Detects changed bf16 weights from a 1-byte-per-element snapshot. + + Like [`BF16ChangeDetector`], hooks the optimizer (PyTorch >= 2.1) and diffs pre/post step — but snapshots only the + **low byte** of each weight's bf16 pattern (1 B/elem) instead of the full bf16 value (2 B/elem). A flipped low byte + implies the bf16 value changed, so the detected mask is a strict subset of the true change set: **no false + positives**, but rare false negatives (mantissa + exp-LSB unchanged while a high exponent/sign bit changed). Those + misses cause inference-side drift, bounded by periodic anchors — set ``validate_recall=True`` to measure the miss + rate. + + ``_validated_masks[name]`` is a boolean tensor, True for each element detected as changed. + + Args: + model ([`~torch.nn.Module`]): + Model whose parameters are tracked. + optimizer ([`~torch.optim.Optimizer`]): + Optimizer to hook. Must expose native ``register_step_*_hook`` (unwrap Accelerate first). + validate_recall (`bool`, *optional*, defaults to `False`): + Also keep a full bf16 snapshot to score low-byte detection against the true diff. Doubles the snapshot + cost; for diagnostics only. + snapshot_to_cpu (`bool`, *optional*, defaults to `False`): + Keep the low-byte snapshot in host memory instead of on the param's device. Masks are then produced on CPU. + Use when a full on-device snapshot would not fit (e.g. DeepSpeed-Z2 holds full params per rank). Default + keeps it on-GPU so the change mask and sparse extraction stay on the device. + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + validate_recall: bool = False, + snapshot_to_cpu: bool = False, + ): + self.validate_recall = validate_recall + self._snap_device = "cpu" if snapshot_to_cpu else None + self._validated_masks: dict[str, torch.Tensor] = {} + self._pre_step_low: dict[str, torch.Tensor] = {} + self._pre_step_bf16: dict[str, torch.Tensor] = {} # only populated when validate_recall + self._accuracy: dict[str, float] = {} + + # Match model param names to optimizer param objects via data_ptr() + # (id() doesn't work because Accelerate wraps params as different objects) + model_params = {p.data_ptr(): name.removeprefix("module.") for name, p in model.named_parameters()} + self._param_id_to_name: dict[int, str] = {} + for group in optimizer.param_groups: + for p in group["params"]: + name = model_params.get(p.data_ptr()) + if name is not None: + self._param_id_to_name[id(p)] = name + + logger.info( + "LowByteChangeDetector: matched %d/%d optimizer params (validate_recall=%s)", + len(self._param_id_to_name), + sum(1 for _ in model.named_parameters()), + validate_recall, + ) + + self._pre_hook_handle = optimizer.register_step_pre_hook(self._pre_step_hook) + self._post_hook_handle = optimizer.register_step_post_hook(self._post_step_hook) + + def _pre_step_hook(self, optimizer, args, kwargs) -> None: + self._pre_step_low.clear() + self._pre_step_bf16.clear() + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + name = self._param_id_to_name.get(id(p)) + if name is None: + continue + self._pre_step_low[name] = _low_byte(p, self._snap_device) + if self.validate_recall: + self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).to(self._snap_device or p.device).clone() + + def _post_step_hook(self, optimizer, args, kwargs) -> None: + self._validated_masks.clear() + total_tp, total_true, total_elements = 0, 0, 0 + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + name = self._param_id_to_name.get(id(p)) + if name is None or name not in self._pre_step_low: + continue + detected = _low_byte(p, self._snap_device) != self._pre_step_low[name] + self._validated_masks[name] = detected + if self.validate_recall: + # True bf16 diff, computed here while p still holds the post-step value. + post_bf16 = p.detach().to(torch.bfloat16).to(self._snap_device or p.device) + true_mask = post_bf16 != self._pre_step_bf16[name] + total_tp += (detected & true_mask).sum().item() + total_true += true_mask.sum().item() + total_elements += true_mask.numel() + if self.validate_recall: + # Low-byte changes ⊆ bf16 changes, so precision is 1.0 by construction; + # recall = fraction of truly-changed elements the low byte detected. + self._accuracy = { + "recall": total_tp / max(total_true, 1), + "true_changed": total_true, + "detected_changed": total_tp, + "total_elements": total_elements, + "sparsity": 1.0 - total_true / max(total_elements, 1), + } + + def get_prediction_accuracy(self) -> dict[str, float]: + return dict(self._accuracy) + + def close(self): + self._pre_hook_handle.remove() + self._post_hook_handle.remove() + + +@dataclass +class PatchMetadata: + format: str = "sparse_weight_patch" + version: str = "1" + sparse: bool = False + model_version: int = 0 + num_changed_params: int = 0 + total_changed_elements: int = 0 + total_elements: int = 0 + sparsity: float = 0.0 + encoding: Encoding = Encoding.GAP_DELTA # index encoding (delta files only) + + def to_metadata_dict(self) -> dict[str, str]: + return {k: (v.value if isinstance(v, Enum) else str(v)) for k, v in asdict(self).items()} + + @classmethod + def from_metadata_dict(cls, d: dict[str, str]) -> PatchMetadata: + field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()} + kwargs = {} + for k, v in d.items(): + if k not in field_types: + continue + ft = field_types[k] + if ft == "int": + kwargs[k] = int(v) + elif ft == "float": + kwargs[k] = float(v) + elif ft == "bool": + kwargs[k] = v == "True" + elif ft == "Encoding": + kwargs[k] = Encoding(v) + else: + kwargs[k] = v + return cls(**kwargs) diff --git a/trl/experimental/async_grpo/weight_transfer.py b/trl/experimental/async_grpo/weight_transfer.py index d0af83c2d2a..d32113df76b 100644 --- a/trl/experimental/async_grpo/weight_transfer.py +++ b/trl/experimental/async_grpo/weight_transfer.py @@ -17,6 +17,7 @@ import requests from accelerate.logging import get_logger +from huggingface_hub import create_bucket from trl.import_utils import is_vllm_available @@ -36,6 +37,10 @@ def __init__( weight_update_info: dict, server_timeout: float = 240.0, init_weight_transfer_timeout: int = 1800, + delta_sync_enabled: bool = False, + delta_sync_repo_id: str | None = None, + delta_sync_anchor_interval: int = 10, + delta_sync_encoding: str = "gap_delta", ): if not is_vllm_available(min_version="0.17.1"): raise ImportError( @@ -46,6 +51,13 @@ def __init__( self.init_weight_transfer_timeout = init_weight_transfer_timeout self._weight_update_info = weight_update_info self.model_update_group = None + # Delta sync (Transport B): sparse patches over an HF Storage Bucket instead of NCCL. + self.delta_sync_enabled = delta_sync_enabled + self._delta_sync_repo_id = delta_sync_repo_id + self._delta_sync_anchor_interval = delta_sync_anchor_interval + self._delta_sync_encoding = delta_sync_encoding + self._delta_model_version = 0 + self._delta_pending: dict | None = None def _wait_for_server_ready_sync(self, timeout_s: float | None = None, poll_interval_s: float = 2.0) -> None: timeout_s = timeout_s if timeout_s is not None else self.server_timeout @@ -72,6 +84,15 @@ def _wait_for_server_ready_sync(self, timeout_s: float | None = None, poll_inter def init_weight_transfer(self) -> None: self._wait_for_server_ready_sync() + if self.delta_sync_enabled: + create_bucket(self._delta_sync_repo_id, exist_ok=True) + requests.post( + f"{self.vllm_server_url}/init_weight_transfer_engine", + json={"init_info": {}}, + timeout=self.init_weight_transfer_timeout, + ) + logger.info("Initialised delta weight transfer (bucket %s)", self._delta_sync_repo_id) + return response = requests.get(f"{self.vllm_server_url}/get_world_size") inference_world_size = response.json()["world_size"] world_size = inference_world_size + 1 @@ -123,6 +144,74 @@ def send_weights(self, iterator) -> None: f"(total send_weights: {time.time() - t0:.1f}s)" ) + def upload_weights(self, iterator) -> None: + """Delta phase 1 (inference still running): encode the changed params as a sparse patch, + upload it to the bucket, and record where [`apply_weights`] should fetch it. + + Every Nth sync is a full anchor; the rest are gap-delta patches. An empty iterator (nothing changed) is a no-op + and leaves ``_delta_pending`` cleared, so the apply is skipped. The phase is explicit — never inferred from + emptiness — so a zero-change step can't trigger an apply. + """ + from .delta_engine import DeltaWeightTransferEngine + + self._delta_model_version += 1 + is_anchor = self._delta_model_version == 1 or self._delta_model_version % self._delta_sync_anchor_interval == 0 + if is_anchor: + iterator = ((name, tensor, None) for name, tensor, _mask in iterator) # strip masks -> full tensors + subdir = "anchors" if is_anchor else "deltas" + filename = f"{subdir}/step_{self._delta_model_version:06d}.safetensors" + meta = DeltaWeightTransferEngine.upload( + iterator=iterator, + bucket_id=self._delta_sync_repo_id, + filename=filename, + model_version=self._delta_model_version, + encoding=self._delta_sync_encoding, + ) + self._delta_pending = ( + None + if meta is None + else { + "repo_id": self._delta_sync_repo_id, + "filename": filename, + "update_kind": "dense" if is_anchor else "sparse_flat", # "dense" <=> anchor + } + ) + + def apply_weights_delta(self) -> None: + """Signal vLLM to fetch and apply the uploaded patch. + + No-op when nothing was uploaded this step; ``_delta_pending`` is cleared up front so a failed apply leaves no + stale state. + """ + if self._delta_pending is None: + return + info, self._delta_pending = self._delta_pending, None + # Anchors are HF-checkpoint-format full tensors; deltas are sparse kernel-format. + self._post_vllm("/start_weight_update", {"is_checkpoint_format": info["update_kind"] == "dense"}) + # vLLM fetches the patch from the bucket inside this call; a full anchor can take minutes, so + # the timeout must cover the download — otherwise a read-timeout would retry into a re-download. + self._post_vllm("/update_weights", {"update_info": info}, retries=5, timeout=1800) + self._post_vllm("/finish_weight_update", {}) + + def _post_vllm(self, path: str, json_body: dict, retries: int = 1, timeout: int = 300) -> None: + """POST to a vLLM server endpoint with bounded retry on 429 / connection errors.""" + url = f"{self.vllm_server_url}{path}" + for attempt in range(retries): + try: + resp = requests.post(url, json=json_body, timeout=timeout) + if resp.status_code < 429: + resp.raise_for_status() + return + status = resp.status_code + except requests.RequestException as e: + logger.warning(f"[weight_sync] POST {path} failed: {e}") + status = "connection error" + if attempt < retries - 1: + wait = min(2**attempt, 30) + logger.warning(f"[weight_sync] POST {path} -> {status}, retry in {wait}s ({attempt + 1}/{retries})") + time.sleep(wait) + raise RuntimeError(f"[weight_sync] POST {path} failed after {retries} attempt(s)") + def pause(self) -> None: t0 = time.time() requests.post(f"{self.vllm_server_url}/pause", params={"mode": "keep"}) From 36d54dbeb35544a2f41f26ba87702002a0c925ab Mon Sep 17 00:00:00 2001 From: aminediro Date: Thu, 4 Jun 2026 09:51:34 +0000 Subject: [PATCH 2/5] Add AsyncGRPO delta weight sync example --- examples/scripts/async_grpo_delta.py | 95 ++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 examples/scripts/async_grpo_delta.py diff --git a/examples/scripts/async_grpo_delta.py b/examples/scripts/async_grpo_delta.py new file mode 100644 index 00000000000..cddf22eb5d3 --- /dev/null +++ b/examples/scripts/async_grpo_delta.py @@ -0,0 +1,95 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +AsyncGRPO with delta weight sync (Transport B: HF Storage Bucket + in-place sparse apply). + +Only changed bf16 weights are encoded as a sparse safetensors patch, uploaded to a bucket, and +applied in place on vLLM via PR #40096 — no full-model broadcast, no vLLM-side snapshot. + +Start the vLLM server with the `delta` backend + worker extension (registers the engine) and the +`transformers` model impl (so vLLM's runtime param names match the trainer's HF names — every +param is then addressable by the in-place sparse apply, no fuse/unfuse remap needed): + +# VLLM_USE_V2_MODEL_RUNNER=0 is required: the in-place sparse apply (apply_sparse_weight_patches, +# vLLM #40096) exists only on the V1 model runner. Without it the server picks V2 and every sparse +# delta update fails (the dense anchors still work, so it silently degrades to anchor-only sync). +CUDA_VISIBLE_DEVICES=1 VLLM_SERVER_DEV_MODE=1 VLLM_USE_V2_MODEL_RUNNER=0 vllm serve Qwen/Qwen3-1.7B \ + --model-impl transformers \ + --worker-extension-cls trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension \ + --weight-transfer-config '{"backend":"delta"}' \ + --max-model-len 2560 + +CUDA_VISIBLE_DEVICES=0 accelerate launch examples/scripts/async_grpo_delta.py +""" + +import logging +import os + +from datasets import load_dataset + +from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer +from trl.rewards import accuracy_reward + + +logging.basicConfig( + level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logging.getLogger("trl").setLevel(logging.INFO) + + +def format_sample(sample): + return { + "prompt": [{"role": "user", "content": sample["question"]}], + "solution": sample["answer"].split("####")[-1].strip(), + } + + +def main() -> None: + dataset = load_dataset("openai/gsm8k", "main", split="train") + dataset = dataset.map(format_sample, remove_columns=dataset.column_names) + + config = AsyncGRPOConfig( + output_dir="./results/async_grpo_delta", + per_device_train_batch_size=1, + num_generations=8, + max_completion_length=512, + max_steps=60, + learning_rate=1e-5, + logging_steps=1, + bf16=True, + report_to="none", + project="async_grpo_delta", + log_completions=True, + # Qwen3 thinking traces blow past the completion cap on GSM8K (truncated -> no answer -> + # zero reward); disable thinking so completions are short and accuracy_reward gets signal. + chat_template_kwargs={"enable_thinking": False}, + # --- delta weight sync (Transport B) --- + delta_sync_enabled=True, + delta_sync_repo_id="aminediroHF/async-grpo-delta-demo", + delta_sync_anchor_interval=20, # full anchor every N syncs; sparse deltas in between + delta_sync_encoding="gap_delta", # raw | gap_delta | nvcomp_cascaded + ) + trainer = AsyncGRPOTrainer( + model="Qwen/Qwen3-1.7B", + args=config, + train_dataset=dataset, + reward_funcs=accuracy_reward, + ) + trainer.train() + + +if __name__ == "__main__": + main() From ceae48b9323874c9688cba1d9fc6c26659d66ff0 Mon Sep 17 00:00:00 2001 From: aminediro Date: Thu, 4 Jun 2026 20:16:39 +0000 Subject: [PATCH 3/5] Use placeholder HF bucket in delta sync example --- examples/scripts/async_grpo_delta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/scripts/async_grpo_delta.py b/examples/scripts/async_grpo_delta.py index cddf22eb5d3..d8bfc46e08b 100644 --- a/examples/scripts/async_grpo_delta.py +++ b/examples/scripts/async_grpo_delta.py @@ -78,7 +78,7 @@ def main() -> None: chat_template_kwargs={"enable_thinking": False}, # --- delta weight sync (Transport B) --- delta_sync_enabled=True, - delta_sync_repo_id="aminediroHF/async-grpo-delta-demo", + delta_sync_repo_id="/async-grpo-delta-demo", # set to a bucket you own delta_sync_anchor_interval=20, # full anchor every N syncs; sparse deltas in between delta_sync_encoding="gap_delta", # raw | gap_delta | nvcomp_cascaded ) From cb4154e30e59e7a0f994f5073b8a4e9ccfd22687 Mon Sep 17 00:00:00 2001 From: aminediro Date: Fri, 5 Jun 2026 13:18:56 +0000 Subject: [PATCH 4/5] Halve change-detector peak memory: free each pre-step snapshot as it's diffed --- trl/experimental/async_grpo/weight_diff.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/trl/experimental/async_grpo/weight_diff.py b/trl/experimental/async_grpo/weight_diff.py index 59143ac4024..a65ad0ebfe8 100644 --- a/trl/experimental/async_grpo/weight_diff.py +++ b/trl/experimental/async_grpo/weight_diff.py @@ -86,7 +86,8 @@ def _post_step_hook(self, optimizer, args, kwargs) -> None: name = self._param_id_to_name.get(id(p)) if name is None or name not in self._pre_step_bf16: continue - self._validated_masks[name] = p.detach().to(torch.bfloat16).cpu() != self._pre_step_bf16[name] + # pop so each pre-step snapshot is freed right after it's diffed (peak ~1× snapshot, not 2×). + self._validated_masks[name] = p.detach().to(torch.bfloat16).cpu() != self._pre_step_bf16.pop(name) def close(self): self._pre_hook_handle.remove() @@ -196,12 +197,14 @@ def _post_step_hook(self, optimizer, args, kwargs) -> None: name = self._param_id_to_name.get(id(p)) if name is None or name not in self._pre_step_low: continue - detected = _low_byte(p, self._snap_device) != self._pre_step_low[name] + # pop (not index) so each param's pre-step snapshot is freed as soon as it's diffed: + # the shrinking snapshot set + growing mask set stay ~1 B/elem total instead of 2. + detected = _low_byte(p, self._snap_device) != self._pre_step_low.pop(name) self._validated_masks[name] = detected if self.validate_recall: # True bf16 diff, computed here while p still holds the post-step value. post_bf16 = p.detach().to(torch.bfloat16).to(self._snap_device or p.device) - true_mask = post_bf16 != self._pre_step_bf16[name] + true_mask = post_bf16 != self._pre_step_bf16.pop(name) total_tp += (detected & true_mask).sum().item() total_true += true_mask.sum().item() total_elements += true_mask.numel() From cd8350624e40619f9dd746f608eb52b5d0c8e47d Mon Sep 17 00:00:00 2001 From: aminediro Date: Fri, 5 Jun 2026 19:44:11 +0000 Subject: [PATCH 5/5] weight diff bucketed --- trl/experimental/async_grpo/weight_diff.py | 144 +++++++++++++-------- 1 file changed, 88 insertions(+), 56 deletions(-) diff --git a/trl/experimental/async_grpo/weight_diff.py b/trl/experimental/async_grpo/weight_diff.py index a65ad0ebfe8..bba22501477 100644 --- a/trl/experimental/async_grpo/weight_diff.py +++ b/trl/experimental/async_grpo/weight_diff.py @@ -114,14 +114,16 @@ def _low_byte(p: torch.Tensor, device: torch.device | str | None = None) -> torc class LowByteChangeDetector: - """Detects changed bf16 weights from a 1-byte-per-element snapshot. + """Detects changed bf16 weights from a 1-byte-per-element snapshot kept in host (CPU) memory. Like [`BF16ChangeDetector`], hooks the optimizer (PyTorch >= 2.1) and diffs pre/post step — but snapshots only the **low byte** of each weight's bf16 pattern (1 B/elem) instead of the full bf16 value (2 B/elem). A flipped low byte implies the bf16 value changed, so the detected mask is a strict subset of the true change set: **no false - positives**, but rare false negatives (mantissa + exp-LSB unchanged while a high exponent/sign bit changed). Those - misses cause inference-side drift, bounded by periodic anchors — set ``validate_recall=True`` to measure the miss - rate. + positives**, but rare false negatives. + + The snapshot is kept on CPU (0 GPU memory footprint) using pre-allocated pinned memory to maximize transfer + bandwidth and avoid runtime allocation overhead. GPU-side diffing is performed in bounded buckets to prevent VRAM + explosion while maintaining maximum PCIe saturation. ``_validated_masks[name]`` is a boolean tensor, True for each element detected as changed. @@ -133,10 +135,10 @@ class LowByteChangeDetector: validate_recall (`bool`, *optional*, defaults to `False`): Also keep a full bf16 snapshot to score low-byte detection against the true diff. Doubles the snapshot cost; for diagnostics only. - snapshot_to_cpu (`bool`, *optional*, defaults to `False`): - Keep the low-byte snapshot in host memory instead of on the param's device. Masks are then produced on CPU. - Use when a full on-device snapshot would not fit (e.g. DeepSpeed-Z2 holds full params per rank). Default - keeps it on-GPU so the change mask and sparse extraction stay on the device. + snapshot_to_cpu (`bool`, *optional*, defaults to `True`): + Kept for backwards compatibility. Snapshots are always kept in host memory. + bucket_mb (`int`, *optional*, defaults to `128`): + Cap peak GPU memory staging to ~3x this size (e.g. ~384 MB for 128 MB) during transfers and diffing. """ def __init__( @@ -144,73 +146,103 @@ def __init__( model: torch.nn.Module, optimizer: torch.optim.Optimizer, validate_recall: bool = False, - snapshot_to_cpu: bool = False, + snapshot_to_cpu: bool = True, + bucket_mb: int = 128, ): self.validate_recall = validate_recall - self._snap_device = "cpu" if snapshot_to_cpu else None + self.bucket_bytes = bucket_mb * 1024 * 1024 self._validated_masks: dict[str, torch.Tensor] = {} - self._pre_step_low: dict[str, torch.Tensor] = {} - self._pre_step_bf16: dict[str, torch.Tensor] = {} # only populated when validate_recall - self._accuracy: dict[str, float] = {} - # Match model param names to optimizer param objects via data_ptr() - # (id() doesn't work because Accelerate wraps params as different objects) model_params = {p.data_ptr(): name.removeprefix("module.") for name, p in model.named_parameters()} - self._param_id_to_name: dict[int, str] = {} + self._params: list[tuple[str, torch.Tensor]] = [] for group in optimizer.param_groups: for p in group["params"]: name = model_params.get(p.data_ptr()) - if name is not None: - self._param_id_to_name[id(p)] = name + if name is not None and p.requires_grad: + self._params.append((name, p)) + + self._buckets: list[ + list[tuple[str, torch.Tensor, int, int]] + ] = [] # List of buckets: [(name, p, offset_in_bucket, length)] + if self._params: + self._device = self._params[0][1].device + + current_bucket = [] + bucket_size = 0 + for name, p in self._params: + current_bucket.append((name, p, bucket_size, p.numel())) + bucket_size += p.numel() + if bucket_size >= self.bucket_bytes: + self._buckets.append(current_bucket) + current_bucket = [] + bucket_size = 0 + if current_bucket: + self._buckets.append(current_bucket) + + # Pre-allocate pinned CPU memory matching the bucket structures + self._pinned_pre_low: list[torch.Tensor] = [] + self._pinned_post_mask: list[torch.Tensor] = [] + for bucket in self._buckets: + total_numel = sum(length for _, _, _, length in bucket) + self._pinned_pre_low.append(torch.empty(total_numel, dtype=torch.uint8, device="cpu", pin_memory=True)) + self._pinned_post_mask.append( + torch.empty(total_numel, dtype=torch.bool, device="cpu", pin_memory=True) + ) - logger.info( - "LowByteChangeDetector: matched %d/%d optimizer params (validate_recall=%s)", - len(self._param_id_to_name), - sum(1 for _ in model.named_parameters()), - validate_recall, - ) + self._pre_step_bf16: dict[str, torch.Tensor] = {} + self._accuracy: dict[str, float] = {} self._pre_hook_handle = optimizer.register_step_pre_hook(self._pre_step_hook) self._post_hook_handle = optimizer.register_step_post_hook(self._post_step_hook) def _pre_step_hook(self, optimizer, args, kwargs) -> None: - self._pre_step_low.clear() self._pre_step_bf16.clear() - for group in optimizer.param_groups: - for p in group["params"]: - if p.grad is None: - continue - name = self._param_id_to_name.get(id(p)) - if name is None: - continue - self._pre_step_low[name] = _low_byte(p, self._snap_device) - if self.validate_recall: - self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).to(self._snap_device or p.device).clone() + if not self._buckets: + return + + # Process and copy bucket-by-bucket asynchronously + for b_idx, bucket in enumerate(self._buckets): + gpu_buf = torch.cat([_low_byte(p).view(-1) for _, p, _, _ in bucket]) + self._pinned_pre_low[b_idx].copy_(gpu_buf, non_blocking=True) + + if self.validate_recall: + for name, p in self._params: + self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).cpu().clone() def _post_step_hook(self, optimizer, args, kwargs) -> None: self._validated_masks.clear() - total_tp, total_true, total_elements = 0, 0, 0 - for group in optimizer.param_groups: - for p in group["params"]: - if p.grad is None: - continue - name = self._param_id_to_name.get(id(p)) - if name is None or name not in self._pre_step_low: - continue - # pop (not index) so each param's pre-step snapshot is freed as soon as it's diffed: - # the shrinking snapshot set + growing mask set stay ~1 B/elem total instead of 2. - detected = _low_byte(p, self._snap_device) != self._pre_step_low.pop(name) - self._validated_masks[name] = detected - if self.validate_recall: - # True bf16 diff, computed here while p still holds the post-step value. - post_bf16 = p.detach().to(torch.bfloat16).to(self._snap_device or p.device) - true_mask = post_bf16 != self._pre_step_bf16.pop(name) - total_tp += (detected & true_mask).sum().item() - total_true += true_mask.sum().item() - total_elements += true_mask.numel() + if not self._buckets: + return + + for b_idx, bucket in enumerate(self._buckets): + cur_buf = torch.cat([_low_byte(p).view(-1) for _, p, _, _ in bucket]) + + # Fetch only this bucket's pre-step snapshot to GPU and diff + prev_buf = self._pinned_pre_low[b_idx].to(self._device, non_blocking=True) + diff = cur_buf != prev_buf + + self._pinned_post_mask[b_idx].copy_(diff, non_blocking=True) + + torch.cuda.synchronize() + + # Unpack masks back to model format + for b_idx, bucket in enumerate(self._buckets): + mask_buf = self._pinned_post_mask[b_idx] + for name, p, offset, length in bucket: + self._validated_masks[name] = mask_buf[offset : offset + length].view(p.shape).clone() + if self.validate_recall: - # Low-byte changes ⊆ bf16 changes, so precision is 1.0 by construction; - # recall = fraction of truly-changed elements the low byte detected. + total_tp, total_true, total_elements = 0, 0, 0 + for name, p in self._params: + if name not in self._validated_masks or name not in self._pre_step_bf16: + continue + detected = self._validated_masks[name] + post_bf16 = p.detach().to(torch.bfloat16).cpu() + true_mask = post_bf16 != self._pre_step_bf16.pop(name) + total_tp += (detected.cpu() & true_mask).sum().item() + total_true += true_mask.sum().item() + total_elements += true_mask.numel() + self._accuracy = { "recall": total_tp / max(total_true, 1), "true_changed": total_true,