Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 27 additions & 29 deletions skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,53 +156,48 @@ def __init__(
# so param.numel()==0 and bucketing collapses to a single bucket.
# By the time extract_weights runs, the dispatch has already
# called prepare_for_weight_sync → _ensure_on_gpu.
self.param_buckets = None
self.bucket_index_groups = None
self._buckets_initialized = False

def _init_param_buckets(self):
"""Initialize parameter buckets for packing."""
# Get conversion tasks from bridge
weight_conversion_tasks = self.bridge.get_conversion_tasks(self.actor_module)
"""Compute bucket boundaries (index groups) from parameter sizes.

# Calculate size for each parameter
param_info = []
Only the bucket *structure* (which task indices go in which bucket) is
persisted. The actual ``WeightConversionTask`` objects are rebuilt on
every ``extract_weights`` call so that mapping objects start with clean
PP-collective caches, avoiding stale cached state across offload/reload
and training cycles.
"""
weight_conversion_tasks = self.bridge.get_conversion_tasks(self.actor_module)

def calculate_size_in_bytes(param, tp_size, ep_size):
if param is None:
# need to broadcast for other pp ranks
size_in_bytes = None
else:
# Calculate size for this parameter
prec_to_bytes = {
torch.bfloat16: 2,
torch.float32: 4,
}
scale = prec_to_bytes[self.training_dtype] / prec_to_bytes[param.dtype]
size_in_bytes = param.element_size() * param.numel() * tp_size * ep_size * scale

# Broadcast size_in_bytes across pipeline parallel ranks
return broadcast_object_across_pp_ranks(size_in_bytes)

for task in weight_conversion_tasks:
param_info.append(
(
task,
calculate_size_in_bytes(
task.param_weight,
task.mapping.tp_size,
task.mapping.ep_size if task.mapping.is_expert else 1,
),
)
sizes = [
calculate_size_in_bytes(
task.param_weight,
task.mapping.tp_size,
task.mapping.ep_size if task.mapping.is_expert else 1,
)
for task in weight_conversion_tasks
]

# Group parameters into buckets based on size threshold
self.param_buckets = [[]]
self.bucket_index_groups: list[list[int]] = [[]]
curr_size = 0
for task, size in param_info:
for idx, size in enumerate(sizes):
if curr_size + size > self.bucket_size_threshold_GB * 1024**3:
self.param_buckets.append([])
self.bucket_index_groups.append([])
curr_size = 0
self.param_buckets[-1].append(task)
self.bucket_index_groups[-1].append(idx)
curr_size += size

def get_weight_metadata(self, dtype: torch.dtype) -> dict:
Expand Down Expand Up @@ -262,7 +257,6 @@ def extract_weights(self, dtype: torch.dtype):
)

for name, tensor in hf_params_generator:
# Move to device and convert dtype
tensor = tensor.to(device=device, dtype=dtype, non_blocking=True)

yield WeightChunk(
Expand All @@ -272,12 +266,16 @@ def extract_weights(self, dtype: torch.dtype):
tensors=[tensor],
)
else:
# Bucketing mode: iterate over buckets, yield one chunk per bucket
for bucket in self.param_buckets:
# Build fresh tasks each sync so mapping objects have clean
# PP-collective caches; reuse the pre-computed bucket structure.
fresh_tasks = self.bridge.get_conversion_tasks(self.actor_module)

for index_group in self.bucket_index_groups:
bucket_tasks = [fresh_tasks[i] for i in index_group]
hf_params_generator = self.bridge.export_hf_weights(
self.actor_module,
show_progress=False,
conversion_tasks=bucket,
conversion_tasks=bucket_tasks,
)

# Collect all parameters in this bucket into one chunk
Expand Down
Loading