Skip to content

Delta weight sync for AsyncGRPO (sparse patches over HF Bucket)#5937

Open
AmineDiro wants to merge 6 commits into
mainfrom
delta-weight-sync-v3
Open

Delta weight sync for AsyncGRPO (sparse patches over HF Bucket)#5937
AmineDiro wants to merge 6 commits into
mainfrom
delta-weight-sync-v3

Conversation

@AmineDiro
Copy link
Copy Markdown
Member

@AmineDiro AmineDiro commented Jun 4, 2026

What does this PR do?

Adds delta (sparse) weight synchronization to the experimental AsyncGRPO. Instead of broadcasting the full policy to vLLM on every weight sync, the trainer detects which bf16 weights actually changed after the optimizer step, encodes only those as a sparse safetensors patch, pushes it to an HF Storage Bucket. On the inference engine side, vLLM applies it in place no full-model broadcast, we use the NEW SparseWeightPatch from vllm (still waiting for a release)

At steady state, a delta is ~1–10% of the model, and sparsity rises as the LR decays.

How it works

Trainer side

  • LowByteChangeDetector hooks into the optimizer and snapshots only the low byte of each weight's bf16 pattern (1 B/elem, half a full clone). A flipped low byte ⊆ a changed bf16 value, so precision is 1.0 by construction; recall (measured) is 1.0 in normal training. Misses cause inference drift, bounded by periodic full anchors.
  • delta_codec.py does the sparse extraction on the GPU: extract_sparse_batched runs a singlenonzero (to device sync) split across all changed params (instead of one nonzero/param) (~16× less than the old dense-D2H path).
  • Index encodings (main difference with the previous implementation): raw (int32), gap_delta (uint16 gaps, 2×), nvcomp_cascaded (uses GPU Cascaded delta+bitpack, ~3×, optional added a dep).
  • DeltaWeightTransferEngine.upload writes one safetensors patch (anchor = full tensors; delta ={name}.idx + {name}.val) and pushes it to the bucket. Self-describing format: names from the .val keys, encoding from a global header field, gap-delta width from the index dtype.

Lifecycle: the trainer drives vLLM's start_weight_update / update_weights /finish_weight_update HTTP routes; the change detector is created in compute_loss before the first optimizer.step.
The NCCL path is untouched.

