Skip to content

feat: [Qwen3.5] Support block-wise FP8 quantization and model adaptation#18926

Merged
ispobock merged 5 commits intosgl-project:mainfrom
zju-stu-lizheng:qwen3_5/fp8_fix
Feb 18, 2026
Merged

feat: [Qwen3.5] Support block-wise FP8 quantization and model adaptation#18926
ispobock merged 5 commits intosgl-project:mainfrom
zju-stu-lizheng:qwen3_5/fp8_fix

Conversation

@zju-stu-lizheng
Copy link
Contributor

@zju-stu-lizheng zju-stu-lizheng commented Feb 17, 2026

Overview

This PR introduces support for block-wise FP8 quantization for the Qwen3.5 series and refines model adaptation logic for several architectures (Mistral-3, Qwen3-VL) to ensure compatibility with specialized quantization prefixes.

Key Changes

  1. Block-wise FP8 Scale Loading
    Implemented _load_merged_block_scale to handle block-level scale loading for merged column parallel linear layers.
    Added logic to calculate block sizes and offsets based on quantization configurations.
    Supported shard-wise weight block loading with full Tensor Parallel (TP) slicing support.
    Enhanced weight_loader_v2 to handle BlockQuantScaleParameter types.

  2. Model Adaptation & Prefix Alignment
    Qwen3.5: Updated quantization configuration prefix from model to mtp to align with specific architecture requirements.
    Qwen3-VL: Disabled quantization for vision encoders to resolve known incompatibility issues and updated hierarchy prefixes to model.visual and model.language_model.
    Mistral-3: Restricted layer prefix replacements to prevent unintended side effects in other models.
    Consistency: Refactored layer naming logic and removed redundant/unused replacements to ensure cleaner model initialization.

Acknowledgments

Special thanks to @cao1zhg for the collaboration and valuable contributions to this implementation.

Checklist

  • Tested with Qwen3.5 FP8 checkpoints.
  • All tensor parallel (TP) slicing logic for scales is verified.

瑀澈 added 4 commits February 17, 2026 18:13
- Restrict ignored layer prefix replacement to Mistral-3 models to avoid unintended side effects.
- Update Qwen3.5 quantization configuration prefix from 'model' to 'mtp'.
- Disable quantization for Qwen3-VL vision encoders to resolve incompatibility issues.
- Refactor prefixes to 'model.visual' and 'model.language_model' to accurately reflect model hierarchy.
- Remove redundant layer name replacements and ensure naming consistency for vision layers.
- Add `_load_merged_block_scale` for processing block-level quantization scales.
- Compute block dimensions and offsets from quantization settings.
- Enable shard-by-shard weight block loading with tensor parallelism support.
- Integrate `BlockQuantScaleParameter` handling into `weight_loader_v2`.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @zju-stu-lizheng, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the quantization capabilities by introducing block-wise FP8 support for the Qwen3.5 series. It also refines the model adaptation logic across several architectures, including Qwen3.5, Qwen3-VL, and Mistral-3, to ensure proper handling of quantization configurations and layer naming conventions. These changes aim to improve compatibility and robustness when applying advanced quantization techniques to these models.

Highlights

  • Block-wise FP8 Quantization Support: Implemented a new method _load_merged_block_scale to handle block-wise scale loading for merged column parallel linear layers, supporting shard-wise weight block loading with full Tensor Parallel (TP) slicing. The weight_loader_v2 was updated to utilize this new functionality for BlockQuantScaleParameter types.
  • Qwen3.5 Model Adaptation: The quantization configuration prefix for Qwen3.5 models was updated from 'model' to 'mtp' to align with specific architecture requirements.
  • Qwen3-VL Model Adaptation: Quantization for the Qwen3-VL vision encoder was disabled to resolve known incompatibility issues. Additionally, hierarchy prefixes were updated to 'model.visual' and 'model.language_model' for better consistency.
  • Mistral-3 Layer Prefix Restriction: The logic for replacing 'model.' prefixes in ignored_layers was refined to apply only to Mistral-3 models, preventing unintended side effects in other architectures.
  • Refactored Layer Naming Logic: Redundant name replacement logic, specifically for 'model.visual.', was removed in the Qwen3-VL model's weight loading, contributing to cleaner model initialization.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/linear.py
    • Implemented _load_merged_block_scale to manage block-wise scale loading for merged column parallel linear layers.
    • Updated weight_loader_v2 to dispatch to _load_merged_block_scale for BlockQuantScaleParameter types.
  • python/sglang/srt/layers/quantization/fp8.py
    • Restricted the model. prefix removal from ignored_layers to only apply to Mistral-3 models to prevent unintended side effects.
  • python/sglang/srt/models/qwen3_5_mtp.py
    • Modified the quantization configuration prefix for Qwen3.5 models from 'model' to 'mtp' for alignment.
  • python/sglang/srt/models/qwen3_vl.py
    • Disabled quantization for the Qwen3-VL vision encoder by setting quant_config to None.
    • Updated the prefix for the vision encoder from 'visual' to 'model.visual'.
    • Updated the prefix for the language model from 'model' to 'model.language_model'.
    • Removed a redundant name.replace operation for 'model.visual.' in load_weights.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for block-wise FP8 quantization for Qwen3.5 and refactors model adaptation logic for several models, including Qwen3-VL and Mistral-3. The changes are well-structured and align with the goals described.

My review includes a suggestion to refactor a loop for calculating shard sizes and offsets into a more concise and Pythonic implementation using itertools.accumulate. I also pointed out a minor typo in a comment. Overall, the changes look good and improve quantization support and model compatibility.

Comment on lines +743 to +750
shard_block_sizes = []
shard_block_offsets = []
current_block_offset = 0
for output_size in self.output_sizes:
shard_block_size = (output_size + block_n - 1) // block_n
shard_block_sizes.append(shard_block_size)
shard_block_offsets.append(current_block_offset)
current_block_offset += shard_block_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation of shard_block_sizes and shard_block_offsets can be made more concise and Pythonic by using a list comprehension and itertools.accumulate. This improves readability and maintainability.

        shard_block_sizes = [
            (output_size + block_n - 1) // block_n for output_size in self.output_sizes
        ]
        shard_block_offsets = [0] + list(itertools.accumulate(shard_block_sizes[:-1]))

- Deprecate special handling for the 'language_model.mtp' prefix.
- Refactor prefix replacement to avoid unintended side effects from over-replacement.
- Keep the essential 'mtp.' to 'model.' replacement rule intact.
- Remove the redundant 'model.norm' replacement to prevent regressions.
@yuan-luo
Copy link
Collaborator

@zju-stu-lizheng About the VLM update, could you please paste the running command for this enhancement?

@ispobock
Copy link
Collaborator

/tag-and-rerun-ci

@zju-stu-lizheng
Copy link
Contributor Author

zju-stu-lizheng commented Feb 18, 2026

The running command for loading fp8 checkpoint @yuan-luo :

TP_SIZE=8
DP_SIZE=4

python -m sglang.launch_server \
    --model Qwen/Qwen3.5-397B-A17B-FP8 \
    --tp-size ${TP_SIZE} \
    --dp-size ${DP_SIZE} \
    --enable-dp-attention \
    --enable-dp-lm-head \
    --enable-multimodal \
    --max-mamba-cache-size $((32 * 4)) \
    --max-running-requests $((32 * 4)) \
    --chunked-prefill-size $((256 * 8)) \
    --mem-fraction-static 0.85 \
    --model-loader-extra-config '{"enable_multithread_load": true,"num_threads": 64}' \
    --speculative-algo NEXTN    \
    --speculative-num-steps 3     \
    --speculative-eagle-topk 1     \
    --speculative-num-draft-tokens 4 

@ispobock ispobock merged commit fa5698d into sgl-project:main Feb 18, 2026
202 of 227 checks passed
@yuan-luo
Copy link
Collaborator

The running command for loading fp8 checkpoint @yuan-luo :

TP_SIZE=8 DP_SIZE=4

python -m sglang.launch_server \
    --model Qwen/Qwen3.5-397B-A17B-FP8 \
    --tp-size ${TP_SIZE} \
    --dp-size ${DP_SIZE} \
    --enable-dp-attention \
    --enable-dp-lm-head \
    --enable-multimodal \
    --max-mamba-cache-size $((32 * 4)) \
    --max-running-requests $((32 * 4)) \
    --chunked-prefill-size $((256 * 8)) \
    --mem-fraction-static 0.85 \
    --model-loader-extra-config '{"enable_multithread_load": true,"num_threads": 64}' \
    --speculative-algo NEXTN    \
    --speculative-num-steps 3     \
    --speculative-eagle-topk 1     \
    --speculative-num-draft-tokens 4 

@zju-stu-lizheng Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants