Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/megatron/rlhf/gkd/opsd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@ megatron rlhf \
--model Qwen/Qwen3-4B \
--external_plugins examples/train/rlhf/opsd/opsd_plugin.py \
--dataset 'open-r1/OpenThoughts-114k-math' \
--use_vllm true \
--vllm_mode colocate \
--vllm_gpu_memory_utilization 0.6 \
--vllm_max_model_len 10240 \
--sleep_level 1 \
--lmbda 1.0 \
--beta 0.5 \
--temperature 1.2 \
--sft_alpha 0 \
--torch_dtype bfloat16 \
--micro_batch_size 1 \
--global_batch_size 32 \
--max_steps 1000 \
--train_iters 1000 \
--lr 2e-5 \
--save_steps 100 \
--save_total_limit 10 \
Expand Down
153 changes: 48 additions & 105 deletions swift/megatron/trainers/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template):

truncated_mask = torch.tensor([b['is_truncated'] for b in rollout_batch], dtype=torch.bool, device=self.device)

rolled_labels = torch.roll(labels, shifts=-1, dims=-1)

if template.padding_free:
# In padding_free mode, labels shape is [1, total_seq_len] (rmpad format)
# Calculate seq_lengths from cu_seq_lens or position_ids
Expand All @@ -290,7 +292,7 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template):
max_seq_len = seq_lengths.max().item()

# completion_mask in rmpad format [1, total_tokens]
completion_mask_rmpad = (labels != -100).float()
completion_mask_rmpad = (rolled_labels != -100).float()
completion_mask, _ = pad_logps_back_to_batch(
logps_rmpad=completion_mask_rmpad,
logits_to_keep=max_seq_len,
Expand All @@ -312,8 +314,8 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template):
seq_lengths = torch.full((batch_size, ), labels.shape[-1], dtype=torch.int64, device=self.device)
max_seq_len = labels.shape[-1]

# completion_mask is already [batch_size, seq_len] in non-padding_free mode
completion_mask = (labels != -100)
# completion_mask based on rolled labels for alignment with per_token_logps
completion_mask = (rolled_labels != -100)

encoded_batch.update({
'completion_mask': completion_mask, # [batch_size, max_seq_len]
Expand Down Expand Up @@ -359,7 +361,6 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template):
flat_lps, dtype=torch.float32, device=self.device)

encoded_batch['rollout_per_token_logps'] = rollout_per_token_logps

return encoded_batch

def _generate_and_score_completions(self, batch):
Expand Down Expand Up @@ -934,34 +935,31 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]:

inputs = self._prepare_model_inputs(batch)
if self.beta != 0.0:
with torch.no_grad(), self.null_ref_context() as ref_models:
with self.null_ref_context() as ref_models:
assert len(ref_models) == 1, 'GRPO currently does not support VPP.'
ref_model = ref_models[0]
ref_per_token_logps_raw = self.model_forward(
ref_model, iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps']
ref_per_token_logps_packed = self.compute_per_token_logps(
ref_model, iter([deepcopy(inputs)]), temperature=self.temperature)
if self.template.padding_free:
# In padding_free mode, logps are in rmpad format [1, total_tokens]
# Pad to batch format [batch_size, max_seq_len]
ref_per_token_logps, _ = pad_logps_back_to_batch(
logps_rmpad=ref_per_token_logps_raw,
logps_rmpad=ref_per_token_logps_packed,
logits_to_keep=max_seq_len,
batch_size=batch_size,
seq_lengths=seq_lengths)
else:
# In non-padding_free mode, logps are already in batch format [batch_size, seq_len]
ref_per_token_logps = ref_per_token_logps_raw
ref_per_token_logps = ref_per_token_logps_packed
batch['ref_per_token_logps'] = ref_per_token_logps

old_per_token_logps_raw = self.model_forward(
self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps']
old_per_token_logps_packed = self.compute_per_token_logps(
self.unwrapped_models[0], iter([deepcopy(inputs)]), temperature=self.temperature)
Comment on lines +941 to +954
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using deepcopy can be computationally expensive, especially when dealing with large tensors. Since compute_per_token_logps modifies the dictionary it receives by popping keys, a copy is necessary. However, a shallow copy using .copy() should be sufficient if the model's forward pass does not modify tensors in-place. This would be more performant.

If you can ensure that the forward pass is free of in-place tensor modifications, consider using a shallow copy for both calls to compute_per_token_logps.

Suggested change
ref_per_token_logps_packed = self.compute_per_token_logps(
ref_model, iter([deepcopy(inputs)]), temperature=self.temperature)
if self.template.padding_free:
# In padding_free mode, logps are in rmpad format [1, total_tokens]
# Pad to batch format [batch_size, max_seq_len]
ref_per_token_logps, _ = pad_logps_back_to_batch(
logps_rmpad=ref_per_token_logps_raw,
logps_rmpad=ref_per_token_logps_packed,
logits_to_keep=max_seq_len,
batch_size=batch_size,
seq_lengths=seq_lengths)
else:
# In non-padding_free mode, logps are already in batch format [batch_size, seq_len]
ref_per_token_logps = ref_per_token_logps_raw
ref_per_token_logps = ref_per_token_logps_packed
batch['ref_per_token_logps'] = ref_per_token_logps
old_per_token_logps_raw = self.model_forward(
self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps']
old_per_token_logps_packed = self.compute_per_token_logps(
self.unwrapped_models[0], iter([deepcopy(inputs)]), temperature=self.temperature)
ref_per_token_logps_packed = self.compute_per_token_logps(
ref_model, iter([inputs.copy()]), temperature=self.temperature)
if self.template.padding_free:
ref_per_token_logps, _ = pad_logps_back_to_batch(
logps_rmpad=ref_per_token_logps_packed,
logits_to_keep=max_seq_len,
batch_size=batch_size,
seq_lengths=seq_lengths)
else:
ref_per_token_logps = ref_per_token_logps_packed
batch['ref_per_token_logps'] = ref_per_token_logps
old_per_token_logps_packed = self.compute_per_token_logps(
self.unwrapped_models[0], iter([inputs.copy()]), temperature=self.temperature)

if self.template.padding_free:
old_per_token_logps, _ = pad_logps_back_to_batch(
logps_rmpad=old_per_token_logps_raw,
logps_rmpad=old_per_token_logps_packed,
logits_to_keep=max_seq_len,
batch_size=batch_size,
seq_lengths=seq_lengths)
else:
old_per_token_logps = old_per_token_logps_raw
old_per_token_logps = old_per_token_logps_packed
batch['old_per_token_logps'] = old_per_token_logps

return batch
Expand Down Expand Up @@ -1052,69 +1050,46 @@ def forward_step(self, data_iterator, model):

# Check if this is the PP last stage (only last stage has labels and computes loss)
is_pp_last_stage = mpu.is_pipeline_last_stage()

if self.compute_entropy:
# Forward without labels to get logits, then compute logps and entropy
inputs_for_logits = {k: v for k, v in inputs.items() if k != 'labels'}
output_tensor = model(**inputs_for_logits)

# Compute per_token_logps and per_token_entropy from logits on PP last stage
if is_pp_last_stage and output_tensor is not None:
# output_tensor is logits [batch/1, seq, partition_vocab_size]
per_token_logps_raw, per_token_entropy_raw = compute_logps_and_entropy_from_logits(
output_tensor, labels, compute_entropy=True)

# In CP mode, all_gather and reconstruct full sequence
if args.context_parallel_size > 1:
num_samples = packed_seq_params.num_samples if args.padding_free else micro_batch_size
per_token_logps_raw = self._postprocess_packed_tensor_cp(per_token_logps_raw, packed_seq_params,
num_samples)
per_token_entropy_raw = self._postprocess_packed_tensor_cp(per_token_entropy_raw, packed_seq_params,
num_samples)

if args.padding_free:
# Pad from rmpad [1, total_tokens] to batch format [batch_size, max_seq_len]
per_token_logps, _ = pad_logps_back_to_batch(
logps_rmpad=per_token_logps_raw,
logits_to_keep=max_seq_len,
batch_size=micro_batch_size,
seq_lengths=seq_lengths)
inputs_for_logits = {k: v for k, v in inputs.items() if k != 'labels'}
output_tensor = model(**inputs_for_logits)
if is_pp_last_stage and output_tensor is not None:
logits_packed = output_tensor
if self.temperature != 1.0:
logits_packed.div_(self.temperature)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For clarity and to avoid potential side effects, it's generally safer to perform tensor operations out-of-place, especially on tensors that are part of the computation graph. While the current in-place modification of logits_packed seems safe as output_tensor is not used elsewhere, using an out-of-place operation would make the code more robust to future changes.

Consider changing this to an out-of-place division to improve code safety, unless the in-place operation is a deliberate memory optimization.

Suggested change
logits_packed.div_(self.temperature)
logits_packed = logits_packed / self.temperature

per_token_logps_packed, per_token_entropy_packed = compute_logps_and_entropy_from_logits(
logits_packed, labels, compute_entropy=self.compute_entropy)
Comment on lines +1059 to +1060
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The labels tensor passed to compute_logps_and_entropy_from_logits appears to be unshifted. For autoregressive models, the labels should be shifted left by one position to align with the logits for next-token prediction (i.e., logits[..., i, :] predicts labels[..., i+1]). Using unshifted labels will result in misaligned log probabilities.

Please consider shifting the labels before this call.

Suggested change
per_token_logps_packed, per_token_entropy_packed = compute_logps_and_entropy_from_logits(
logits_packed, labels, compute_entropy=self.compute_entropy)
per_token_logps_packed, per_token_entropy_packed = compute_logps_and_entropy_from_logits(
logits_packed, torch.roll(labels, shifts=-1, dims=-1), compute_entropy=self.compute_entropy)


# In CP mode, all_gather and reconstruct full sequence
if args.context_parallel_size > 1:
num_samples = packed_seq_params.num_samples if args.padding_free else micro_batch_size
per_token_logps_packed = self._postprocess_packed_tensor_cp(per_token_logps_packed, packed_seq_params,
num_samples)
if per_token_entropy_packed is not None:
per_token_entropy_packed = self._postprocess_packed_tensor_cp(per_token_entropy_packed,
packed_seq_params, num_samples)

if args.padding_free:
# Pad from rmpad [1, total_tokens] to batch format [batch_size, max_seq_len]
per_token_logps, _ = pad_logps_back_to_batch(
logps_rmpad=per_token_logps_packed,
logits_to_keep=max_seq_len,
batch_size=micro_batch_size,
seq_lengths=seq_lengths)
if per_token_entropy_packed is not None:
per_token_entropy, _ = pad_logps_back_to_batch(
logps_rmpad=per_token_entropy_raw,
logps_rmpad=per_token_entropy_packed,
logits_to_keep=max_seq_len,
batch_size=micro_batch_size,
seq_lengths=seq_lengths,
pad_value=float('nan'))
else:
per_token_logps = per_token_logps_raw
per_token_entropy = per_token_entropy_raw

data['per_token_logps'] = per_token_logps
data['per_token_entropy'] = per_token_entropy
else:
# Standard forward with labels, returns per-token loss (more efficient)
output_tensor = model(**inputs)

# Convert output_tensor (per-token loss) to per_token_logps on PP last stage
if is_pp_last_stage and output_tensor is not None:
per_token_logps_raw = self.get_logps(
output_tensor,
labels,
packed_seq_params,
packed_seq_params.num_samples if args.padding_free else micro_batch_size,
per_token=True)

if args.padding_free:
per_token_logps, _ = pad_logps_back_to_batch(
logps_rmpad=per_token_logps_raw,
logits_to_keep=max_seq_len,
batch_size=micro_batch_size,
seq_lengths=seq_lengths)
else:
per_token_logps = per_token_logps_raw
per_token_entropy = None
else:
per_token_logps = per_token_logps_packed
per_token_entropy = per_token_entropy_packed

data['per_token_logps'] = per_token_logps
data['per_token_entropy'] = None
output_tensor = per_token_logps
data['per_token_entropy'] = per_token_entropy

return output_tensor, partial(self.loss_func, data=data)

Expand All @@ -1129,7 +1104,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):

# Get pre-computed per_token_logps and per_token_entropy from forward_step
# These are already in batch format [batch_size, max_seq_len]
per_token_logps = data.get('per_token_logps')
per_token_logps = output_tensor
per_token_entropy = data.get('per_token_entropy')

# Get pre-padded ref/old/rollout logps from data
Expand Down Expand Up @@ -1409,38 +1384,6 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):

return loss, reporting_metric

def model_forward(self, model, data_iterator, no_grad=True, per_token=False):
"""Forward pass through model to compute logps.

Args:
model: The model to forward
data_iterator: Iterator providing batch data
no_grad: Whether to use torch.no_grad() context
per_token: Whether to return per-token logps

Returns:
data dict containing 'logps'
"""
# used to calculate model forward (logps) in GRPO
data = self.get_batch(data_iterator)
data.pop('loss_scale', None)
input_ids = data.get('input_ids')
labels = data.get('labels')
context = torch.no_grad() if no_grad else nullcontext()

with context:
output_tensor = forward_step_helper(self.args, model, data)

# packed_seq_params only exists in padding_free mode
packed_seq_params = data.get('packed_seq_params')
if packed_seq_params is not None:
num_samples = packed_seq_params.num_samples
else:
num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0]
data['logps'] = None if labels is None else self.get_logps(
output_tensor, labels, packed_seq_params, num_samples, per_token=per_token)
return data