Requirements/constraints

  • vLLM with sparse weight transfer ([Frontend][Core] Add sparse NCCL weight transfer support for in-place updates vllm-project/vllm#40096) merged to main after v0.22.0, not in a release yet; install from nightly. New delta_weight_sync extra added to pyproject.toml.
  • Serve with --model-impl transformers and VLLM_USE_V2_MODEL_RUNNER=0 (apply_sparse_weight_patches exists only on the V1 runner 😢 ). Example:
    CUDA_VISIBLE_DEVICES=1 VLLM_SERVER_DEV_MODE=1 VLLM_USE_V2_MODEL_RUNNER=0 \
    vllm serve Qwen/Qwen3-1.7B \
          --model-impl transformers \
          --worker-extension-cls trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension \
          --weight-transfer-config '{"backend":"delta"}' \
          --max-model-len 2560

🔴🔴 Sparse apply is TP=1 / PP=1 (enforced by vLLM). Dense/small models today; sharded (TP>1 / EP /
fused-MoE) is future work.

Tests

  • Low-byte detector: recall 1.0 / precision 1.0 vs the full bf16 diff
  • Codec/file round-trip: bit-exact for raw / gap_delta / nvcomp_cascaded, and extract_sparse_batched
  • End to end AsyncGRPO (Qwen3-1.7B, GSM8K): sparse deltas apply with 0 failures (~91→99% sparse), reward improves 0.10 → 0.50 through the delta path; receiver timing anchor download-bound, vllm sidedecode+apply ~1 s.

References

AI writing disclosure

  • AI-assisted: parts were suggested/iterated with an AI tool, written and reviewed by a human.

Note

High Risk
Changes the training–inference weight path and depends on unreleased vLLM sparse apply; failed applies are logged and inference may run on stale weights until the next anchor.

Overview
Adds optional delta weight sync to experimental AsyncGRPO so vLLM can be updated with sparse bf16 patches over an HF Storage Bucket instead of a full NCCL broadcast on every sync.

Trainer path: New delta_sync_* config flags wire into WeightTransferClient. LowByteChangeDetector (optimizer hooks) tracks which weights changed; _sync_weight uploads patches while inference still runs, then pauses vLLM and calls apply_weights_delta. Periodic anchors (full tensors) limit drift from low-byte detection misses. NCCL full sync stays the default when delta sync is off.

New modules: delta_codec (GPU sparse extract + raw / gap_delta / nvcomp_cascaded index encoding), weight_diff (PatchMetadata, change detectors), delta_engine (safetensors encode/upload + vLLM delta weight-transfer engine registration via DeltaWorkerExtension).

Also adds examples/scripts/async_grpo_delta.py and optional delta_weight_sync extra (huggingface-hub, optional nvcomp).

Reviewed by Cursor Bugbot for commit cd83506. Bugbot is set up for automated code reviews on this repo. Configure here.

full = param.full_tensor() if isinstance(param, DTensor) else param.detach()
if full.device != device:
full = full.to(device)
yield name, full, mask
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Periodic anchors omit unchanged params

High Severity

After training starts, _streaming_iter_delta only yields parameters the low-byte detector marked as changed. Periodic anchor uploads still use that iterator and only set mask=None on those tensors, so vLLM receives a partial checkpoint instead of a full model refresh. If no weights changed, the anchor upload can be skipped entirely while inference keeps stale weights.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 36d54db. Configure here.

kwargs[k] = Encoding(v)
else:
kwargs[k] = v
return cls(**kwargs)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metadata dict type coercion broken

High Severity

PatchMetadata.from_metadata_dict compares each field’s type to the strings "int", "float", "bool", and "Encoding", but dataclass f.type values are actual types (int, float, etc.). Safetensors header values stay as strings, so vLLM logging and any numeric use of model_version or sparsity can raise TypeError after a successful patch apply.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 36d54db. Configure here.

# vLLM fetches the patch from the bucket inside this call; a full anchor can take minutes, so
# the timeout must cover the download — otherwise a read-timeout would retry into a re-download.
self._post_vllm("/update_weights", {"update_info": info}, retries=5, timeout=1800)
self._post_vllm("/finish_weight_update", {})
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weight update not finished on error

Medium Severity

apply_weights_delta calls /start_weight_update and /update_weights sequentially without try/finally. If /update_weights fails after retries, /finish_weight_update is never sent, while the trainer catches the exception and resumes vLLM anyway, which can leave the server in an inconsistent in-flight weight-update state.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 36d54db. Configure here.

self._delta_model_version += 1
is_anchor = self._delta_model_version == 1 or self._delta_model_version % self._delta_sync_anchor_interval == 0
if is_anchor:
iterator = ((name, tensor, None) for name, tensor, _mask in iterator) # strip masks -> full tensors
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Periodic anchors omit unchanged weights

High Severity

When is_anchor is true, upload_weights only strips masks on whatever _streaming_iter_delta yields. That iterator skips unchanged parameters, so periodic anchors upload a subset of tensors while vLLM treats them as a dense checkpoint. Unlisted weights stay stale on the inference side.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit e3166c0. Configure here.

def init_weight_transfer(self) -> None:
self._wait_for_server_ready_sync()
if self.delta_sync_enabled:
create_bucket(self._delta_sync_repo_id, exist_ok=True)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delta sync without bucket id

Medium Severity

delta_sync_enabled can be set while delta_sync_repo_id stays None; init_weight_transfer then calls create_bucket with a null id. Training fails at startup instead of a clear config error.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit e3166c0. Configure here.

name = name.removeprefix("module.") # DDP/FSDP1 wrapping
mask = masks.get(name) if masks else None
if masks and (mask is None or not mask.any()):
continue # unchanged param -> not in this delta
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FSDP collectives skipped on delta sync

High Severity

_streaming_iter_delta skips full_tensor() for parameters deemed unchanged, but non-main ranks still walk the same loop for FSDP2 collectives. Per-rank low-byte masks can differ under sharding, so ranks may skip different parameters and deadlock or corrupt the gather.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit e3166c0. Configure here.

full = param.full_tensor() if isinstance(param, DTensor) else param.detach()
if full.device != device:
full = full.to(device)
yield name, full, mask
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shard masks vs full tensors

High Severity

For DTensor parameters, full_tensor() is global but change masks come from the local optimizer shard. Sparse encode pairs that mask with the gathered full tensor, so indices and values can be wrong or indexing can fail.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit e3166c0. Configure here.

@bot-ci-comment
Copy link
Copy Markdown

bot-ci-comment Bot commented Jun 4, 2026

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@AmineDiro AmineDiro changed the title Delta weight sync for AsyncGRPO (sparse patches over an HF Bucket) Delta weight sync for AsyncGRPO (sparse patches over HF Bucket) Jun 5, 2026
@AmineDiro
Copy link
Copy Markdown
Member Author

I've run a micro-benchmark (no GRPO loop): the policy is held on GPU 0, vLLM serves on GPU 1 (on the same node), and we time N weight syncs to compare FULL sync NCCL vs using buckets for small models.

model full size NCCL (full) delta 0.90 delta 0.99 delta 0.999 delta full-anchor
Qwen3-4B 8.04 GB 0.17 s 12.7 s 3.0 s 1.3 s 29.9 s
Qwen2.5-7B 15.23 GB 0.15 s 27.2 s 4.5 s 1.8 s 46.2 s
gemma-2-9b 18.48 GB 0.19 s 27.5 s 5.3 s 2.1 s 53.9 s

Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.

There are 9 total unresolved issues (including 7 from previous reviews).

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit cd83506. Configure here.

prev = 0
for i, (name, _, _) in enumerate(items):
g = global_idx[prev : bounds[i]] - offsets[i] # local positions within this param
out.append((name, g.to(torch.int32), flats[i].index_select(0, g)))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU masks break GPU sparse extract

High Severity

LowByteChangeDetector keeps change masks on CPU, and _streaming_iter_delta passes them unchanged into extract_sparse_batched, which runs nonzero on CPU masks then index_select on GPU weight tensors. PyTorch expects indices on the same device as the source tensor, so sparse delta encoding can fail once any weight changes.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit cd83506. Configure here.

self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).cpu().clone()

def _post_step_hook(self, optimizer, args, kwargs) -> None:
self._validated_masks.clear()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multi-step sync drops earlier deltas

Medium Severity

The change detector clears and rebuilds _validated_masks after each optimizer step, retaining only the latest step’s changes. With weight_sync_steps greater than one, a sync sends just that last step’s sparse patch, so vLLM never receives updates from intervening steps.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit cd83506. Configure here.

@AmineDiro
Copy link
Copy Markdown
Member Author

AmineDiro commented Jun 5, 2026

Discussing with @qgallouedec and running some benchmarks. Even with 1 byte/param as a diff, we can't keep it on the GPU's VRAM, especially for big models. So we have to offload to CPU. But to reduce D2H sync time, I just send the low bits and use buckets in cd83506 and use pinned memory to do async transfer.

Mode & Sparsity Baseline Optimized H Speedup Peak GPU Overhead Retained GPU
CPU (Sparsity 0.90) 5.147s 1.652s 3.11× 1.94 GB 0.0 GB
CPU (Sparsity 0.99) 4.668s 1.631s 2.86× 1.94 GB 0.0 GB
CPU (Sparsity 0.999) 5.097s 1.632s 3.12× 1.94 GB 0.0 GB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant