[megatron] feat: model engine support mtp#5561
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for Multi-Token Prediction (MTP) in the model engine, primarily affecting Megatron-based training and inference. Key changes include refactoring model forward passes to handle MTP-specific preprocessing and postprocessing, updating configuration files to enable MTP, and adding a new example script for MTP training. The changes also include improvements in handling nested tensors and position IDs. However, there are several areas that require attention to improve robustness, maintainability, and portability, particularly concerning hardcoded values, manual configuration steps, and potential behavioral changes in patched functions.
I am having trouble creating individual review comments. Click here to see my feedback.
examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh (34)
The comment indicates a manual step to modify max_position_embeddings in config.json. Manual intervention for configuration is prone to human error and reduces the reproducibility and automation of the setup. This step should ideally be automated within the script or handled programmatically by the model loading logic.
examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh (30-31)
The default values for NNODES and NGPUS_PER_NODE are set to 16 and 8 respectively, which are very high for a test script. These are then overridden to 1 and 4 within the fully_async array (lines 77-80). This inconsistency can lead to confusion regarding the actual resource allocation and might cause unexpected resource consumption or errors if the script is run without careful inspection. It is best to define these variables once or clearly indicate which values take precedence.
NNODES=${NNODES:-1}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-4}
examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh (48-49)
The use of magic numbers 2 and 3 in the calculations for actor_ppo_max_token_len and infer_ppo_max_token_len reduces the readability and maintainability of the script. It would be clearer to define these values as named variables with descriptive names, explaining their purpose.
ACTOR_PPO_FACTOR=2
INFER_PPO_FACTOR=3
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * ACTOR_PPO_FACTOR))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * INFER_PPO_FACTOR))
examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh (76)
The calculation 512*100 for rollout.total_rollout_steps is hardcoded. While simple, it would be more explicit and easier to modify if 512 and 100 were defined as variables, or if the resulting value was directly assigned as a literal if it's a fixed constant.
ROLLOUT_BASE_STEPS=512
ROLLOUT_MULTIPLIER=100
rollout.total_rollout_steps=$(((ROLLOUT_BASE_STEPS*ROLLOUT_MULTIPLIER)))
verl/experimental/fully_async_policy/shell/runtime_env_4_4.yaml (14-15)
The paths for RAY_DATA_HOME and TENSORBOARD_DIR are hardcoded to user-specific directories (/home/hadoop-djst-algoplat). This significantly reduces the portability and reusability of this runtime environment configuration across different systems or users. These paths should be made configurable (e.g., via environment variables that can be set externally) or use more generic relative paths.
verl/models/mcore/model_forward.py (191-192)
In the _convert_to_nested_tensor function, if vi.shape[0] < target_len, the tensor is padded with torch.ones. Depending on the context and the data represented by vi (e.g., token IDs), padding with 1 might be semantically incorrect if 1 is a valid token ID. This could lead to unintended model behavior. It would be safer to pad with a specific pad_token_id from the tokenizer or a value that is guaranteed not to interfere with valid data.
vi = torch.cat([vi, torch.full((target_len - vi.shape[0],), self.pad_token_id, dtype=vi.dtype, device=vi.device)])
verl/models/mcore/mtp_patch.py (78-81)
The refactoring of _megatron_gptmodel_postprocess removes the explicit delegation to self._postprocess_backup for inference paths (when labels is None). While the new logic might cover the training path, it's crucial to verify that the inference behavior remains unchanged. If _postprocess_backup contained specific logic or optimizations for inference that are not replicated in the new combined logic, this change could introduce regressions or alter the model's behavior during evaluation.
### What does this PR do? model engine support mtp verl-project#5323 break use mtp in mbridge, revert. Unload the KV cache before parameter synchronization (SGLang supports this first). <img width="696" height="550" alt="image" src="https://github.com/user-attachments/assets/2aeacab4-b466-4d51-85d0-128b54ff13b2" /> <img width="704" height="580" alt="image" src="https://github.com/user-attachments/assets/08ac4490-c41a-4ddd-b522-6a1539e2e229" /> Throughput increased from an initial **3900 token/s** to **4800 token/s**, representing a **23% improvement**. The speculative acceptance rate increased from 44% to 54%, representing a 22% improvement. <img width="1380" height="610" alt="image" src="https://github.com/user-attachments/assets/51da4d2e-3d12-4a71-8f48-f347e1c71896" /> <img width="2774" height="596" alt="image" src="https://github.com/user-attachments/assets/825084af-ba1e-4d58-ac9e-16f1251fd1e3" /> ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`, `fully_async`, `one_step_off` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [x] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
What does this PR do?
model engine support mtp
#5323 break use mtp in mbridge, revert.
Unload the KV cache before parameter synchronization (SGLang supports this first).
Throughput increased from an initial 3900 token/s to 4800 token/s, representing a 23% improvement.
The speculative acceptance rate increased from 44% to 54%, representing a 22% improvement.
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.