def inputs2requests(self, inputs: Union[DataType, List[RolloutInferRequest]]) -> List[RolloutInferRequest]:
"""Convert raw input data into RolloutInferRequest objects"""

Expand Down
48 changes: 46 additions & 2 deletions swift/megatron/trainers/rlhf_mixin.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from megatron.core import mpu
from torch.distributed.nn import all_reduce
from transformers.utils import ContextManagers

from swift.megatron.model import get_mcore_model
from swift.megatron.utils import load_mcore_checkpoint
from swift.megatron.utils import forward_step_helper, load_mcore_checkpoint
from swift.rlhf_trainers.utils import identity_data_collator
from swift.utils import get_logger
from .base import BaseMegatronTrainer
from .vocab_parallel_utils import compute_logps_and_entropy_from_logits

logger = get_logger()

Expand Down Expand Up @@ -91,6 +92,49 @@ def get_logps(self, output_tensor, labels, packed_seq_params, num_samples, per_t
all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group())
return all_logps

def compute_per_token_logps(self, model, data_iterator, no_grad=True, temperature=1.0):
"""Forward pass to get logits, then compute temperature-scaled per-token logps.

Unlike get_logps (which recovers logps from cross-entropy loss), this method
obtains raw logits from the model and computes logps with temperature scaling,
which is required for importance sampling in GRPO and potentially other algorithms.

Args:
model: The model to forward
data_iterator: Iterator providing batch data
no_grad: Whether to disable gradient computation (default: True)
temperature: Temperature for scaling logits before log_softmax

Returns:
per_token_logps tensor, or None if on a non-last PP stage
"""
data = self.get_batch(data_iterator)
data.pop('loss_scale', None)
labels = data.get('labels')

