Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions slime/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import random
from argparse import Namespace
from itertools import accumulate

Expand Down Expand Up @@ -791,6 +792,15 @@ def update_weights(self) -> None: # type: ignore[override]
dist.barrier(group=get_gloo_group())

self.weight_updater.update_weights()

if self.args.ci_test and len(rollout_engines) > 0:
engine = random.choice(rollout_engines)
engine_version = ray.get(engine.get_weight_version.remote())
if str(engine_version) != str(self.weight_updater.weight_version):
raise RuntimeError(
f"Weight version mismatch! Engine: {engine_version}, Updater: {self.weight_updater.weight_version}"
)

clear_memory()

def _create_ref_model(self, ref_load_path: str | None):
Expand Down
16 changes: 8 additions & 8 deletions slime/backends/fsdp_utils/update_weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class UpdateWeight(abc.ABC):
def __init__(self, args: Namespace, model: torch.nn.Module) -> None:
self.args = args
self.model = model
self.weight_version = 0

@abc.abstractmethod
def connect_rollout_engines(
Expand All @@ -43,6 +44,7 @@ def connect_rollout_engines(
pass

def update_weights(self) -> None:
self.weight_version += 1
Comment on lines 46 to +47
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Initialize weight_version for distributed weight updater

UpdateWeight.update_weights now increments self.weight_version (lines 46‑47), but UpdateWeightFromDistributed overrides __init__ without calling the base constructor, so instances created when args.colocate is False never define weight_version. The next call to update_weights() raises AttributeError before any tensors are sent, breaking non-colocated training runs. Make sure UpdateWeightFromDistributed sets self.weight_version (e.g., by invoking super().init) so distributed weight sync can run.

Useful? React with 👍 / 👎.

bucket = []
bucket_size = 0
for name, param in self.model.state_dict().items():
Expand Down Expand Up @@ -71,10 +73,10 @@ def update_weights(self) -> None:

def wait_and_update_bucket_weights(self, bucket):
bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket]
self.update_bucket_weights(bucket)
self.update_bucket_weights(bucket, weight_version=self.weight_version)

@abc.abstractmethod
def update_bucket_weights(self, named_tensors) -> None:
def update_bucket_weights(self, named_tensors, weight_version=None) -> None:
pass


Expand Down Expand Up @@ -114,7 +116,7 @@ def connect_rollout_engines(
# Calculate TP rank within this SGLang engine group
self.tp_rank = dist.get_rank() - start_rank

def update_bucket_weights(self, named_tensors) -> None:
def update_bucket_weights(self, named_tensors, weight_version=None) -> None:
monkey_patch_torch_reductions()
# Use flattened bucket approach similar to Megatron
logger.info("Using flattened tensor bucket")
Expand Down Expand Up @@ -162,6 +164,7 @@ def update_bucket_weights(self, named_tensors) -> None:
"serialized_named_tensors": [tensors[i] for tensors in gathered_serialized_batches],
"load_format": "flattened_bucket",
"flush_cache": False,
"weight_version": str(weight_version),
}
ref = self._ipc_engine.update_weights_from_tensor.remote(**kwargs)
ray.get(ref)
Expand All @@ -174,10 +177,6 @@ def update_bucket_weights(self, named_tensors) -> None:
class UpdateWeightFromDistributed(UpdateWeight):
"""Broadcast weights via a temporary NCCL group to rollout engines."""

def __init__(self, args: Namespace, model: torch.nn.Module) -> None:
self.args = args
self.model = model

def connect_rollout_engines(
self,
rollout_engines: Sequence[ActorHandle],
Expand Down Expand Up @@ -220,7 +219,7 @@ def connect_rollout_engines(
)
ray.get(refs)

def update_bucket_weights(self, named_tensors) -> None:
def update_bucket_weights(self, named_tensors, weight_version=None) -> None:
"""Send names/dtypes/shapes metadata to engines, then broadcast tensors.

Ensures tensors are contiguous; when `world_size == 1`, converts DTensors
Expand All @@ -235,6 +234,7 @@ def update_bucket_weights(self, named_tensors) -> None:
dtypes=[param.dtype for _, param in named_tensors],
shapes=[param.shape for _, param in named_tensors],
group_name=self._group_name,
weight_version=str(weight_version),
)
for engine in self.rollout_engines
]
Expand Down
9 changes: 9 additions & 0 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import random
import socket
from argparse import Namespace
from contextlib import nullcontext
Expand Down Expand Up @@ -474,6 +475,14 @@ def update_weights(self) -> None:
self.weight_updater.update_weights()
print_memory("after update_weights")

if self.args.ci_test and len(rollout_engines) > 0:
engine = random.choice(rollout_engines)
engine_version = ray.get(engine.get_weight_version.remote())
if str(engine_version) != str(self.weight_updater.weight_version):
raise RuntimeError(
f"Weight version mismatch! Engine: {engine_version}, Updater: {self.weight_updater.weight_version}"
)

if getattr(self.args, "keep_old_actor", False):
if self.args.update_weights_interval == 1:
logger.info("updating model queue: rollout_actor -> old_actor, actor -> rollout_actor")
Expand Down
4 changes: 2 additions & 2 deletions slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _split_train_data_by_dp(self, data, dp_size):

def init_rollout_engines(args, pg, all_rollout_engines):
if args.debug_train_only:
return 0, None
return 0

num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node)
num_engines = args.rollout_num_gpus // num_gpu_per_engine
Expand Down Expand Up @@ -391,7 +391,7 @@ def init_rollout_engines(args, pg, all_rollout_engines):
num_new_engines = len(rollout_engines)

if num_new_engines == 0:
return num_new_engines, None
return num_new_engines

if args.rollout_external:
addr_and_ports = _allocate_rollout_engine_addr_and_ports_external(args=args, rollout_engines=rollout_engines)
Expand Down
Loading