-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Fix completion_mask alignment and temperature scaling in Megatron GRPO trainer #8427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -279,6 +279,8 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template): | |||||||||
|
|
||||||||||
| truncated_mask = torch.tensor([b['is_truncated'] for b in rollout_batch], dtype=torch.bool, device=self.device) | ||||||||||
|
|
||||||||||
| rolled_labels = torch.roll(labels, shifts=-1, dims=-1) | ||||||||||
|
|
||||||||||
| if template.padding_free: | ||||||||||
| # In padding_free mode, labels shape is [1, total_seq_len] (rmpad format) | ||||||||||
| # Calculate seq_lengths from cu_seq_lens or position_ids | ||||||||||
|
|
@@ -290,7 +292,7 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template): | |||||||||
| max_seq_len = seq_lengths.max().item() | ||||||||||
|
|
||||||||||
| # completion_mask in rmpad format [1, total_tokens] | ||||||||||
| completion_mask_rmpad = (labels != -100).float() | ||||||||||
| completion_mask_rmpad = (rolled_labels != -100).float() | ||||||||||
| completion_mask, _ = pad_logps_back_to_batch( | ||||||||||
| logps_rmpad=completion_mask_rmpad, | ||||||||||
| logits_to_keep=max_seq_len, | ||||||||||
|
|
@@ -312,8 +314,8 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template): | |||||||||
| seq_lengths = torch.full((batch_size, ), labels.shape[-1], dtype=torch.int64, device=self.device) | ||||||||||
| max_seq_len = labels.shape[-1] | ||||||||||
|
|
||||||||||
| # completion_mask is already [batch_size, seq_len] in non-padding_free mode | ||||||||||
| completion_mask = (labels != -100) | ||||||||||
| # completion_mask based on rolled labels for alignment with per_token_logps | ||||||||||
| completion_mask = (rolled_labels != -100) | ||||||||||
|
|
||||||||||
| encoded_batch.update({ | ||||||||||
| 'completion_mask': completion_mask, # [batch_size, max_seq_len] | ||||||||||
|
|
@@ -359,7 +361,6 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template): | |||||||||
| flat_lps, dtype=torch.float32, device=self.device) | ||||||||||
|
|
||||||||||
| encoded_batch['rollout_per_token_logps'] = rollout_per_token_logps | ||||||||||
|
|
||||||||||
| return encoded_batch | ||||||||||
|
|
||||||||||
| def _generate_and_score_completions(self, batch): | ||||||||||
|
|
@@ -934,34 +935,31 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: | |||||||||
|
|
||||||||||
| inputs = self._prepare_model_inputs(batch) | ||||||||||
| if self.beta != 0.0: | ||||||||||
| with torch.no_grad(), self.null_ref_context() as ref_models: | ||||||||||
| with self.null_ref_context() as ref_models: | ||||||||||
| assert len(ref_models) == 1, 'GRPO currently does not support VPP.' | ||||||||||
| ref_model = ref_models[0] | ||||||||||
| ref_per_token_logps_raw = self.model_forward( | ||||||||||
| ref_model, iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps'] | ||||||||||
| ref_per_token_logps_packed = self.compute_per_token_logps( | ||||||||||
| ref_model, iter([deepcopy(inputs)]), temperature=self.temperature) | ||||||||||
| if self.template.padding_free: | ||||||||||
| # In padding_free mode, logps are in rmpad format [1, total_tokens] | ||||||||||
| # Pad to batch format [batch_size, max_seq_len] | ||||||||||
| ref_per_token_logps, _ = pad_logps_back_to_batch( | ||||||||||
| logps_rmpad=ref_per_token_logps_raw, | ||||||||||
| logps_rmpad=ref_per_token_logps_packed, | ||||||||||
| logits_to_keep=max_seq_len, | ||||||||||
| batch_size=batch_size, | ||||||||||
| seq_lengths=seq_lengths) | ||||||||||
| else: | ||||||||||
| # In non-padding_free mode, logps are already in batch format [batch_size, seq_len] | ||||||||||
| ref_per_token_logps = ref_per_token_logps_raw | ||||||||||
| ref_per_token_logps = ref_per_token_logps_packed | ||||||||||
| batch['ref_per_token_logps'] = ref_per_token_logps | ||||||||||
|
|
||||||||||
| old_per_token_logps_raw = self.model_forward( | ||||||||||
| self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps'] | ||||||||||
| old_per_token_logps_packed = self.compute_per_token_logps( | ||||||||||
| self.unwrapped_models[0], iter([deepcopy(inputs)]), temperature=self.temperature) | ||||||||||
| if self.template.padding_free: | ||||||||||
| old_per_token_logps, _ = pad_logps_back_to_batch( | ||||||||||
| logps_rmpad=old_per_token_logps_raw, | ||||||||||
| logps_rmpad=old_per_token_logps_packed, | ||||||||||
| logits_to_keep=max_seq_len, | ||||||||||
| batch_size=batch_size, | ||||||||||
| seq_lengths=seq_lengths) | ||||||||||
| else: | ||||||||||
| old_per_token_logps = old_per_token_logps_raw | ||||||||||
| old_per_token_logps = old_per_token_logps_packed | ||||||||||
| batch['old_per_token_logps'] = old_per_token_logps | ||||||||||
|
|
||||||||||
| return batch | ||||||||||
|
|
@@ -1052,69 +1050,46 @@ def forward_step(self, data_iterator, model): | |||||||||
|
|
||||||||||
| # Check if this is the PP last stage (only last stage has labels and computes loss) | ||||||||||
| is_pp_last_stage = mpu.is_pipeline_last_stage() | ||||||||||
|
|
||||||||||
| if self.compute_entropy: | ||||||||||
| # Forward without labels to get logits, then compute logps and entropy | ||||||||||
| inputs_for_logits = {k: v for k, v in inputs.items() if k != 'labels'} | ||||||||||
| output_tensor = model(**inputs_for_logits) | ||||||||||
|
|
||||||||||
| # Compute per_token_logps and per_token_entropy from logits on PP last stage | ||||||||||
| if is_pp_last_stage and output_tensor is not None: | ||||||||||
| # output_tensor is logits [batch/1, seq, partition_vocab_size] | ||||||||||
| per_token_logps_raw, per_token_entropy_raw = compute_logps_and_entropy_from_logits( | ||||||||||
| output_tensor, labels, compute_entropy=True) | ||||||||||
|
|
||||||||||
| # In CP mode, all_gather and reconstruct full sequence | ||||||||||
| if args.context_parallel_size > 1: | ||||||||||
| num_samples = packed_seq_params.num_samples if args.padding_free else micro_batch_size | ||||||||||
| per_token_logps_raw = self._postprocess_packed_tensor_cp(per_token_logps_raw, packed_seq_params, | ||||||||||
| num_samples) | ||||||||||
| per_token_entropy_raw = self._postprocess_packed_tensor_cp(per_token_entropy_raw, packed_seq_params, | ||||||||||
| num_samples) | ||||||||||
|
|
||||||||||
| if args.padding_free: | ||||||||||
| # Pad from rmpad [1, total_tokens] to batch format [batch_size, max_seq_len] | ||||||||||
| per_token_logps, _ = pad_logps_back_to_batch( | ||||||||||
| logps_rmpad=per_token_logps_raw, | ||||||||||
| logits_to_keep=max_seq_len, | ||||||||||
| batch_size=micro_batch_size, | ||||||||||
| seq_lengths=seq_lengths) | ||||||||||
| inputs_for_logits = {k: v for k, v in inputs.items() if k != 'labels'} | ||||||||||
| output_tensor = model(**inputs_for_logits) | ||||||||||
| if is_pp_last_stage and output_tensor is not None: | ||||||||||
| logits_packed = output_tensor | ||||||||||
| if self.temperature != 1.0: | ||||||||||
| logits_packed.div_(self.temperature) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For clarity and to avoid potential side effects, it's generally safer to perform tensor operations out-of-place, especially on tensors that are part of the computation graph. While the current in-place modification of Consider changing this to an out-of-place division to improve code safety, unless the in-place operation is a deliberate memory optimization.
Suggested change
|
||||||||||
| per_token_logps_packed, per_token_entropy_packed = compute_logps_and_entropy_from_logits( | ||||||||||
| logits_packed, labels, compute_entropy=self.compute_entropy) | ||||||||||
|
Comment on lines
+1059
to
+1060
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Please consider shifting the labels before this call.
Suggested change
|
||||||||||
|
|
||||||||||
| # In CP mode, all_gather and reconstruct full sequence | ||||||||||
| if args.context_parallel_size > 1: | ||||||||||
| num_samples = packed_seq_params.num_samples if args.padding_free else micro_batch_size | ||||||||||
| per_token_logps_packed = self._postprocess_packed_tensor_cp(per_token_logps_packed, packed_seq_params, | ||||||||||
| num_samples) | ||||||||||
| if per_token_entropy_packed is not None: | ||||||||||
| per_token_entropy_packed = self._postprocess_packed_tensor_cp(per_token_entropy_packed, | ||||||||||
| packed_seq_params, num_samples) | ||||||||||
|
|
||||||||||
| if args.padding_free: | ||||||||||
| # Pad from rmpad [1, total_tokens] to batch format [batch_size, max_seq_len] | ||||||||||
| per_token_logps, _ = pad_logps_back_to_batch( | ||||||||||
| logps_rmpad=per_token_logps_packed, | ||||||||||
| logits_to_keep=max_seq_len, | ||||||||||
| batch_size=micro_batch_size, | ||||||||||
| seq_lengths=seq_lengths) | ||||||||||
| if per_token_entropy_packed is not None: | ||||||||||
| per_token_entropy, _ = pad_logps_back_to_batch( | ||||||||||
| logps_rmpad=per_token_entropy_raw, | ||||||||||
| logps_rmpad=per_token_entropy_packed, | ||||||||||
| logits_to_keep=max_seq_len, | ||||||||||
| batch_size=micro_batch_size, | ||||||||||
| seq_lengths=seq_lengths, | ||||||||||
| pad_value=float('nan')) | ||||||||||
| else: | ||||||||||
| per_token_logps = per_token_logps_raw | ||||||||||
| per_token_entropy = per_token_entropy_raw | ||||||||||
|
|
||||||||||
| data['per_token_logps'] = per_token_logps | ||||||||||
| data['per_token_entropy'] = per_token_entropy | ||||||||||
| else: | ||||||||||
| # Standard forward with labels, returns per-token loss (more efficient) | ||||||||||
| output_tensor = model(**inputs) | ||||||||||
|
|
||||||||||
| # Convert output_tensor (per-token loss) to per_token_logps on PP last stage | ||||||||||
| if is_pp_last_stage and output_tensor is not None: | ||||||||||
| per_token_logps_raw = self.get_logps( | ||||||||||
| output_tensor, | ||||||||||
| labels, | ||||||||||
| packed_seq_params, | ||||||||||
| packed_seq_params.num_samples if args.padding_free else micro_batch_size, | ||||||||||
| per_token=True) | ||||||||||
|
|
||||||||||
| if args.padding_free: | ||||||||||
| per_token_logps, _ = pad_logps_back_to_batch( | ||||||||||
| logps_rmpad=per_token_logps_raw, | ||||||||||
| logits_to_keep=max_seq_len, | ||||||||||
| batch_size=micro_batch_size, | ||||||||||
| seq_lengths=seq_lengths) | ||||||||||
| else: | ||||||||||
| per_token_logps = per_token_logps_raw | ||||||||||
| per_token_entropy = None | ||||||||||
| else: | ||||||||||
| per_token_logps = per_token_logps_packed | ||||||||||
| per_token_entropy = per_token_entropy_packed | ||||||||||
|
|
||||||||||
| data['per_token_logps'] = per_token_logps | ||||||||||
| data['per_token_entropy'] = None | ||||||||||
| output_tensor = per_token_logps | ||||||||||
| data['per_token_entropy'] = per_token_entropy | ||||||||||
|
|
||||||||||
| return output_tensor, partial(self.loss_func, data=data) | ||||||||||
|
|
||||||||||
|
|
@@ -1129,7 +1104,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): | |||||||||
|
|
||||||||||
| # Get pre-computed per_token_logps and per_token_entropy from forward_step | ||||||||||
| # These are already in batch format [batch_size, max_seq_len] | ||||||||||
| per_token_logps = data.get('per_token_logps') | ||||||||||
| per_token_logps = output_tensor | ||||||||||
| per_token_entropy = data.get('per_token_entropy') | ||||||||||
|
|
||||||||||
| # Get pre-padded ref/old/rollout logps from data | ||||||||||
|
|
@@ -1409,38 +1384,6 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): | |||||||||
|
|
||||||||||
| return loss, reporting_metric | ||||||||||
|
|
||||||||||
| def model_forward(self, model, data_iterator, no_grad=True, per_token=False): | ||||||||||
| """Forward pass through model to compute logps. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| model: The model to forward | ||||||||||
| data_iterator: Iterator providing batch data | ||||||||||
| no_grad: Whether to use torch.no_grad() context | ||||||||||
| per_token: Whether to return per-token logps | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| data dict containing 'logps' | ||||||||||
| """ | ||||||||||
| # used to calculate model forward (logps) in GRPO | ||||||||||
| data = self.get_batch(data_iterator) | ||||||||||
| data.pop('loss_scale', None) | ||||||||||
| input_ids = data.get('input_ids') | ||||||||||
| labels = data.get('labels') | ||||||||||
| context = torch.no_grad() if no_grad else nullcontext() | ||||||||||
|
|
||||||||||
| with context: | ||||||||||
| output_tensor = forward_step_helper(self.args, model, data) | ||||||||||
|
|
||||||||||
| # packed_seq_params only exists in padding_free mode | ||||||||||
| packed_seq_params = data.get('packed_seq_params') | ||||||||||
| if packed_seq_params is not None: | ||||||||||
| num_samples = packed_seq_params.num_samples | ||||||||||
| else: | ||||||||||
| num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0] | ||||||||||
| data['logps'] = None if labels is None else self.get_logps( | ||||||||||
| output_tensor, labels, packed_seq_params, num_samples, per_token=per_token) | ||||||||||
| return data | ||||||||||
|
|
||||||||||
| def inputs2requests(self, inputs: Union[DataType, List[RolloutInferRequest]]) -> List[RolloutInferRequest]: | ||||||||||
| """Convert raw input data into RolloutInferRequest objects""" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,15 +1,16 @@ | ||||||
| # Copyright (c) ModelScope Contributors. All rights reserved. | ||||||
| import torch | ||||||
| from contextlib import contextmanager | ||||||
| from contextlib import contextmanager, nullcontext | ||||||
| from megatron.core import mpu | ||||||
| from torch.distributed.nn import all_reduce | ||||||
| from transformers.utils import ContextManagers | ||||||
|
|
||||||
| from swift.megatron.model import get_mcore_model | ||||||
| from swift.megatron.utils import load_mcore_checkpoint | ||||||
| from swift.megatron.utils import forward_step_helper, load_mcore_checkpoint | ||||||
| from swift.rlhf_trainers.utils import identity_data_collator | ||||||
| from swift.utils import get_logger | ||||||
| from .base import BaseMegatronTrainer | ||||||
| from .vocab_parallel_utils import compute_logps_and_entropy_from_logits | ||||||
|
|
||||||
| logger = get_logger() | ||||||
|
|
||||||
|
|
@@ -91,6 +92,49 @@ def get_logps(self, output_tensor, labels, packed_seq_params, num_samples, per_t | |||||
| all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) | ||||||
| return all_logps | ||||||
|
|
||||||
| def compute_per_token_logps(self, model, data_iterator, no_grad=True, temperature=1.0): | ||||||
| """Forward pass to get logits, then compute temperature-scaled per-token logps. | ||||||
|
|
||||||
| Unlike get_logps (which recovers logps from cross-entropy loss), this method | ||||||
| obtains raw logits from the model and computes logps with temperature scaling, | ||||||
| which is required for importance sampling in GRPO and potentially other algorithms. | ||||||
|
|
||||||
| Args: | ||||||
| model: The model to forward | ||||||
| data_iterator: Iterator providing batch data | ||||||
| no_grad: Whether to disable gradient computation (default: True) | ||||||
| temperature: Temperature for scaling logits before log_softmax | ||||||
|
|
||||||
| Returns: | ||||||
| per_token_logps tensor, or None if on a non-last PP stage | ||||||
| """ | ||||||
| data = self.get_batch(data_iterator) | ||||||
| data.pop('loss_scale', None) | ||||||
| labels = data.get('labels') | ||||||
|
|
||||||
| data_for_forward = {k: v for k, v in data.items() if k != 'labels'} | ||||||
| context = torch.no_grad() if no_grad else nullcontext() | ||||||
| with context: | ||||||
| output_tensor = forward_step_helper(self.args, model, data_for_forward) | ||||||
|
|
||||||
| if labels is None or output_tensor is None: | ||||||
| return None | ||||||
|
|
||||||
| if temperature != 1.0: | ||||||
| output_tensor.div_(temperature) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the change in
Suggested change
|
||||||
| per_token_logps, _ = compute_logps_and_entropy_from_logits(output_tensor, labels) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the issue in
Suggested change
|
||||||
|
|
||||||
| packed_seq_params = data.get('packed_seq_params') | ||||||
| if packed_seq_params is not None: | ||||||
| num_samples = packed_seq_params.num_samples | ||||||
| else: | ||||||
| input_ids = data.get('input_ids') | ||||||
| num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0] | ||||||
|
|
||||||
| if self.args.context_parallel_size > 1: | ||||||
| per_token_logps = self._postprocess_packed_tensor_cp(per_token_logps, packed_seq_params, num_samples) | ||||||
| return per_token_logps | ||||||
|
|
||||||
| def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples): | ||||||
| """ | ||||||
| Generic method: In CP mode, all_gather and reconstruct full tensor sequences. | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
deepcopycan be computationally expensive, especially when dealing with large tensors. Sincecompute_per_token_logpsmodifies the dictionary it receives by popping keys, a copy is necessary. However, a shallow copy using.copy()should be sufficient if the model's forward pass does not modify tensors in-place. This would be more performant.If you can ensure that the forward pass is free of in-place tensor modifications, consider using a shallow copy for both calls to
compute_per_token_logps.