data_for_forward = {k: v for k, v in data.items() if k != 'labels'}
context = torch.no_grad() if no_grad else nullcontext()
with context:
output_tensor = forward_step_helper(self.args, model, data_for_forward)

if labels is None or output_tensor is None:
return None

if temperature != 1.0:
output_tensor.div_(temperature)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the change in grpo_trainer.py, consider using an out-of-place division here for better code safety and clarity. While the in-place operation is currently safe as output_tensor is not reused, an out-of-place operation is more robust against future modifications.

Suggested change
output_tensor.div_(temperature)
output_tensor = output_tensor / temperature

per_token_logps, _ = compute_logps_and_entropy_from_logits(output_tensor, labels)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the issue in grpo_trainer.py, the labels passed to compute_logps_and_entropy_from_logits here are unshifted. This will lead to misaligned log probabilities, as output_tensor (logits) at position i is for predicting token i+1. The labels should be shifted left by one.

Suggested change
per_token_logps, _ = compute_logps_and_entropy_from_logits(output_tensor, labels)
per_token_logps, _ = compute_logps_and_entropy_from_logits(output_tensor, torch.roll(labels, shifts=-1, dims=-1))


packed_seq_params = data.get('packed_seq_params')
if packed_seq_params is not None:
num_samples = packed_seq_params.num_samples
else:
input_ids = data.get('input_ids')
num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0]

if self.args.context_parallel_size > 1:
per_token_logps = self._postprocess_packed_tensor_cp(per_token_logps, packed_seq_params, num_samples)
return per_token_logps

def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples):
"""
Generic method: In CP mode, all_gather and reconstruct full tensor sequences.
Expand Down
6 changes: 5 additions & 1 deletion swift/megatron/trainers/vocab_parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,12 @@ def compute_logps_and_entropy_from_logits(
Note: In Megatron, labels are already shifted (via torch.roll in get_batch_on_this_tp_rank),
so logits and labels are already aligned. No additional shift is needed here.

Temperature scaling should be applied by the caller before invoking this function,
so that this function remains a pure computation without side effects on the input.

Args:
logits: Logits tensor [batch, seq, partition_vocab_size] or [1, total_tokens, partition_vocab_size]
logits: Logits tensor [batch, seq, partition_vocab_size] or [1, total_tokens, partition_vocab_size].
Should be pre-scaled by temperature if needed.
labels: Token labels [batch, seq] or [1, total_tokens], -100 for masked positions
compute_entropy: Whether to compute entropy (default: False)
entropy_chunk_size: Chunk size for entropy computation (default: 512)
Expand Down
Loading