Delta weight sync for AsyncGRPO (sparse patches over HF Bucket)#5937
Delta weight sync for AsyncGRPO (sparse patches over HF Bucket)#5937AmineDiro wants to merge 6 commits into
Conversation
| full = param.full_tensor() if isinstance(param, DTensor) else param.detach() | ||
| if full.device != device: | ||
| full = full.to(device) | ||
| yield name, full, mask |
There was a problem hiding this comment.
Periodic anchors omit unchanged params
High Severity
After training starts, _streaming_iter_delta only yields parameters the low-byte detector marked as changed. Periodic anchor uploads still use that iterator and only set mask=None on those tensors, so vLLM receives a partial checkpoint instead of a full model refresh. If no weights changed, the anchor upload can be skipped entirely while inference keeps stale weights.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 36d54db. Configure here.
| kwargs[k] = Encoding(v) | ||
| else: | ||
| kwargs[k] = v | ||
| return cls(**kwargs) |
There was a problem hiding this comment.
Metadata dict type coercion broken
High Severity
PatchMetadata.from_metadata_dict compares each field’s type to the strings "int", "float", "bool", and "Encoding", but dataclass f.type values are actual types (int, float, etc.). Safetensors header values stay as strings, so vLLM logging and any numeric use of model_version or sparsity can raise TypeError after a successful patch apply.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 36d54db. Configure here.
| # vLLM fetches the patch from the bucket inside this call; a full anchor can take minutes, so | ||
| # the timeout must cover the download — otherwise a read-timeout would retry into a re-download. | ||
| self._post_vllm("/update_weights", {"update_info": info}, retries=5, timeout=1800) | ||
| self._post_vllm("/finish_weight_update", {}) |
There was a problem hiding this comment.
Weight update not finished on error
Medium Severity
apply_weights_delta calls /start_weight_update and /update_weights sequentially without try/finally. If /update_weights fails after retries, /finish_weight_update is never sent, while the trainer catches the exception and resumes vLLM anyway, which can leave the server in an inconsistent in-flight weight-update state.
Reviewed by Cursor Bugbot for commit 36d54db. Configure here.
| self._delta_model_version += 1 | ||
| is_anchor = self._delta_model_version == 1 or self._delta_model_version % self._delta_sync_anchor_interval == 0 | ||
| if is_anchor: | ||
| iterator = ((name, tensor, None) for name, tensor, _mask in iterator) # strip masks -> full tensors |
There was a problem hiding this comment.
Periodic anchors omit unchanged weights
High Severity
When is_anchor is true, upload_weights only strips masks on whatever _streaming_iter_delta yields. That iterator skips unchanged parameters, so periodic anchors upload a subset of tensors while vLLM treats them as a dense checkpoint. Unlisted weights stay stale on the inference side.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit e3166c0. Configure here.
| def init_weight_transfer(self) -> None: | ||
| self._wait_for_server_ready_sync() | ||
| if self.delta_sync_enabled: | ||
| create_bucket(self._delta_sync_repo_id, exist_ok=True) |
There was a problem hiding this comment.
Delta sync without bucket id
Medium Severity
delta_sync_enabled can be set while delta_sync_repo_id stays None; init_weight_transfer then calls create_bucket with a null id. Training fails at startup instead of a clear config error.
Reviewed by Cursor Bugbot for commit e3166c0. Configure here.
| name = name.removeprefix("module.") # DDP/FSDP1 wrapping | ||
| mask = masks.get(name) if masks else None | ||
| if masks and (mask is None or not mask.any()): | ||
| continue # unchanged param -> not in this delta |
There was a problem hiding this comment.
FSDP collectives skipped on delta sync
High Severity
_streaming_iter_delta skips full_tensor() for parameters deemed unchanged, but non-main ranks still walk the same loop for FSDP2 collectives. Per-rank low-byte masks can differ under sharding, so ranks may skip different parameters and deadlock or corrupt the gather.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit e3166c0. Configure here.
| full = param.full_tensor() if isinstance(param, DTensor) else param.detach() | ||
| if full.device != device: | ||
| full = full.to(device) | ||
| yield name, full, mask |
There was a problem hiding this comment.
Shard masks vs full tensors
High Severity
For DTensor parameters, full_tensor() is global but change masks come from the local optimizer shard. Sparse encode pairs that mask with the gathered full tensor, so indices and values can be wrong or indexing can fail.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit e3166c0. Configure here.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
I've run a micro-benchmark (no GRPO loop): the policy is held on GPU 0, vLLM serves on GPU 1 (on the same node), and we time N weight syncs to compare FULL sync NCCL vs using buckets for small models.
|
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.
There are 9 total unresolved issues (including 7 from previous reviews).
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit cd83506. Configure here.
| prev = 0 | ||
| for i, (name, _, _) in enumerate(items): | ||
| g = global_idx[prev : bounds[i]] - offsets[i] # local positions within this param | ||
| out.append((name, g.to(torch.int32), flats[i].index_select(0, g))) |
There was a problem hiding this comment.
CPU masks break GPU sparse extract
High Severity
LowByteChangeDetector keeps change masks on CPU, and _streaming_iter_delta passes them unchanged into extract_sparse_batched, which runs nonzero on CPU masks then index_select on GPU weight tensors. PyTorch expects indices on the same device as the source tensor, so sparse delta encoding can fail once any weight changes.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit cd83506. Configure here.
| self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).cpu().clone() | ||
|
|
||
| def _post_step_hook(self, optimizer, args, kwargs) -> None: | ||
| self._validated_masks.clear() |
There was a problem hiding this comment.
Multi-step sync drops earlier deltas
Medium Severity
The change detector clears and rebuilds _validated_masks after each optimizer step, retaining only the latest step’s changes. With weight_sync_steps greater than one, a sync sends just that last step’s sparse patch, so vLLM never receives updates from intervening steps.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit cd83506. Configure here.
|
Discussing with @qgallouedec and running some benchmarks. Even with 1 byte/param as a diff, we can't keep it on the GPU's VRAM, especially for big models. So we have to offload to CPU. But to reduce D2H sync time, I just send the low bits and use buckets in cd83506 and use pinned memory to do async transfer.
|


