-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Delta weight sync for AsyncGRPO (sparse patches over HF Bucket) #5937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
081a21a
36d54db
e3166c0
ceae48b
cb4154e
cd83506
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| # Copyright 2020-2026 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| AsyncGRPO with delta weight sync (Transport B: HF Storage Bucket + in-place sparse apply). | ||
|
|
||
| Only changed bf16 weights are encoded as a sparse safetensors patch, uploaded to a bucket, and | ||
| applied in place on vLLM via PR #40096 — no full-model broadcast, no vLLM-side snapshot. | ||
|
|
||
| Start the vLLM server with the `delta` backend + worker extension (registers the engine) and the | ||
| `transformers` model impl (so vLLM's runtime param names match the trainer's HF names — every | ||
| param is then addressable by the in-place sparse apply, no fuse/unfuse remap needed): | ||
|
|
||
| # VLLM_USE_V2_MODEL_RUNNER=0 is required: the in-place sparse apply (apply_sparse_weight_patches, | ||
| # vLLM #40096) exists only on the V1 model runner. Without it the server picks V2 and every sparse | ||
| # delta update fails (the dense anchors still work, so it silently degrades to anchor-only sync). | ||
| CUDA_VISIBLE_DEVICES=1 VLLM_SERVER_DEV_MODE=1 VLLM_USE_V2_MODEL_RUNNER=0 vllm serve Qwen/Qwen3-1.7B \ | ||
| --model-impl transformers \ | ||
| --worker-extension-cls trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension \ | ||
| --weight-transfer-config '{"backend":"delta"}' \ | ||
| --max-model-len 2560 | ||
|
|
||
| CUDA_VISIBLE_DEVICES=0 accelerate launch examples/scripts/async_grpo_delta.py | ||
| """ | ||
|
|
||
| import logging | ||
| import os | ||
|
|
||
| from datasets import load_dataset | ||
|
|
||
| from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer | ||
| from trl.rewards import accuracy_reward | ||
|
|
||
|
|
||
| logging.basicConfig( | ||
| level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO), | ||
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | ||
| ) | ||
| logging.getLogger("trl").setLevel(logging.INFO) | ||
|
|
||
|
|
||
| def format_sample(sample): | ||
| return { | ||
| "prompt": [{"role": "user", "content": sample["question"]}], | ||
| "solution": sample["answer"].split("####")[-1].strip(), | ||
| } | ||
|
|
||
|
|
||
| def main() -> None: | ||
| dataset = load_dataset("openai/gsm8k", "main", split="train") | ||
| dataset = dataset.map(format_sample, remove_columns=dataset.column_names) | ||
|
|
||
| config = AsyncGRPOConfig( | ||
| output_dir="./results/async_grpo_delta", | ||
| per_device_train_batch_size=1, | ||
| num_generations=8, | ||
| max_completion_length=512, | ||
| max_steps=60, | ||
| learning_rate=1e-5, | ||
| logging_steps=1, | ||
| bf16=True, | ||
| report_to="none", | ||
| project="async_grpo_delta", | ||
| log_completions=True, | ||
| # Qwen3 thinking traces blow past the completion cap on GSM8K (truncated -> no answer -> | ||
| # zero reward); disable thinking so completions are short and accuracy_reward gets signal. | ||
| chat_template_kwargs={"enable_thinking": False}, | ||
| # --- delta weight sync (Transport B) --- | ||
| delta_sync_enabled=True, | ||
| delta_sync_repo_id="<your-hf-username>/async-grpo-delta-demo", # set to a bucket you own | ||
| delta_sync_anchor_interval=20, # full anchor every N syncs; sparse deltas in between | ||
| delta_sync_encoding="gap_delta", # raw | gap_delta | nvcomp_cascaded | ||
| ) | ||
| trainer = AsyncGRPOTrainer( | ||
| model="Qwen/Qwen3-1.7B", | ||
| args=config, | ||
| train_dataset=dataset, | ||
| reward_funcs=accuracy_reward, | ||
| ) | ||
| trainer.train() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,7 @@ | |
|
|
||
| from .async_grpo_config import AsyncGRPOConfig | ||
| from .async_rollout_worker import AsyncRolloutWorker | ||
| from .weight_diff import LowByteChangeDetector | ||
| from .weight_transfer import WeightTransferClient | ||
|
|
||
|
|
||
|
|
@@ -413,6 +414,7 @@ def __init__( | |
| self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} | ||
| self._train_tokens_start_time = None | ||
| self.model_version = 0 | ||
| self._change_detector: LowByteChangeDetector | None = None # delta sync only; created in compute_loss | ||
| # Create worker and queue on rank 0 | ||
| if self.accelerator.is_main_process: | ||
| if self.train_dataset is None: | ||
|
|
@@ -444,6 +446,10 @@ def __init__( | |
| "packed": True, | ||
| "is_checkpoint_format": True, | ||
| }, | ||
| delta_sync_enabled=self.args.delta_sync_enabled, | ||
| delta_sync_repo_id=self.args.delta_sync_repo_id, | ||
| delta_sync_anchor_interval=self.args.delta_sync_anchor_interval, | ||
| delta_sync_encoding=self.args.delta_sync_encoding, | ||
| ) | ||
| self.rollout_worker = AsyncRolloutWorker( | ||
| model_name=model_name, | ||
|
|
@@ -517,6 +523,10 @@ def _set_signature_columns_if_needed(self): | |
| ] | ||
|
|
||
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): | ||
| # Register the change detector before the first optimizer.step so step 1 is captured and the | ||
| # first delta sync is already sparse (no-op unless delta sync is enabled). | ||
| self._maybe_init_change_detector() | ||
|
|
||
| input_ids = inputs["input_ids"] | ||
| attention_mask = inputs["attention_mask"] | ||
| completion_mask = inputs["completion_mask"] | ||
|
|
@@ -667,8 +677,48 @@ def _streaming_iter(self): | |
| full = full.to(device) | ||
| yield name, full | ||
|
|
||
| def _streaming_iter_delta(self): | ||
| """Like [`_streaming_iter`] but yields ``(name, full, mask)`` for delta sync. With an active | ||
| change detector, only changed params are yielded (their element-level masks); otherwise all params with | ||
| ``mask=None`` (the cold anchor). All ranks must walk this identically so the FSDP2 ``full_tensor()`` | ||
| collectives line up. | ||
| """ | ||
| device = self.accelerator.device | ||
| masks = self._change_detector._validated_masks if self._change_detector is not None else {} | ||
| for name, param in self.model.named_parameters(): | ||
| name = name.removeprefix("module.") # DDP/FSDP1 wrapping | ||
| mask = masks.get(name) if masks else None | ||
| if masks and (mask is None or not mask.any()): | ||
| continue # unchanged param -> not in this delta | ||
| full = param.full_tensor() if isinstance(param, DTensor) else param.detach() | ||
| if full.device != device: | ||
| full = full.to(device) | ||
| yield name, full, mask | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Periodic anchors omit unchanged paramsHigh Severity After training starts, Additional Locations (1)Reviewed by Cursor Bugbot for commit 36d54db. Configure here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shard masks vs full tensorsHigh Severity For Additional Locations (1)Reviewed by Cursor Bugbot for commit e3166c0. Configure here. |
||
|
|
||
| def _maybe_init_change_detector(self): | ||
| """Create the bf16 change detector once the (prepared) optimizer exists, before its first | ||
| step, so the first delta sync is already sparse. No-op unless delta sync is enabled.""" | ||
| if self.args.delta_sync_enabled and self._change_detector is None and getattr(self, "optimizer", None): | ||
| # Unwrap AcceleratedOptimizer to the native torch optimizer (register_step_*_hook). | ||
| raw_optimizer = getattr(self.optimizer, "optimizer", self.optimizer) | ||
| self._change_detector = LowByteChangeDetector(self.model, raw_optimizer) | ||
|
|
||
| def _sync_weight(self): | ||
| t0 = time.time() | ||
| is_delta = self.args.delta_sync_enabled | ||
|
|
||
| if is_delta: | ||
| # Delta phase 1: upload the sparse patch to the bucket while inference keeps running. | ||
| logger.info("Weight sync: uploading delta patch (inference still running)...") | ||
| if self.accelerator.is_main_process and self.weight_transfer: | ||
| self.weight_transfer.upload_weights(self._streaming_iter_delta()) | ||
| else: | ||
| # Non-rank-0 still walks the iterator so full_tensor() collectives complete. | ||
| for _ in self._streaming_iter_delta(): | ||
| pass | ||
| self.accelerator.wait_for_everyone() | ||
|
|
||
| # Pause vllm both delta and full | ||
| logger.info("Weight sync: pausing vLLM...") | ||
| if self.accelerator.is_main_process and self.weight_transfer: | ||
| self.weight_transfer.pause() | ||
|
|
@@ -679,7 +729,13 @@ def _sync_weight(self): | |
| t_barrier = time.time() | ||
|
|
||
| logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)") | ||
| if self.accelerator.is_main_process and self.weight_transfer: | ||
| if is_delta: | ||
| if self.accelerator.is_main_process and self.weight_transfer: | ||
| try: | ||
| self.weight_transfer.apply_weights_delta() | ||
| except Exception as e: | ||
| logger.warning(f"Weight sync: apply failed ({e}), skipping, vLLM will use stale weights") | ||
| elif self.accelerator.is_main_process and self.weight_transfer: | ||
| self.weight_transfer.send_weights(self._streaming_iter()) | ||
| else: | ||
| # Non-rank-0 processes must still participate in full_tensor() collectives for FSDP2. | ||
|
|
||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FSDP collectives skipped on delta sync
High Severity
_streaming_iter_deltaskipsfull_tensor()for parameters deemed unchanged, but non-main ranks still walk the same loop for FSDP2 collectives. Per-rank low-byte masks can differ under sharding, so ranks may skip different parameters and deadlock or corrupt the gather.Additional Locations (1)
trl/experimental/async_grpo/async_grpo_trainer.py#L715-L718Reviewed by Cursor Bugbot for commit e3166c0. Configure here.