From 84c296aaf111bfcc23ff1dd35576ee49076f59fe Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 19 Mar 2026 00:00:17 +0000 Subject: [PATCH 1/2] retrieve get_conversion_tasks() on every weight sync to guarantee clean refs --- .../workers/megatron/megatron_worker.py | 60 ++++++++++--------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 2a3a30ba52..9fbba11eb5 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -156,53 +156,48 @@ def __init__( # so param.numel()==0 and bucketing collapses to a single bucket. # By the time extract_weights runs, the dispatch has already # called prepare_for_weight_sync → _ensure_on_gpu. - self.param_buckets = None + self.bucket_index_groups = None self._buckets_initialized = False def _init_param_buckets(self): - """Initialize parameter buckets for packing.""" - # Get conversion tasks from bridge - weight_conversion_tasks = self.bridge.get_conversion_tasks(self.actor_module) + """Compute bucket boundaries (index groups) from parameter sizes. - # Calculate size for each parameter - param_info = [] + Only the bucket *structure* (which task indices go in which bucket) is + persisted. The actual ``WeightConversionTask`` objects are rebuilt on + every ``extract_weights`` call so that mapping objects start with clean + PP-collective caches, avoiding stale cached state across offload/reload + and training cycles. + """ + weight_conversion_tasks = self.bridge.get_conversion_tasks(self.actor_module) def calculate_size_in_bytes(param, tp_size, ep_size): if param is None: - # need to broadcast for other pp ranks size_in_bytes = None else: - # Calculate size for this parameter prec_to_bytes = { torch.bfloat16: 2, torch.float32: 4, } scale = prec_to_bytes[self.training_dtype] / prec_to_bytes[param.dtype] size_in_bytes = param.element_size() * param.numel() * tp_size * ep_size * scale - - # Broadcast size_in_bytes across pipeline parallel ranks return broadcast_object_across_pp_ranks(size_in_bytes) - for task in weight_conversion_tasks: - param_info.append( - ( - task, - calculate_size_in_bytes( - task.param_weight, - task.mapping.tp_size, - task.mapping.ep_size if task.mapping.is_expert else 1, - ), - ) + sizes = [ + calculate_size_in_bytes( + task.param_weight, + task.mapping.tp_size, + task.mapping.ep_size if task.mapping.is_expert else 1, ) + for task in weight_conversion_tasks + ] - # Group parameters into buckets based on size threshold - self.param_buckets = [[]] + self.bucket_index_groups: list[list[int]] = [[]] curr_size = 0 - for task, size in param_info: + for idx, size in enumerate(sizes): if curr_size + size > self.bucket_size_threshold_GB * 1024**3: - self.param_buckets.append([]) + self.bucket_index_groups.append([]) curr_size = 0 - self.param_buckets[-1].append(task) + self.bucket_index_groups[-1].append(idx) curr_size += size def get_weight_metadata(self, dtype: torch.dtype) -> dict: @@ -262,7 +257,6 @@ def extract_weights(self, dtype: torch.dtype): ) for name, tensor in hf_params_generator: - # Move to device and convert dtype tensor = tensor.to(device=device, dtype=dtype, non_blocking=True) yield WeightChunk( @@ -272,12 +266,16 @@ def extract_weights(self, dtype: torch.dtype): tensors=[tensor], ) else: - # Bucketing mode: iterate over buckets, yield one chunk per bucket - for bucket in self.param_buckets: + # Build fresh tasks each sync so mapping objects have clean + # PP-collective caches; reuse the pre-computed bucket structure. + fresh_tasks = self.bridge.get_conversion_tasks(self.actor_module) + + for index_group in self.bucket_index_groups: + bucket_tasks = [fresh_tasks[i] for i in index_group] hf_params_generator = self.bridge.export_hf_weights( self.actor_module, show_progress=False, - conversion_tasks=bucket, + conversion_tasks=bucket_tasks, ) # Collect all parameters in this bucket into one chunk @@ -817,7 +815,11 @@ async def broadcast_to_inference_engines( torch.cuda.empty_cache() # Extract and send weights using the sender created at init time + # from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE + # if _SKYRL_USE_NEW_INFERENCE: weight_metadata = self.weight_extractor.get_weight_metadata(generator_dtype) + # else: + # weight_metadata = None await self._weight_transfer_sender.send_chunks( self.weight_extractor.extract_weights(generator_dtype), weight_metadata=weight_metadata, From 7fe787686ebbc80b8eedab3eca0e41eadfc1b4df Mon Sep 17 00:00:00 2001 From: Eric Tang <46737979+erictang000@users.noreply.github.com> Date: Wed, 18 Mar 2026 17:16:20 -0700 Subject: [PATCH 2/2] Update skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../backends/skyrl_train/workers/megatron/megatron_worker.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 9fbba11eb5..636c518154 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -815,11 +815,7 @@ async def broadcast_to_inference_engines( torch.cuda.empty_cache() # Extract and send weights using the sender created at init time - # from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE - # if _SKYRL_USE_NEW_INFERENCE: weight_metadata = self.weight_extractor.get_weight_metadata(generator_dtype) - # else: - # weight_metadata = None await self._weight_transfer_sender.send_chunks( self.weight_extractor.extract_weights(generator_dtype), weight_metadata=weight_metadata,