diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 925a4b8a65..152348590a 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -248,7 +248,7 @@ def loss_func(logits, data): action_mask = data.get("action_mask") num_microbatches = data.get("num_microbatches") - dp_size = mpu.get_data_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size(with_context_parallel=True) tp_grp = mpu.get_tensor_model_parallel_group() tp_rank = mpu.get_tensor_model_parallel_rank() @@ -346,14 +346,18 @@ def loss_func(logits, data): # when summing across the entire minibatch (see `apply_loss_reduction_to_advantages_minibatch`). # Megatron divides loss by num_microbatches # (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248) - # and the data parallel all-reduce averages gradients across dp_size + # and the data parallel all-reduce averages gradients across dp_size (including CP ranks) # (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285) # so we multiply by both factors to recover the correct sum reduction. grad_sum_correction_factor = num_microbatches * dp_size # NOTE: The KL and entropy loss terms are not pre-scaled, # so we just average them across microbatches and DP workers. - loss = policy_loss * grad_sum_correction_factor + kl_loss_term - entropy_loss_term + # Megatron's DDP averages gradients across the full DP+CP group, + # but KL/entropy should only be averaged across DP (not CP). + # Multiply by cp_size to counteract the unwanted CP averaging. + cp_size = mpu.get_context_parallel_world_size() + loss = policy_loss * grad_sum_correction_factor + (kl_loss_term - entropy_loss_term) * cp_size unscaled_loss = loss / grad_sum_correction_factor # Build per-sequence loss_fn_outputs with logprobs. diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 138dfb2916..363267c7f3 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -740,7 +740,7 @@ def forward_backward( # NOTE: Sum loss metrics because scaling is already applied at the advantage level status = reduce_metrics(all_metrics, sum_loss_metrics=sum_loss_metrics) status["policy_lr"] = self.optimizer.param_groups[0]["lr"] - group = mpu.get_data_parallel_group(with_context_parallel=True) + group = mpu.get_data_parallel_group(with_context_parallel=False) status = all_reduce_metrics(status, self.strategy, group=group, sum_loss_metrics=sum_loss_metrics) # Add loss_fn_outputs back (not reduced, kept as list) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py index 6befe51e43..7c959e68fd 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py @@ -590,7 +590,7 @@ async def test_megatron_train( # the entropy calculation is different (fsdp has random logits for padding tokens) continue assert isinstance(result[k], (int, float)), f"{k} should be an int or float" - assert abs(result[k] - results_megatron[i][k]) < 1.5e-1, f"diff in {k} is too large!" + assert abs(result[k] - results_megatron[i][k]) < 2.5e-1, f"diff in {k} is too large!" @pytest.mark.asyncio