What does this PR do?
Adds delta (sparse) weight synchronization to the experimental
AsyncGRPO. Instead of broadcasting the full policy to vLLM on every weight sync, the trainer detects which bf16 weights actually changed after the optimizer step, encodes only those as a sparse safetensors patch, pushes it to an HF Storage Bucket. On the inference engine side, vLLM applies it in place no full-model broadcast, we use the NEWSparseWeightPatchfrom vllm (still waiting for a release)At steady state, a delta is ~1–10% of the model, and sparsity rises as the LR decays.
How it works
Trainer side
LowByteChangeDetectorhooks into the optimizer and snapshots only the low byte of each weight's bf16 pattern (1 B/elem, half a full clone). A flipped low byte ⊆ a changed bf16 value, so precision is 1.0 by construction; recall (measured) is 1.0 in normal training. Misses cause inference drift, bounded by periodic full anchors.delta_codec.pydoes the sparse extraction on the GPU:extract_sparse_batchedruns a singlenonzero(to device sync) split across all changed params (instead of onenonzero/param) (~16× less than the old dense-D2H path).raw(int32),gap_delta(uint16 gaps, 2×),nvcomp_cascaded(uses GPU Cascaded delta+bitpack, ~3×, optional added a dep).DeltaWeightTransferEngine.uploadwrites one safetensors patch (anchor = full tensors; delta ={name}.idx+{name}.val) and pushes it to the bucket. Self-describing format: names from the.valkeys, encoding from a global header field, gap-delta width from the index dtype.Lifecycle: the trainer drives vLLM's
start_weight_update/update_weights/finish_weight_updateHTTP routes; the change detector is created incompute_lossbefore the firstoptimizer.step.The NCCL path is untouched.
Requirements/constraints
mainafter v0.22.0, not in a release yet; install from nightly. Newdelta_weight_syncextra added topyproject.toml.--model-impl transformersandVLLM_USE_V2_MODEL_RUNNER=0(apply_sparse_weight_patchesexists only on the V1 runner 😢 ). Example:CUDA_VISIBLE_DEVICES=1 VLLM_SERVER_DEV_MODE=1 VLLM_USE_V2_MODEL_RUNNER=0 \ vllm serve Qwen/Qwen3-1.7B \ --model-impl transformers \ --worker-extension-cls trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension \ --weight-transfer-config '{"backend":"delta"}' \ --max-model-len 2560Tests
raw/gap_delta/nvcomp_cascaded, andextract_sparse_batcheddecode+apply~1 s.References
AI writing disclosure
Note
High Risk
Changes the training–inference weight path and depends on unreleased vLLM sparse apply; failed applies are logged and inference may run on stale weights until the next anchor.
Overview
Adds optional delta weight sync to experimental AsyncGRPO so vLLM can be updated with sparse bf16 patches over an HF Storage Bucket instead of a full NCCL broadcast on every sync.
Trainer path: New
delta_sync_*config flags wire intoWeightTransferClient.LowByteChangeDetector(optimizer hooks) tracks which weights changed;_sync_weightuploads patches while inference still runs, then pauses vLLM and callsapply_weights_delta. Periodic anchors (full tensors) limit drift from low-byte detection misses. NCCL full sync stays the default when delta sync is off.New modules:
delta_codec(GPU sparse extract +raw/gap_delta/nvcomp_cascadedindex encoding),weight_diff(PatchMetadata, change detectors),delta_engine(safetensors encode/upload + vLLMdeltaweight-transfer engine registration viaDeltaWorkerExtension).Also adds
examples/scripts/async_grpo_delta.pyand optionaldelta_weight_syncextra (huggingface-hub, optional nvcomp).Reviewed by Cursor Bugbot for commit cd83506. Bugbot is set up for automated code reviews on this repo. Configure here.