[Model] Support TP/PP/mamba2 kernel for PLaMo2#19674
[Model] Support TP/PP/mamba2 kernel for PLaMo2#19674tlrmchlsmth merged 18 commits intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Summary of Changes
Hello @Alnusjaponica, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the PLaMo2 model implementation by adding support for Pipeline Parallelism and improving Tensor Parallelism handling. It integrates the Mamba2 kernel for more efficient state-space model computations, including support for chunked prefill. The changes also involve refactoring the model's internal structure to use standard vLLM components and updating the weight loading logic accordingly.
Highlights
- Pipeline Parallelism Support: Implemented support for Pipeline Parallelism (PP) by modifying the
Plamo2DecoderandPlamo2PreTrainedModelforward passes to handle intermediate tensors and using themake_layersutility. - Tensor Parallelism Support: Updated linear layers (
ColumnParallelLinear,MergedColumnParallelLinear,RowParallelLinear) in the Mamba and Attention mixers to correctly handle tensor parallelism, including weight loading and parameter shapes. - Mamba2 Kernel Integration: Refactored the
Plamo2MambaMixerto align with the Mamba2 implementation, integrating themamba_chunk_scan_combinedkernel for prefill andselective_state_updatefor decode, enabling chunked prefill and continuous batching for Mamba layers. - Unified Mamba/Attention Mixer Initialization: Modified the
Plamo2MambaMixerandPlamo2AttentionMixerconstructors to acceptVllmConfigdirectly, simplifying initialization. - Standard Component Usage: Replaced custom RMSNorm and activation functions (
_rms_norm,_swiglu) with vLLM's standardRMSNormandSiluAndMullayers, and integrated the standardSampler. - Weight Loading Updates: Adjusted the
load_weightsmethod to handle the new RMSNorm weights in the Attention mixer and reshape thein_projweights in the Mamba mixer to match the expected format forMergedColumnParallelLinear, supporting both unquantized and quantized weights. - Documentation Update: Updated the
supported_models.mdfile to indicate PLaMo2 now supports Tensor Parallelism.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request significantly enhances the PLaMo2 model by adding support for Tensor Parallelism (TP), Pipeline Parallelism (PP), and integrating the Mamba2 kernel for its Mamba layers. Key changes include:
- Parallelism: Standard vLLM utilities like
make_layers,is_pp_missing_parameter, andget_pp_groupare used to enable TP and PP. Model components (Plamo2Model,Plamo2Decoder,Plamo2ForCausalLM) are updated to handle distributed execution and intermediate tensor passing. - Mamba Mixer Refactor:
Plamo2MambaMixeris updated to use Mamba2-style parameters (e.g.,A,Dshapes related tonum_heads) and leverages themamba_chunk_scan_combinedkernel for prefill, while retainingselective_state_updatefor decode. It now acceptsMamba2Metadata. - Attention Mixer Update:
Plamo2AttentionMixernow applies RMSNorm per head to Q and K projections, a common feature in newer LLMs. The old custom RMSNorm and SwiGLU implementations are correctly replaced by standard vLLM layers (RMSNorm,SiluAndMul). - Configuration: Model components now consistently use
VllmConfigfor initialization. - Weight Loading: The
load_weightsmethod is substantially updated to handle new parameter names (e.g., for per-head Q/K norms) and complex reshaping for Mamba'sin_projlayer to align withMergedColumnParallelLinearexpectations. It also correctly skips loading weights for layers not on the current PP rank. - Documentation: The
supported_models.mdfile is updated to reflect the new TP and PP capabilities of PLaMo2.
The changes appear well-structured and aim to integrate PLaMo2 more deeply into the vLLM ecosystem. The most critical areas for review by the author would be the correctness of the Mamba parameter shape changes (if upgrading to Mamba2), the new per-head Q/K normalization in attention, and the intricate weight reshaping logic for in_proj in load_weights, especially concerning various quantization methods. Comprehensive testing across different parallelism configurations and quantization schemes will be essential.
| self.q_norm = RMSNorm(config.hidden_size_per_head, | ||
| eps=config.rms_norm_eps) | ||
| self.q_norm.weight = torch.nn.Parameter( | ||
| torch.ones((self.num_heads, config.hidden_size_per_head))) | ||
| self.k_weight = torch.nn.Parameter( | ||
| set_weight_attrs(self.q_norm.weight, | ||
| {"weight_loader": sharded_weight_loader(0)}) | ||
| self.k_norm = RMSNorm(config.hidden_size_per_head, | ||
| eps=config.rms_norm_eps) | ||
| self.k_norm.weight = torch.nn.Parameter( | ||
| torch.ones((self.num_kv_heads, config.hidden_size_per_head))) | ||
| # Tensor-parallelism shards the K norm weights to the tp ranks | ||
| # in a head-wise manner. This approach does not work if there is only | ||
| # a single KV head, as is the case for PLaMo 2-1B. | ||
| if self.total_num_kv_heads != 1: | ||
| set_weight_attrs(self.k_norm.weight, | ||
| {"weight_loader": sharded_weight_loader(0)}) |
There was a problem hiding this comment.
The q_weight and k_weight parameters have been replaced with q_norm and k_norm respectively, which are RMSNorm layers applied per-head. This is a significant architectural change, similar to what's seen in models like Llama 3.
- Please confirm if this per-head normalization for Q and K is an intended upgrade/modification for the PLaMo2 architecture or if PLaMo2 originally had this feature.
- The
sharded_weight_loader(0)is used forq_norm.weightandk_norm.weight. This implies these weights are sharded along the head dimension. The conditional sharding fork_norm.weightwhenself.total_num_kv_heads == 1is a good detail to handle GQA/MQA correctly with TP.
There was a problem hiding this comment.
- per-head
RMSNormis expected. We had own implementation for this, but this PR utilizes the vLLM native implementation. - Thanks. As noted in the comment, some model has only one head, which is why we have this conditional sharding.
| self.A = nn.Parameter( | ||
| torch.empty( | ||
| self.intermediate_size // tp_size, | ||
| self.ssm_state_size, | ||
| divide(self.num_heads, self.tp_size), | ||
| dtype=torch.float32, | ||
| )) | ||
| self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) | ||
| self.D = nn.Parameter(torch.ones(divide(self.num_heads, self.tp_size))) | ||
| self.dt_bias = nn.Parameter( | ||
| torch.ones(divide(self.num_heads, self.tp_size))) |
There was a problem hiding this comment.
The shapes and initialization of self.A, self.D, and self.dt_bias have changed significantly.
Previously:
self.Awas(self.intermediate_size // tp_size, self.ssm_state_size).self.Dwastorch.ones(self.intermediate_size // tp_size).self.dt_biaswas initialized usingget_initial_dt_bias(self.num_heads).
Now:
self.Ais(divide(self.num_heads, self.tp_size), dtype=torch.float32)(note: the second dimension ford_stateseems to be missing in thetorch.emptycall, but the weight loader forAimplies it's related toA_logwhich is typically(d_inner, d_state)or(num_heads, d_state)in Mamba variants. The current definitiontorch.empty(divide(self.num_heads, self.tp_size), dtype=torch.float32)seems to define a 1D parameter per TP shard, which might be incorrect ifAis intended to be a matrix. Please double-check the intended shape and initialization forself.Abased on Mamba2's architecture. Thea_weight_loaderuseslambda x: -torch.exp(x.float()), which is common forA_logparameter in Mamba.self.Distorch.ones(divide(self.num_heads, self.tp_size)).self.dt_biasistorch.ones(divide(self.num_heads, self.tp_size)).
These changes suggest a shift from intermediate_size based parameters to num_heads based parameters, which is characteristic of Mamba2's per-head independent SSMs. Ensure these new shapes and initializations correctly reflect the PLaMo2 model's intended Mamba variant (original or Mamba2 upgrade). If A is meant to be A_log, its typical shape is (d_inner, d_state) or (n_heads, d_state). The current definition for self.A seems to be (n_heads_per_tp,) which might be missing the d_state dimension.
There was a problem hiding this comment.
This change is totally expected. We have been manually broadcasted those parameters in the load_weights to fit mamba1 kernel API, though PLaMo2's architecture is rather mamba2 in the ssm layer. The duplication step in the load_weights is also removed in this PR.
vllm/vllm/model_executor/models/plamo2.py
Lines 713 to 722 in aed8468
Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Sixue Wang <cecilwang@preferred.jp>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp>
c2a909a to
ed09215
Compare
tlrmchlsmth
left a comment
There was a problem hiding this comment.
Eval results look good to me
main (TP=1, PP=1)
|Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.202|± |0.0127|
| | |strict-match | 5|exact_match|↑ |0.520|± |0.0158|
This PR (TP=2, PP=2)
|Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.205|± |0.0128|
| | |strict-match | 5|exact_match|↑ |0.517|± |0.0158|
tlrmchlsmth
left a comment
There was a problem hiding this comment.
Could you please merge latest main? Once that's done I'll mark it ready and turn on automerge - thank you!
| composed_weight_loader, default_weight_loader, sharded_weight_loader) | ||
| from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, | ||
| SupportsV0Only) | ||
| SupportsPP, SupportsV0Only) |
There was a problem hiding this comment.
@Alnusjaponica do you have plans to add support for V1?
There was a problem hiding this comment.
@nopperl is currently working on V1 support at https://github.com/pfnet/vllm/tree/plamo2-follow-up-v1 but has not yet caught up with the latest main branch. We are going to submit it as a separate PR.
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
|
@tlrmchlsmth Thanks for your review! I've merge the latest main and added assertion for tp_size check. Could you take another look? |
Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp> Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp> Co-authored-by: Sixue Wang <cecilwang@preferred.jp>
Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp> Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp> Co-authored-by: Sixue Wang <cecilwang@preferred.jp>
Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp> Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp> Co-authored-by: Sixue Wang <cecilwang@preferred.jp> Signed-off-by: x22x22 <wadeking@qq.com>
Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp> Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp> Co-authored-by: Sixue Wang <cecilwang@preferred.jp>
Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp> Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp> Co-authored-by: Sixue Wang <cecilwang@preferred.jp>
Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp> Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp> Co-authored-by: Sixue Wang <cecilwang@preferred.jp> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp> Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp> Co-authored-by: Sixue Wang <cecilwang@preferred.jp> Signed-off-by: Paul Pak <paulpak58@gmail.com>
Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp> Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp> Co-authored-by: Sixue Wang <cecilwang@preferred.jp> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp> Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp> Co-authored-by: Sixue Wang <cecilwang@preferred.jp>
Signed-off-by: Shinichi Hemmi <shemmi@preferred.jp> Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger <metzger@preferred.jp> Co-authored-by: Sixue Wang <cecilwang@preferred.jp>
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Purpose
Follow-up #14323 to support
Test Plan
Manually modify tests to use only
pfnet/plamo-2-1brun the following tests.Test Result
(Optional) Documentation Update
Updated
docs/models/supported_models.mdin 89247de