Skip to content

enable marlin fp8 blockwise#8990

Merged
zhyncs merged 7 commits intosgl-project:mainfrom
qeternity:marlin-enable-fp8-block
Aug 19, 2025
Merged

enable marlin fp8 blockwise#8990
zhyncs merged 7 commits intosgl-project:mainfrom
qeternity:marlin-enable-fp8-block

Conversation

@qeternity
Copy link
Contributor

Motivation

Update the Marlin FP8 implementation to allow blockwise quants to run on older hardware.

Resolves #7715

Modifications

Refactor weight post-processing so that we don't early return before Marlin prep. Auto enable Marlin for supported archs.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @qeternity, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've updated the Marlin FP8 implementation to enable blockwise quantization on older GPU hardware. This change also refactors the weight post-processing logic and introduces automatic Marlin enablement for supported architectures, addressing a reported issue.

Highlights

  • Marlin FP8 Blockwise Support: I've enabled the Marlin FP8 kernel to support blockwise quantization, which was previously not possible. This allows users with older GPU hardware (SM 8.0-8.8) to leverage Marlin for faster FP8 blockwise operations.
  • Revised Marlin Enablement Logic: The logic for determining self.use_marlin has been updated. It now automatically enables Marlin for CUDA devices with compute capability between SM 8.0 and SM 8.8, in addition to the existing SGLANG_FORCE_FP8_MARLIN environment variable. The previous explicit disablement for ROCm and block-wise FP8 has been removed.
  • Refactored Weight Post-processing: I've refactored the process_weights_after_loading method to ensure that weight preparation for Marlin (e.g., prepare_fp8_layer_for_marlin) is correctly applied regardless of whether block quantization is used. This involves restructuring the conditional logic for block_quant.
  • New GPU Capability Check: I've introduced a new utility function, can_auto_enable_marlin_fp8, which checks the GPU's compute capability to determine if it falls within the range (SM 8.0 to 8.8) where Marlin FP8 can be automatically enabled.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables Marlin FP8 quantization for blockwise quantized models. The changes involve refactoring the weight processing logic to allow Marlin preparation for blockwise quantized weights and adding a mechanism to auto-enable Marlin on supported hardware (Ampere GPUs).

My review focuses on code correctness and robustness. I've identified a minor issue with the use of a bare except clause, which could potentially mask errors. Overall, the changes are well-structured and achieve the intended goal.

@qeternity qeternity force-pushed the marlin-enable-fp8-block branch 2 times, most recently from b9cf2f2 to 17065d5 Compare August 9, 2025 01:13
@qeternity qeternity force-pushed the marlin-enable-fp8-block branch from 17065d5 to 2fccefe Compare August 9, 2025 01:14
@zhyncs zhyncs merged commit e483ab6 into sgl-project:main Aug 19, 2025
@PaulRoeseler
Copy link

Hi, thank you for this merge request!

I've tested this branch to confirm the fix. I had to build vllm from source to align with sglang dependency torch==2.8 environment.

Here are my findings:

  • Success: The changes work correctly for non-MoE FP8 models. I was able to successfully launch the server with Qwen/Qwen3-4B-Thinking-2507-FP8.

  • Remaining Issue: The server still fails to launch for MoE FP8 models like Qwen/Qwen3-235B-A22B-Thinking-2507-FP8. The crash occurs during CUDA graph compilation with the same Triton error as before.

Environment:

  • GPU: 4xNVIDIA A100 80GB
  • PyTorch: 2.8.0
  • SGLang/vLLM: Built from source on this branch's commit.

Failing Command:

python -m sglang.launch_server --model-path Qwen/Qwen3-235B-A22B-Thinking-2507-FP8 --tp 4

Key Error:

triton.compiler.errors.CompilationError: ... ValueError("type fp8e4nv not supported in this architecture. The supported fp8 dtypes are ('fp8e4b15', 'fp8e5')")

It seems the core issue is specific to the fused_moe_kernel and its FP8 implementation on Ampere-architecture GPUs. Or am I doing something wrong? And do you plan to add support for MoE models too? I know that it works at vLLM already.

@qeternity
Copy link
Contributor Author

Hi @Anaudia - yes, this does not cover the MoE Marlin implementation. I am not familiar with the inner workings of that and we don't use MoEs in production.

@ehuaa
Copy link
Contributor

ehuaa commented Aug 27, 2025

Hi, thank you for this merge request!

I've tested this branch to confirm the fix. I had to build vllm from source to align with sglang dependency torch==2.8 environment.

Here are my findings:

  • Success: The changes work correctly for non-MoE FP8 models. I was able to successfully launch the server with Qwen/Qwen3-4B-Thinking-2507-FP8.
  • Remaining Issue: The server still fails to launch for MoE FP8 models like Qwen/Qwen3-235B-A22B-Thinking-2507-FP8. The crash occurs during CUDA graph compilation with the same Triton error as before.

Environment:

  • GPU: 4xNVIDIA A100 80GB
  • PyTorch: 2.8.0
  • SGLang/vLLM: Built from source on this branch's commit.

Failing Command:

python -m sglang.launch_server --model-path Qwen/Qwen3-235B-A22B-Thinking-2507-FP8 --tp 4

Key Error:

triton.compiler.errors.CompilationError: ... ValueError("type fp8e4nv not supported in this architecture. The supported fp8 dtypes are ('fp8e4b15', 'fp8e5')")

It seems the core issue is specific to the fused_moe_kernel and its FP8 implementation on Ampere-architecture GPUs. Or am I doing something wrong? And do you plan to add support for MoE models too? I know that it works at vLLM already.

Hi @Anaudia , can you share your command of installing vLLM from source in sglang docker, when I install from source, it still uninstalled torch 2.8.0 and install torch 2.7

@qeternity
Copy link
Contributor Author

@ehuaa please consult the vLLM docs for instructions on how to use an existing pytorch installation.

@ehuaa
Copy link
Contributor

ehuaa commented Aug 27, 2025

@ehuaa please consult the vLLM docs for instructions on how to use an existing pytorch installation.

Thank you @qeternity, you really saved my day. I'll try it later.

@ehuaa
Copy link
Contributor

ehuaa commented Aug 28, 2025

Hi, thank you for this merge request!

I've tested this branch to confirm the fix. I had to build vllm from source to align with sglang dependency torch==2.8 environment.

Here are my findings:

  • Success: The changes work correctly for non-MoE FP8 models. I was able to successfully launch the server with Qwen/Qwen3-4B-Thinking-2507-FP8.
  • Remaining Issue: The server still fails to launch for MoE FP8 models like Qwen/Qwen3-235B-A22B-Thinking-2507-FP8. The crash occurs during CUDA graph compilation with the same Triton error as before.

Environment:

  • GPU: 4xNVIDIA A100 80GB
  • PyTorch: 2.8.0
  • SGLang/vLLM: Built from source on this branch's commit.

Failing Command:

python -m sglang.launch_server --model-path Qwen/Qwen3-235B-A22B-Thinking-2507-FP8 --tp 4

Key Error:

triton.compiler.errors.CompilationError: ... ValueError("type fp8e4nv not supported in this architecture. The supported fp8 dtypes are ('fp8e4b15', 'fp8e5')")

It seems the core issue is specific to the fused_moe_kernel and its FP8 implementation on Ampere-architecture GPUs. Or am I doing something wrong? And do you plan to add support for MoE models too? I know that it works at vLLM already.

Hi @Anaudia you can check this pr #9754, I have verified on Qwen/Qwen3-30B-A3B-Thinking-2507-FP8, Qwen3-235B-A22B-Thinking-2507-FP8, and DeepSeek-V3.1-FP8 on 2*8gpus

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Migrate support for FP8 in Ampere GPUs

5 participants