-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[megatron]feat: Add routing replay support for Megatron-Swift GRPO #8196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
affc976
03f718c
d8f03cf
15937e8
1e20732
75bddf2
b4113a5
e0700f6
f23c4ef
278f16e
c139541
0ba9443
177a4d9
46c5a70
35c8b80
4ee1125
44f76e8
fede8cf
a1edcc2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -157,6 +157,8 @@ class MegatronModelConfig(TransformerConfig): | |
| 'none'] = 'aux_loss' | ||
| use_shared_expert_gate: bool = False | ||
|
|
||
| enable_routing_replay: bool = False | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I recommend using the parameter name Reference: NVIDIA/Megatron-LM#2101 |
||
|
|
||
| # mla | ||
| multi_latent_attention: bool = False | ||
| q_lora_rank: Optional[int] = None | ||
|
|
@@ -508,6 +510,10 @@ def get_mcore_model_config(args, hf_config): | |
| if num_moe_experts is None: | ||
| kwargs['expert_model_parallel_size'] = 1 | ||
| kwargs['expert_tensor_parallel_size'] = 1 | ||
|
|
||
| if args.router_replay_mode != "disabled": | ||
| kwargs['enable_routing_replay'] = True | ||
|
|
||
| config = MegatronModelConfig(**kwargs) | ||
| config.hf_config = hf_config | ||
| config.args = args | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -23,7 +23,9 @@ | |||||||||||||||||||||
| from swift.dataset import RowPreprocessor | ||||||||||||||||||||||
| from swift.infer_engine.protocol import RequestConfig, RolloutInferRequest, RolloutOutput | ||||||||||||||||||||||
| from swift.megatron.arguments import MegatronArguments, MegatronRLHFArguments | ||||||||||||||||||||||
| from swift.megatron.utils import forward_step_helper, get_padding_to, set_random_seed | ||||||||||||||||||||||
| from swift.megatron.utils import (forward_step_helper, get_padding_to, set_random_seed, | ||||||||||||||||||||||
| RouterReplay, RouterReplayAction, RouterReplayHelper, | ||||||||||||||||||||||
| get_router_replay_data, set_router_replay_data, get_local_topk_idx_for_current_rank) | ||||||||||||||||||||||
| from swift.rewards import orms | ||||||||||||||||||||||
| from swift.rlhf_trainers.grpo_trainer import DataType | ||||||||||||||||||||||
| from swift.rlhf_trainers.utils import (aggressive_empty_cache, nanstd, pad_logps_back_to_batch, profiling_context, | ||||||||||||||||||||||
|
|
@@ -271,6 +273,8 @@ def _batch_encode(self, infer_requests: List[Dict], template: Template, strict: | |||||||||||||||||||||
| return batched_inputs, error_list | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _get_encoded_batch(self, encoded_list, rollout_batch, template): | ||||||||||||||||||||||
| original_seq_lengths = [item['length'] for item in encoded_list] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| args = self.args | ||||||||||||||||||||||
| encoded_batch = to_device(template.data_collator(encoded_list, padding_to=get_padding_to(args)), self.device) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -360,6 +364,40 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template): | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| encoded_batch['rollout_per_token_logps'] = rollout_per_token_logps | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Validating and processing routed_experts data in R3 mode | ||||||||||||||||||||||
| if self.args.router_replay_mode == 'R3': | ||||||||||||||||||||||
| routed_experts_list = [] | ||||||||||||||||||||||
| cur_seq_lengths = seq_lengths | ||||||||||||||||||||||
| if (seq_lengths.size(0) > batch_size): | ||||||||||||||||||||||
| cur_seq_lengths = seq_lengths[:batch_size].clone() | ||||||||||||||||||||||
| cur_seq_lengths[batch_size - 1] = seq_lengths[batch_size-1:].sum() | ||||||||||||||||||||||
| for data, original_seq_len, cur_seq_len in zip(rollout_batch, original_seq_lengths, cur_seq_lengths): | ||||||||||||||||||||||
| routed_experts = data.get('routed_experts') | ||||||||||||||||||||||
| assert routed_experts is not None, 'When router_replay_mode = R3, routed_experts must be in rollout data' | ||||||||||||||||||||||
| routed_experts = torch.tensor(routed_experts) | ||||||||||||||||||||||
| # The number of experts in the output can be 1 less than (prompt_length + response_token_count) | ||||||||||||||||||||||
| # This gap of 1 is expected | ||||||||||||||||||||||
| # For more details, please refer PR https://github.com/vllm-project/vllm/pull/28284 | ||||||||||||||||||||||
| experts_seq_len = routed_experts.shape[0] | ||||||||||||||||||||||
| assert (experts_seq_len == original_seq_len | ||||||||||||||||||||||
| or experts_seq_len + 1 == original_seq_len), \ | ||||||||||||||||||||||
| f'The seq_len of routed_experts({experts_seq_len}) in output does not match the seq_len of data({original_seq_len}), should be equal to or 1 less than the seq_len of data' | ||||||||||||||||||||||
| # Padding routed_experts(seq_len, layer_num, topk) seq_len to match the seq_len of the input_ids | ||||||||||||||||||||||
| padding_routed_experts = routed_experts | ||||||||||||||||||||||
| padding_to = cur_seq_len if template.padding_free else max_seq_len | ||||||||||||||||||||||
| padding_len = padding_to - experts_seq_len | ||||||||||||||||||||||
| if padding_len > 0: | ||||||||||||||||||||||
| padding_right = template.padding_side == 'right' | ||||||||||||||||||||||
| padding_routed_experts = nn.functional.pad(routed_experts, | ||||||||||||||||||||||
| (0, 0, 0, 0, 0, padding_len) if padding_right else (0, 0, 0, 0, padding_len, 0), | ||||||||||||||||||||||
| 'constant', 0) | ||||||||||||||||||||||
| routed_experts_list.append(padding_routed_experts) | ||||||||||||||||||||||
| if template.padding_free: | ||||||||||||||||||||||
| gloabl_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| gloabl_routed_experts = torch.stack(routed_experts_list) | ||||||||||||||||||||||
| encoded_batch['routed_experts'] = gloabl_routed_experts.to(device=self.device) | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a typo in the variable name
Suggested change
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return encoded_batch | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _generate_and_score_completions(self, batch): | ||||||||||||||||||||||
|
|
@@ -557,6 +595,9 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out | |||||||||||||||||||||
| if 'content' in choice.logprobs: | ||||||||||||||||||||||
| rollout_logprobs = [item['logprob'] for item in choice.logprobs['content']] | ||||||||||||||||||||||
| input_data['rollout_logprobs'] = [rollout_logprobs] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Step 6: Store rollout routed_experts for routing replay | ||||||||||||||||||||||
| input_data['routed_experts'] = choice.routed_experts | ||||||||||||||||||||||
| return input_data | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| assert len(batch) == len(outputs) | ||||||||||||||||||||||
|
|
@@ -951,9 +992,16 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: | |||||||||||||||||||||
| # In non-padding_free mode, logps are already in batch format [batch_size, seq_len] | ||||||||||||||||||||||
| ref_per_token_logps = ref_per_token_logps_raw | ||||||||||||||||||||||
| batch['ref_per_token_logps'] = ref_per_token_logps | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| old_per_token_logps_raw = self.model_forward( | ||||||||||||||||||||||
| self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps'] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if self.enable_routing_replay: | ||||||||||||||||||||||
| if self.args.router_replay_mode == 'R2': | ||||||||||||||||||||||
| RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) | ||||||||||||||||||||||
| if self.args.router_replay_mode == 'R3': | ||||||||||||||||||||||
| RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| output = self.model_forward(self.unwrapped_models[0], iter([deepcopy(inputs)]), | ||||||||||||||||||||||
| no_grad=True, per_token=True) | ||||||||||||||||||||||
| old_per_token_logps_raw = output['logps'] | ||||||||||||||||||||||
| if self.template.padding_free: | ||||||||||||||||||||||
| old_per_token_logps, _ = pad_logps_back_to_batch( | ||||||||||||||||||||||
| logps_rmpad=old_per_token_logps_raw, | ||||||||||||||||||||||
|
|
@@ -964,6 +1012,11 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: | |||||||||||||||||||||
| old_per_token_logps = old_per_token_logps_raw | ||||||||||||||||||||||
| batch['old_per_token_logps'] = old_per_token_logps | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if self.enable_routing_replay: | ||||||||||||||||||||||
| batch['routed_experts'] = output['layers_topk_idx'] | ||||||||||||||||||||||
| RouterReplay.clear_global_indices() | ||||||||||||||||||||||
| RouterReplay.clear_global_router_replay_action() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return batch | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _compute_kl_from_batches(self, mini_batch_data: List[Dict[str, Any]]) -> torch.Tensor: | ||||||||||||||||||||||
|
|
@@ -1043,6 +1096,15 @@ def forward_step(self, data_iterator, model): | |||||||||||||||||||||
| 'seq_lengths': seq_lengths, | ||||||||||||||||||||||
| }) | ||||||||||||||||||||||
| data.pop('loss_scale', None) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if RouterReplayHelper.is_replay_backward_action(model.config): | ||||||||||||||||||||||
| router_instance_list = RouterReplayHelper.get_micro_batch_router_list(model.config) | ||||||||||||||||||||||
| for router in router_instance_list: | ||||||||||||||||||||||
| router.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD) | ||||||||||||||||||||||
| if RouterReplayHelper.is_replay_forward_action(model.config): | ||||||||||||||||||||||
| layers_topk_idx = data.pop('routed_experts', None) | ||||||||||||||||||||||
| set_router_replay_data(layers_topk_idx, model.config) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| inputs = self._prepare_model_inputs(data) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| labels = data['labels'] | ||||||||||||||||||||||
|
|
@@ -1116,6 +1178,11 @@ def forward_step(self, data_iterator, model): | |||||||||||||||||||||
| data['per_token_logps'] = per_token_logps | ||||||||||||||||||||||
| data['per_token_entropy'] = None | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if RouterReplayHelper.is_replay_forward_action(model.config): | ||||||||||||||||||||||
| router_instance_list = RouterReplayHelper.get_micro_batch_router_list(model.config) | ||||||||||||||||||||||
| for router in router_instance_list: | ||||||||||||||||||||||
| router.set_router_replay_action(RouterReplayAction.REPLAY_BACKWARD) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return output_tensor, partial(self.loss_func, data=data) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @profiling_decorator | ||||||||||||||||||||||
|
|
@@ -1430,6 +1497,13 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False): | |||||||||||||||||||||
| labels = data.get('labels') | ||||||||||||||||||||||
| context = torch.no_grad() if no_grad else nullcontext() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| layers_topk_idx = None | ||||||||||||||||||||||
| global_topk_idx = data.pop('routed_experts', None) | ||||||||||||||||||||||
| if RouterReplayHelper.is_replay_forward_action(model.config): | ||||||||||||||||||||||
| assert global_topk_idx is not None, "When router_replay_mode = R3, routed_experts must be in data" | ||||||||||||||||||||||
| layers_topk_idx = get_local_topk_idx_for_current_rank(global_topk_idx, model.config, data.get('packed_seq_params')) | ||||||||||||||||||||||
| set_router_replay_data(layers_topk_idx, model.config) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| with context: | ||||||||||||||||||||||
| output_tensor = forward_step_helper(self.args, model, data) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -1441,6 +1515,12 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False): | |||||||||||||||||||||
| num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0] | ||||||||||||||||||||||
| data['logps'] = None if labels is None else self.get_logps( | ||||||||||||||||||||||
| output_tensor, labels, packed_seq_params, num_samples, per_token=per_token) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if RouterReplayHelper.is_r2_record_action(model.config): | ||||||||||||||||||||||
| layers_topk_idx = get_router_replay_data(model.config) | ||||||||||||||||||||||
| if layers_topk_idx is not None: | ||||||||||||||||||||||
| data['layers_topk_idx'] = layers_topk_idx | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return data | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def inputs2requests(self, inputs: Union[DataType, List[RolloutInferRequest]]) -> List[RolloutInferRequest]: | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -245,6 +245,9 @@ def _prepare_vllm_engine(self): | |
| vllm_engine_kwargs = args.vllm_engine_kwargs or {} | ||
| load_format = vllm_engine_kwargs.pop('load_format', 'dummy') | ||
|
|
||
| if self.args.router_replay_mode == 'R3': | ||
| vllm_engine_kwargs['enable_return_routed_experts'] = True | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a vLLM version check? The |
||
|
|
||
| engine = GRPOVllmEngine( | ||
| args.model_info.model_dir, | ||
| torch_dtype=args.torch_dtype, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.