Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def loss_func(logits, data):
action_mask = data.get("action_mask")
num_microbatches = data.get("num_microbatches")

dp_size = mpu.get_data_parallel_world_size()
dp_size = mpu.get_data_parallel_world_size(with_context_parallel=True)
tp_grp = mpu.get_tensor_model_parallel_group()
tp_rank = mpu.get_tensor_model_parallel_rank()

Expand Down Expand Up @@ -346,14 +346,18 @@ def loss_func(logits, data):
# when summing across the entire minibatch (see `apply_loss_reduction_to_advantages_minibatch`).
# Megatron divides loss by num_microbatches
# (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248)
# and the data parallel all-reduce averages gradients across dp_size
# and the data parallel all-reduce averages gradients across dp_size (including CP ranks)
# (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285)
# so we multiply by both factors to recover the correct sum reduction.
grad_sum_correction_factor = num_microbatches * dp_size
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.

# NOTE: The KL and entropy loss terms are not pre-scaled,
# so we just average them across microbatches and DP workers.
loss = policy_loss * grad_sum_correction_factor + kl_loss_term - entropy_loss_term
# Megatron's DDP averages gradients across the full DP+CP group,
# but KL/entropy should only be averaged across DP (not CP).
# Multiply by cp_size to counteract the unwanted CP averaging.
cp_size = mpu.get_context_parallel_world_size()
loss = policy_loss * grad_sum_correction_factor + (kl_loss_term - entropy_loss_term) * cp_size
unscaled_loss = loss / grad_sum_correction_factor

# Build per-sequence loss_fn_outputs with logprobs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def forward_backward(
# NOTE: Sum loss metrics because scaling is already applied at the advantage level
status = reduce_metrics(all_metrics, sum_loss_metrics=sum_loss_metrics)
status["policy_lr"] = self.optimizer.param_groups[0]["lr"]
group = mpu.get_data_parallel_group(with_context_parallel=True)
group = mpu.get_data_parallel_group(with_context_parallel=False)
status = all_reduce_metrics(status, self.strategy, group=group, sum_loss_metrics=sum_loss_metrics)

# Add loss_fn_outputs back (not reduced, kept as list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ async def test_megatron_train(
# the entropy calculation is different (fsdp has random logits for padding tokens)
continue
assert isinstance(result[k], (int, float)), f"{k} should be an int or float"
assert abs(result[k] - results_megatron[i][k]) < 1.5e-1, f"diff in {k} is too large!"
assert abs(result[k] - results_megatron[i][k]) < 2.5e-1, f"diff in {k} is too large!"
Comment thread
erictang000 marked this conversation as resolved.


@pytest.mark.asyncio
Expand Down
Loading