-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[megatron]feat: Add routing replay support for Megatron-Swift GRPO #8196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
affc976
03f718c
d8f03cf
15937e8
1e20732
75bddf2
b4113a5
e0700f6
f23c4ef
278f16e
c139541
0ba9443
177a4d9
46c5a70
35c8b80
4ee1125
44f76e8
fede8cf
a1edcc2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,19 +2,39 @@ | |||||||
| import base64 | ||||||||
| import io | ||||||||
| import json | ||||||||
| import numpy as np | ||||||||
| import os | ||||||||
| import time | ||||||||
| import uuid | ||||||||
| from copy import deepcopy | ||||||||
| from dataclasses import asdict, dataclass, field, fields | ||||||||
| from PIL import Image | ||||||||
| from pydantic import BaseModel, Field, field_validator | ||||||||
| from typing import Any, Dict, List, Literal, Optional, Tuple, Union | ||||||||
| from pydantic import AfterValidator, BaseModel, Field, PlainSerializer, field_validator | ||||||||
| from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union | ||||||||
|
|
||||||||
| from swift.template import Messages, Tool | ||||||||
| from swift.utils import remove_response | ||||||||
|
|
||||||||
|
|
||||||||
| def serialize_ndarray(value): | ||||||||
| if value is None: | ||||||||
| return None | ||||||||
| if isinstance(value, np.ndarray): | ||||||||
| return {'data': value.tolist(), 'shape': value.shape, 'dtype': str(value.dtype), '__ndarray__': True} | ||||||||
| return value | ||||||||
|
|
||||||||
|
|
||||||||
| def deserialize_ndarray(value): | ||||||||
| if value is None: | ||||||||
| return None | ||||||||
| if isinstance(value, dict) and value.get('__ndarray__'): | ||||||||
| return np.array(value['data'], dtype=value['dtype']).reshape(value['shape']) | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| return value | ||||||||
|
|
||||||||
|
|
||||||||
| NumpyArray = Annotated[Any, PlainSerializer(serialize_ndarray, return_type=Dict), AfterValidator(deserialize_ndarray)] | ||||||||
|
|
||||||||
|
|
||||||||
| @dataclass | ||||||||
| class InferRequest: | ||||||||
| """ | ||||||||
|
|
@@ -392,6 +412,7 @@ class ChatCompletionResponseChoice: | |||||||
| finish_reason: Literal['stop', 'length', None] | ||||||||
| logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None | ||||||||
| token_ids: Optional[List[int]] = None | ||||||||
| routed_experts: Optional[NumpyArray] = None | ||||||||
|
|
||||||||
| def to_cmpl_choice(self) -> 'CompletionResponseChoice': | ||||||||
| self = deepcopy(self) | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -157,6 +157,8 @@ class MegatronModelConfig(TransformerConfig): | |
| 'none'] = 'aux_loss' | ||
| use_shared_expert_gate: bool = False | ||
|
|
||
| enable_routing_replay: bool = False | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I recommend using the parameter name Reference: NVIDIA/Megatron-LM#2101 |
||
|
|
||
| # mla | ||
| multi_latent_attention: bool = False | ||
| q_lora_rank: Optional[int] = None | ||
|
|
@@ -508,6 +510,10 @@ def get_mcore_model_config(args, hf_config): | |
| if num_moe_experts is None: | ||
| kwargs['expert_model_parallel_size'] = 1 | ||
| kwargs['expert_tensor_parallel_size'] = 1 | ||
|
|
||
| if args.router_replay_mode != 'disabled': | ||
| kwargs['enable_routing_replay'] = True | ||
|
|
||
| config = MegatronModelConfig(**kwargs) | ||
| config.hf_config = hf_config | ||
| config.args = args | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -245,6 +245,9 @@ def _prepare_vllm_engine(self): | |
| vllm_engine_kwargs = args.vllm_engine_kwargs or {} | ||
| load_format = vllm_engine_kwargs.pop('load_format', 'dummy') | ||
|
|
||
| if self.args.router_replay_mode == 'R3': | ||
| vllm_engine_kwargs['enable_return_routed_experts'] = True | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a vLLM version check? The |
||
|
|
||
| engine = GRPOVllmEngine( | ||
| args.model_info.model_dir, | ||
| torch_dtype=args.torch_dtype, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tolist() is inefficient