Skip to content

[Feature] NVFP4 Marlin fallback for non-Blackwell GPUs (SM75+)#19652

Merged
BBuf merged 44 commits intosgl-project:mainfrom
Godmook:nvfp4-marlin-fallback
Apr 3, 2026
Merged

[Feature] NVFP4 Marlin fallback for non-Blackwell GPUs (SM75+)#19652
BBuf merged 44 commits intosgl-project:mainfrom
Godmook:nvfp4-marlin-fallback

Conversation

@Godmook
Copy link
Copy Markdown
Contributor

@Godmook Godmook commented Mar 2, 2026

Motivation

Related Issue: #19491
NVFP4-quantized models (e.g., nvidia/Llama-3.1-8B-Instruct-NVFP4, nvidia/DeepSeek-V3-0324-FP4, mistralai/Minimax-M2.5-NVFP4) crash immediately on non-Blackwell GPUs because get_min_capability() returned 100.
This forces users on A100/A40/H100/RTX 3090 to fall back to less accurate quantization (AWQ/GPTQ) or switch to vLLM, which already supports this via Marlin fallback. This PR brings equivalent functionality to SGLang.

Key properties:

  • Weights remain compressed in FP4 (2 values per byte) — no VRAM explosion unlike naive BF16 dequantization
  • FP4 dequantization happens in-flight inside the Marlin kernel via fast bitwise operations during tensor core matmul
  • Fully automatic — no user-side flags needed; fallback is triggered based on GPU capability detection
  • All kernel changes are gated by w_type == kFE2M1f (FP4 only) — zero impact on existing INT4/INT8/FP8 quantization paths

Modifications

Summary

File Type Change
marlin_utils_fp4.py New Core NVFP4 Marlin utilities (prepare, apply, scale conversion)
test_nvfp4_marlin_fallback.py New 6 unit tests covering full fallback stack
marlin/marlin_template.h (JIT) Kernel fix Fix FP4 scale stride/group mismatch — aligns with vLLM
marlin/marlin_template.h (sgl-kernel) Kernel fix Same fix as JIT — identical changes
marlin_moe/marlin_template.h Kernel fix Uncomment dequant_fp8_scales() for kFE2M1f; fix global_scale application
marlin_moe/moe_wna16_marlin.cuh Modified Guard FP4 kernel instantiation behind SGL_MOE_MARLIN_FP4 macro
moe_wna16_marlin.py Modified Add separate FP4 JIT module to avoid compile-time regression for non-FP4
fused_marlin_moe.py Modified Add w1/w2_global_scale params; detect FP4 Marlin mode
moe_runner/marlin.py Modified Add w13/w2_global_scale fields to MarlinMoeQuantInfo
compressed_tensors_w4a4_nvfp4.py Modified get_min_capability 100→75; Marlin fallback in linear path
compressed_tensors_w4a4_nvfp4_moe.py Modified get_min_capability 100→75; Marlin fallback in MoE path
modelopt_quant.py Modified get_min_capability 100→75; Marlin fallback in linear + MoE path
weight_utils.py Modified Explicit FP8/FP4 config routing for modelopt checkpoints
environ.py Modified Add SGLANG_FORCE_NVFP4_MARLIN env var
environment_variables.md Modified Document new env var

1. New: marlin_utils_fp4.py

Core utility module. All Marlin-specific preparation and inference logic is isolated here.

Function Description
is_fp4_marlin_supported() Returns True if SM ≥ 75
nvfp4_marlin_process_scales() Converts per-group FP8-S1E4M3 scales to FP8-S0E5M3 format required by Marlin dequant kernel
nvfp4_marlin_process_global_scale() Pre-adjusts FP4↔BF16 exponent bias for gptq_marlin_gemm
apply_fp4_marlin_linear() Runs linear inference via gptq_marlin_gemm with float4_e2m1f scalar type
prepare_fp4_layer_for_marlin() Repacks linear weights (gptq_marlin_repack), permutes + converts scales, creates workspace
prepare_moe_fp4_layer_for_marlin() Same as above, applied per-expert for MoE layers

2. Kernel Bug Fixes

2a. Linear kernel: marlin/marlin_template.h (JIT + sgl-kernel)

Root cause: SGLang's Marlin kernel had incorrect FP4 scale stride calculations that diverged from vLLM's corrected implementation. The kernel assumed FP16-sized (2-byte) elements for scale strides, but NVFP4 uses FP8 scales (1 byte each). This caused the kernel to read every other scale row and misalign scale-to-weight-group mapping, producing garbage output on all non-Blackwell GPUs.

6 changes, all gated by w_type == kFE2M1f (zero impact on INT4/INT8/FP8 paths):

Variable Before (FP4 path) After (FP4 path) Impact on non-FP4
s_gl_stride prob_n / 8 prob_n / 16 None (ternary returns 8)
s_sh_stride 16 * thread_n_blocks / 8 16 * thread_n_blocks / 16 None (ternary returns 8)
s_tb_groups thread_k_blocks / group_blocks / 2 thread_k_blocks / group_blocks None (ternary returned /1)
s_gl_rd ... / 2 + ... ... + ... None (ternary returned /1)
s_sh_rd Special *2 + warp_row%2 block Shared logic with other types None (deleted block was FP4-only)
cur_group_id k_blocks / (group_blocks * 2) k_blocks / group_blocks None (ternary returned *1)

These changes align SGLang's Marlin kernel with vLLM's corrected implementation.

2b. MoE kernel: marlin_moe/marlin_template.h

Same As Linear Kernel.


3. Fallback Integration — Both Quantization Formats

Both supported NVFP4 quantization formats now have full Marlin fallback:

Format Files Changed get_min_capability Linear MoE
compressed-tensors W4A4 NVFP4 compressed_tensors_w4a4_nvfp4.py, .._moe.py 100 → 75
ModelOpt NVFP4 modelopt_quant.py 100 → 75

Fallback trigger flow (automatic, no user action needed):

Model Load
└─ get_min_capability() = 75 → SM75+ GPU accepted
 
process_weights_after_loading()
└─ not is_blackwell_supported() AND is_fp4_marlin_supported()
   ├─ prepare_fp4_layer_for_marlin()      # linear layers
   ├─ prepare_moe_fp4_layer_for_marlin()  # MoE layers
   └─ layer.use_marlin_fallback = True
 
Inference (forward pass)
└─ if use_marlin_fallback:
   ├─ apply_fp4_marlin_linear()           # linear: gptq_marlin_gemm
   └─ runner.run(MarlinMoeQuantInfo)      # MoE: moe_wna16_marlin_gemm

Blackwell (SM100+) is completely unaffected — the fallback path is only taken when is_blackwell_supported() returns False. Native FP4 inference continues to use the existing Blackwell-optimized path. Users can force the Marlin fallback on Blackwell for testing via SGLANG_FORCE_NVFP4_MARLIN=1.


4. New: test/registered/quant/test_nvfp4_marlin_fallback.py

Test What it validates
test_is_fp4_marlin_supported SM capability detection (SM ≥ 75 → True)
test_min_capability_changed ModelOptFp4Config.get_min_capability() == 75
test_nvfp4_marlin_process_scales FP8 scale conversion produces no NaN
test_prepare_and_apply_fp4_marlin_linear Linear layer prepare + inference end-to-end
test_fused_marlin_moe_fp4 MoE Marlin GEMM end-to-end (previously NaN, now fixed)
test_prepare_moe_fp4_layer_for_marlin MoE weight repacking shape correctness

Accuracy Tests

Unit tests (A40, SM86)

All 6 unit tests pass:

$ python test/registered/quant/test_nvfp4_marlin_fallback.py
GPU: NVIDIA A40 (SM86)
SM86: Testing Marlin FP4 fallback (non-Blackwell GPU).
test_is_fp4_marlin_supported ... ok
test_min_capability_changed ... ok
test_nvfp4_marlin_process_scales ... ok
test_prepare_and_apply_fp4_marlin_linear ... ok
test_fused_marlin_moe_fp4 ... ok
test_prepare_moe_fp4_layer_for_marlin ... ok
 
Ran 6 tests in 8.191s
OK

Benchmarking and Profiling

This PR is a correctness fix that unblocks a previously unsupported code path. Comparison against native Blackwell FP4 is not applicable since the target hardware is different.

  • Memory efficiency: weights remain in FP4 format (same VRAM as native FP4), unlike naive BF16 dequantization which would 4× the weight memory
  • Throughput on non-Blackwell GPUs will be lower than native Blackwell FP4 — this is expected and noted in the startup warning. End-to-end throughput benchmarks against AWQ/GPTQ alternatives on A100 are left for follow-up

E2E benchmark (nvidia/Llama-3.1-8B-Instruct-NVFP4 on A100-40GB)

Benchmark Metric NVFP4 Marlin (this PR) BF16 Baseline
GSM8K Accuracy (CoT, 8-shot) 74.1% 84.5% ✅ Meta official, reproducible
ARC-Challenge acc / acc_norm (25-shot) 56.3% / 60.8% 83.4% (0-shot) ❌ Unfair comparison — Meta internal 0-shot eval format; community lm-eval 25-shot results are typically ~60% (When I searched on Google, people usually said it is about 60%, but I'm not sure)
HellaSwag acc / acc_norm (10-shot) 0.5849 / 0.7944 ~0.60 / ~0.8777 ⚠️ Not in Meta official report; based on Open LLM Leaderboard community measurements

*BF16 baseline from Meta's published Llama 3.1 8B Instruct results. Note: FP4 quantization inherently trades accuracy for 4× weight compression. Different evaluation protocols (n-shot, prompt format) also affect absolute numbers. The key validation is that output is coherent and numerically reasonable — before this PR, the same model produced garbage text ("ίνα خوش DBHelper troubled WRITE...").

E2E benchmark (nvidia/Qwen3-30B-A3B-NVFP4 on A100-80GB)

Benchmark Metric NVFP4 Marlin (this PR)
GSM8K Accuracy (CoT, 8-shot) 90.1%
GSM8K Invalid 0.000
GSM8K Latency 151.5 s
GSM8K Output throughput 1130.4 token/s
ARC-Challenge acc (25-shot) 0.6689 ± 0.0138
ARC-Challenge acc_norm (25-shot) 0.6894 ± 0.0135
"text": "спіль688вати/autoload debaclerobat航空unctionistine(mappedatoi лиш/il Brofern.Delay"

After this PR (A100):

"text": "5. It is not real. To consider that 2 + 2 could"

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the compatibility of NVFP4-quantized models by introducing a Marlin kernel fallback for GPUs with compute capability SM75 and above, but below Blackwell (SM100). This change allows users on a wider range of NVIDIA hardware, including A100 and RTX 3090, to leverage the memory efficiency of FP4 quantization without encountering crashes. The implementation ensures that weights remain compressed, and the fallback is entirely automatic. A crucial bug causing NaN outputs in the Mixture-of-Experts (MoE) kernel for FP4 inference has also been addressed, improving the stability and correctness of the system.

Highlights

  • NVFP4 Marlin Fallback for Non-Blackwell GPUs: Implemented a fallback mechanism allowing NVFP4-quantized models to run on non-Blackwell GPUs (SM75+), such as A100/A40/RTX 3090, preventing crashes and expanding hardware compatibility.
  • Memory Efficiency Preserved: Ensured that weights remain compressed in FP4 format (2 values per byte) in VRAM, avoiding the memory explosion associated with naive BF16 dequantization.
  • Automatic Fallback Detection: The system automatically detects GPU capability and triggers the Marlin fallback without requiring any user-side flags.
  • MoE Kernel NaN Fix: Resolved a critical bug in the marlin_moe/marlin_template.h kernel where dequant_fp8_scales() was commented out for kFE2M1f, leading to NaN outputs for FP4 MoE inference.
  • New Utility Module and Tests: Introduced marlin_utils_fp4.py for core NVFP4 Marlin utilities and test_nvfp4_marlin_fallback.py with 6 unit tests covering the full fallback stack.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h
    • Uncommented and clarified the dequant_fp8_scales() function for kFE2M1f to correctly convert FP8 scales to BF16/FP16, resolving NaN output issues in MoE kernels.
  • python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py
    • Added w1_global_scale and w2_global_scale parameters to fused_marlin_moe.
    • Implemented logic to detect FP4 Marlin mode based on the presence of global scales.
    • Adjusted assertions for scale dtypes to accommodate FP4 Marlin's special float8_e4m3fn format.
    • Modified scalar type assignment to use float4_e2m1f when in FP4 Marlin mode.
    • Passed w1_global_scale and w2_global_scale to the moe_wna16_marlin_gemm kernel.
  • python/sglang/srt/layers/moe/moe_runner/marlin.py
    • Added w13_global_scale and w2_global_scale fields to MarlinMoeQuantInfo for FP4 Marlin-specific global scales.
    • Included global_num_experts in MarlinMoeQuantInfo.
    • Passed global_num_experts, w13_global_scale, and w2_global_scale to the fused_marlin_moe function.
  • python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
    • Updated get_min_capability() from 100 to 75 to enable Marlin FP4 fallback on older architectures.
    • Added params_dtype attribute to the layer.
    • Integrated Marlin FP4 fallback logic into process_weights_after_loading() for linear layers, including global scale consolidation and weight repacking.
    • Implemented apply_weights() to use apply_fp4_marlin_linear() when Marlin fallback is active.
  • python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py
    • Modified the constructor to check for is_fp4_marlin_supported() on non-Blackwell GPUs, raising an error if not supported and logging a warning if fallback is used.
    • Updated get_min_capability() from 100 to 75.
    • Added intermediate_size_per_partition to the layer.
    • Integrated Marlin FP4 fallback logic into process_weights_after_loading() for MoE layers, including creating w*_weight_scale_2 and calling prepare_moe_fp4_layer_for_marlin().
    • Modified create_moe_runner() to use MoeRunnerBackend.MARLIN when fallback is active.
    • Implemented apply_weights() to use MarlinMoeQuantInfo and runner.run() for Marlin fallback.
  • python/sglang/srt/layers/quantization/marlin_utils_fp4.py
    • Added a new file containing utility functions for NVFP4 Marlin fallback.
    • Included is_fp4_marlin_supported() to check for SM75+ capability.
    • Defined nvfp4_marlin_process_scales() to convert FP8-S1E4M3 scales to FP8-S0E5M3.
    • Implemented nvfp4_marlin_process_global_scale() to pre-adjust global scales for exponent bias.
    • Provided apply_fp4_marlin_linear() for running FP4-quantized linear operations via the Marlin kernel.
    • Created prepare_fp4_layer_for_marlin() to repack linear layer weights and process scales/bias for Marlin.
    • Developed prepare_moe_fp4_layer_for_marlin() to repack MoE layer weights and process scales for Marlin.
  • python/sglang/srt/layers/quantization/modelopt_quant.py
    • Updated get_min_capability() from 100 to 75.
    • Added params_dtype attribute to the layer.
    • Integrated Marlin FP4 fallback logic into process_weights_after_loading() for linear layers, including global scale consolidation and weight repacking.
    • Implemented apply() to use apply_fp4_marlin_linear() when Marlin fallback is active.
    • Modified the constructor of ModelOptNvFp4FusedMoEMethod to check for is_fp4_marlin_supported() on non-Blackwell GPUs, raising an error if not supported and logging a warning if fallback is used.
    • Integrated Marlin FP4 fallback logic into process_weights_after_loading() for MoE layers.
    • Modified create_moe_runner() to use MoeRunnerBackend.MARLIN when fallback is active.
    • Implemented apply() for MoE layers to use MarlinMoeQuantInfo and runner.run() for Marlin fallback.
  • test/srt/test_nvfp4_marlin_fallback.py
    • Added a new test file for NVFP4 Marlin fallback functionality.
    • Included TestNvfp4MarlinLinear with tests for prepare_and_apply_fp4_marlin_linear and nvfp4_marlin_process_scales.
    • Added TestNvfp4MarlinMoe with tests for fused_marlin_moe_fp4 and prepare_moe_fp4_layer_for_marlin.
    • Included TestFp4MarlinSupport with tests for is_fp4_marlin_supported and get_min_capability_changed.
Activity
  • All 6 unit tests were run and passed successfully on an NVIDIA A40 (SM86) GPU.
  • Code formatted according to the contribution guide.
  • Unit tests added and passed.
  • Documentation updated.
  • Accuracy and speed benchmark results provided (noting this is a correctness fix, not a performance comparison against native FP4).
  • SGLang code style followed.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant feature: a Marlin fallback mechanism for NVFP4-quantized models, enabling them to run on non-Blackwell GPUs (SM75+). This is a great enhancement for broader hardware compatibility. The changes are well-structured, with the core logic encapsulated in the new marlin_utils_fp4.py module, and comprehensive unit tests are included. The PR also contains a critical bug fix for MoE layers that was causing NaN outputs. My review focuses on improving maintainability by refactoring duplicated code and enhancing clarity around some of the complex quantization logic.

Comment on lines +1789 to +1814
if self.use_marlin_fallback:
from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo

expert_map = None
global_num_experts = -1
if hasattr(layer, "dispatcher") and hasattr(
layer.dispatcher, "local_expert_mapping"
):
expert_map = layer.dispatcher.local_expert_mapping
if expert_map is not None:
global_num_experts = moe_runner_config.num_experts

quant_info = MarlinMoeQuantInfo(
w13_qweight=layer.w13_weight,
w2_qweight=layer.w2_weight,
w13_scales=layer.w13_weight_scale,
w2_scales=layer.w2_weight_scale,
w13_g_idx_sort_indices=None,
w2_g_idx_sort_indices=None,
weight_bits=4,
w13_global_scale=layer.w13_weight_scale_2,
w2_global_scale=layer.w2_weight_scale_2,
expert_map=expert_map,
global_num_experts=global_num_experts,
)
return self.runner.run(dispatch_output, quant_info)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for the Marlin FP4 fallback path is identical to the one in python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py. To improve maintainability and reduce code duplication, I recommend refactoring this logic into a shared helper function.

For example, you could add a function like apply_fp4_marlin_moe to python/sglang/srt/layers/quantization/marlin_utils_fp4.py and call it from both places. This would centralize the logic for creating MarlinMoeQuantInfo and running the MoE layer.

Here's a sketch of what the helper could look like:

# In marlin_utils_fp4.py
def apply_fp4_marlin_moe(
    layer: torch.nn.Module,
    runner: "MoeRunner",
    moe_runner_config: "MoeRunnerConfig",
    dispatch_output: "StandardDispatchOutput",
) -> "CombineInput":
    from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo

    expert_map = None
    global_num_experts = -1
    if hasattr(layer, "dispatcher") and hasattr(
        layer.dispatcher, "local_expert_mapping"
    ):
        expert_map = layer.dispatcher.local_expert_mapping
        if expert_map is not None:
            global_num_experts = moe_runner_config.num_experts

    quant_info = MarlinMoeQuantInfo(
        w13_qweight=layer.w13_weight,
        w2_qweight=layer.w2_weight,
        w13_scales=layer.w13_weight_scale,
        w2_scales=layer.w2_weight_scale,
        w13_g_idx_sort_indices=None,
        w2_g_idx_sort_indices=None,
        weight_bits=4,
        w13_global_scale=layer.w13_weight_scale_2,
        w2_global_scale=layer.w2_weight_scale_2,
        expert_map=expert_map,
        global_num_experts=global_num_experts,
    )
    return runner.run(dispatch_output, quant_info)

@DarkSharpness
Copy link
Copy Markdown
Collaborator

cc @celve if you have any context

@ciprianveg
Copy link
Copy Markdown

cool. thank you for this.

@adhikjoshi
Copy link
Copy Markdown

This is actually great

@DarkSharpness
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main concern on the kernel side is test coverage. The new tests mostly check shape/dtype and that the output is not NaN, but they do not compare the new FP4 Marlin path against a reference implementation numerically. For changes this deep in scale indexing/dequant logic, I would feel much better with a correctness test against a trusted baseline, not just a no-NaN check.

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Mar 30, 2026

/rerun-failed-ci

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Mar 31, 2026

/rerun-failed-ci

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Mar 31, 2026

/rerun-failed-ci

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 1, 2026

python3 /actions-runner/_work/sglang/sglang/test/registered/quant/test_nvfp4_marlin_fallback.py
.
.

[CI Test Method] TestFp4MarlinSupport.test_is_fp4_marlin_supported
test_is_fp4_marlin_supported (__main__.TestFp4MarlinSupport) ... ok
[CI Test Method] TestFp4MarlinSupport.test_min_capability_changed
[CI Test Method] TestFp4MarlinSupport.test_should_use_fp4_marlin_fallback
test_min_capability_changed (__main__.TestFp4MarlinSupport)
get_min_capability() must return 75 (not 100). ... ok
test_should_use_fp4_marlin_fallback (__main__.TestFp4MarlinSupport)
should_use_fp4_marlin_fallback returns True on non-Blackwell SM>=75. ... ok
test_fake_apply_fp4_marlin_linear (__main__.TestNvfp4MarlinLinear)
Fake impl for PCG tracing must return the correct shape and dtype. ... ok
test_fp4_marlin_3d_input (__main__.TestNvfp4MarlinLinear)
Verify correct reshape for 3-D input (batch, seq_len, K). ... Your GPU does not have native support for FP4 computation but FP4 quantization is being used. Weight-only FP4 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
ok
test_fp4_marlin_custom_op_registration (__main__.TestNvfp4MarlinLinear)
apply_fp4_marlin_linear must be registered as torch.ops.sglang for PCG. ... ok
test_fp4_marlin_linear_with_bias (__main__.TestNvfp4MarlinLinear)
Verify output_with_bias == output_no_bias + bias. ... ok
test_fp4_marlin_multiple_shapes (__main__.TestNvfp4MarlinLinear)
Numerical correctness across various (M, N, K) dimensions. ... ok
test_fp4_marlin_numerical_correctness (__main__.TestNvfp4MarlinLinear)
Kernel output vs BF16 dequant reference (cosine sim, MAE, assert_close). ... ok
test_fp4_marlin_registered_op_numerical (__main__.TestNvfp4MarlinLinear)
torch.ops.sglang.apply_fp4_marlin_linear matches the direct Python call. ... ok
test_nvfp4_marlin_scale_values_correctness (__main__.TestNvfp4MarlinLinear)
Verify scale conversion produces analytically correct values. ... ok
test_prepare_and_apply_fp4_marlin_linear (__main__.TestNvfp4MarlinLinear)
Smoke test: shape and dtype are correct after prepare + apply. ... ok
test_prepare_fp4_layer_permutes_bias (__main__.TestNvfp4MarlinLinear)
prepare_fp4_layer_for_marlin must permute layer.bias when present. ... ok
test_prepare_rejects_bad_weight_shape (__main__.TestNvfp4MarlinLinear)
prepare_fp4_layer_for_marlin must raise on mismatched weight shape. ... ok
test_fused_marlin_moe_fp4 (__main__.TestNvfp4MarlinMoe)
Smoke test: shape, dtype, no NaN for multi-expert MoE. ... <frozen importlib._bootstrap_external>:1184: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
<frozen importlib._bootstrap_external>:1184: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.
ok
test_fused_marlin_moe_fp4_numerical (__main__.TestNvfp4MarlinMoe)
E=1, topk=1 MoE output vs dequant reference (SiLU-gated). ... ok
test_prepare_moe_fp4_layer_for_marlin (__main__.TestNvfp4MarlinMoe)
Weight repacking produces correct shapes for all expert tensors. ... Your GPU does not have native support for FP4 computation but FP4 quantization is being used. Weight-only FP4 compression will be used leveraging the Marlin kernel for MoE layers. This may degrade performance for compute-heavy workloads.
ok

----------------------------------------------------------------------
Ran 17 tests in 409.588s

@BBuf All nvpf4 Kernel tests are passed I think. I'll try to fix the other CI fails. Thanks for help!

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 1, 2026

/rerun-failed-ci

4 similar comments
@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 1, 2026

/rerun-failed-ci

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 2, 2026

/rerun-failed-ci

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 2, 2026

/rerun-failed-ci

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 3, 2026

/rerun-failed-ci

@BBuf BBuf merged commit 991f3aa into sgl-project:main Apr 3, 2026
370 of 442 checks passed
@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 3, 2026

Thanks for help! @BBuf. It seems like I made things a bit hectic, so I’m sorry if I bothered you while you were busy. As a big fan of SGLang, I found it really fun! Thank you so much for your help, especially when you’re busy 😊

@Godmook Godmook deleted the nvfp4-marlin-fallback branch April 3, 2026 03:13
@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 3, 2026

@ciprianveg It is merged. You can use it! :)

@ciprianveg
Copy link
Copy Markdown

@ciprianveg It is merged. You can use it! :)

thank you very much for your relentless work here!

realray808 pushed a commit to Ascend/sglang that referenced this pull request Apr 3, 2026
* [AMD] Fix AMD CI monitor GitHub API rate limit exhaustion (sgl-project#21527)

* [CI] Register missing jit_kernel test files (sgl-project#21547)

* [diffusion] fix: return None instead of raising RuntimeError when no model info found (sgl-project#21319)

Co-authored-by: Mick <mickjagger19@icloud.com>

* [rl][sgl] fix tensor mismatch after pause (sgl-project#21514)

* [Hicache & JIT_kernel] Support page first layout  & mla jit kernel (sgl-project#18311)

* test: point DSV3 int8 MLA CI models to lmsys Hugging Face org (sgl-project#21561)

* [CI] Relax several thresholds in flaky CIs (sgl-project#21562)

* feat: add gc_threshold arg (sgl-project#21481)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Fix flaky test_pp_single_node (sgl-project#21564)

* Split workflow for releasing runtime docker (sgl-project#21563)

* fix tp capture in vit cuda graph (sgl-project#17255)

* [1/n] lora support - Auto detect lora target modules (sgl-project#21439)

Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>

* [fix] qwen3.5 fuse_moe_triton_tune bug (sgl-project#20232)

* Remove sync when enabling return_logprob (sgl-project#20972)

* Scope streaming backlog coalescing to incremental_streaming_output mode (sgl-project#21037)

Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>

* docs: flesh out MAINTAINER.md oncall lists and link GitHub profiles (sgl-project#21575)

* [NVIDIA] Enable automatic NUMA configuration (sgl-project#19452)

* [diffusion] UX: aggregate expected dtype-cast logs during weight loading (sgl-project#21552)

* [diffusion] refactor: Unify `TeaCacheParams` and `WanTeaCacheParams` (sgl-project#20706)

Co-authored-by: Mick <mickjagger19@icloud.com>

* [diffusion] chore: remove redundant identity preprocess_text functions(sgl-project#20633)

Co-authored-by: Fengyuan Yu <15fengyuan@gmail.com>

* Update CODEOWNERS for transformers.py and docs (sgl-project#21555)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>

* reduce CPU peak memory in multimodal tensor hashing (sgl-project#21123)

* Fix HFRunner hang when subprocess dies during init (sgl-project#21582)

* Fix Piecewise CUDA Graph crash with `-enable-mixed-chunk` (sgl-project#20441)

Co-authored-by: jianyingzhu <joeyzhu@nvidia.com>

* [CI] Replace upload/download-artifact with job outputs in release-docker workflow (sgl-project#21579)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Patch transformers is_base_mistral in CI to avoid HF 429 rate limiting (sgl-project#21586)

* [CI] Move v32 cp test to deepep running suite (sgl-project#21585)

* [AMD] Add GLM-4.7-FP8 accuracy CI test for MI35x (sgl-project#21534)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* [Clean] Remove deprecated environs (sgl-project#21536)

* [diffusion] fix: fix Flux2-Klein prompt tokenization length to 512 and add regression coverage (sgl-project#21407)

* [CI] hot-fix ci lint (sgl-project#21608)

* [diffusion] feat: support overlay model materialization (sgl-project#21600)

* [VLM] Optimize ShmPointerMMData for multi-pickle safety and deferred unwrap (sgl-project#21465)

* feat: enable CUDA graph and timestamp for the whisper model(sgl-project#21190)

* [NPU] Update quantization&CI documentation (sgl-project#21100)

Co-authored-by: Tamir Baydasov <41994229+TamirBaydasov@users.noreply.github.com>

* Skip ci for .md files (sgl-project#21482)

* Support skip-softmax attention (sgl-project#19089)

* fix: piecewise_cuda_graph get correct qo_indptr (sgl-project#21452)

Co-authored-by: Avery Huang <averyh@nvidia.com>

* fix bench_serving sglang backend to support image dataset  (sgl-project#21294)

* [AMD] Add peft>=0.18.0 to diffusion_hip deps for transformers 5.x compat for AMD diffusion model (sgl-project#21442)

Co-authored-by: HaiShaw <hixiao@gmail.com>

* [GDN] Fuse GDN kkt + solve_tril into one kernel (sgl-project#21411)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>

* [Diffusion] Align diffusion benchmark skill presets with nightly comparison cases (sgl-project#21616)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Clean up detokenizer and remove dead multimodal_gen code (sgl-project#21588)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [CI] Skip flaky elastic EP test (sgl-project#21619)

* feat(ci): add GB300 nightly benchmark test suites (sgl-project#21487)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [CI] Lossen test_return_routed_experts threshold (sgl-project#21270)

* Add subprocess liveness monitor to detect scheduler crashes (sgl-project#18582)

Co-authored-by: 继优 <jiyou.ljy@alibaba-inc.com>
Co-authored-by: shuwenn <47200617+alphabetc1@users.noreply.github.com>

* fix: scheduler launch hang when non-current rank dies (sgl-project#20287)

* Wrap IPv6 addresses in gRPC, bench_serving, and log messages (sgl-project#21236)

Co-authored-by: hnyls2002 <lsyincs@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>

* [HiCache] fix: graceful shutdown of pending async tasks in bench_mix.py (sgl-project#20276)

* Clean up _wait_for_scheduler_ready implementation (sgl-project#21626)

* fix cuda graph capturing error in sm120 mxfp8 triton path (sgl-project#19835)

* [sgl] disable piecewise cuda graph when a model doesn't have layers (sgl-project#21565)

* [Feature] Optimizations for JPEG input on NVIDIA GPU (sgl-project#19749)

* [VLM] perf: optimize CUDA IPC for multimodal transfer by caching IPC pool handles (sgl-project#21418)

* [Fix] SGLANG_USE_CUDA_IPC_TRANSPORT=1 and SGLANG_ENABLE_MM_SPLITTING=1 do not work at the same time. (sgl-project#19915)

* [Fix] Remove redundant allreduce fusion block and skip TP=1 (sgl-project#20621)

* Simplify routed experts test and move base64 encoding to tokenizer manager (sgl-project#21634)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [Cleanup] Remove unused BatchMultimodalOutput and BatchMultimodalDecodeReq (sgl-project#21640)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Clean up TokenizerManager: remove dead code and improve rid validation (sgl-project#21639)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* README: coding agent sponsorship for long-term contributors (sgl-project#21642)

* Fix circular reference in CustomTestCase.__init_subclass__ (sgl-project#21650)

Co-authored-by: wan4ch <wan4ch@gmail.com>

* [Fix] Fix Qwen3.5 MoE model loading and Mamba cache sharding in PP mode (sgl-project#21448)

Co-authored-by: zhangxiaolei123456 <zhangxiaolei.666@bytedance.com>

* [diffusion] CI: fix dashboard chart (nightly) display issues (sgl-project#21653)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update sponsorship details in README.md (sgl-project#21658)

* [Fix] Handle pre-release tags in nightly wheel version parsing (sgl-project#21656)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [Intel GPU] Enable DeepSeek R1 inference on XPU (sgl-project#18461)

Signed-off-by: P V R K Jyothendra Varma <polisetty.v.r.k.jyothendra.varma@intel.com>

* [Doc] Update tips for developer new-comers (sgl-project#21659)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* [CI] [FlashInfer v0.6.7] Use offline quantized checkpoint for MXFP8 Gemm tests (sgl-project#21625)

* MFU metrics in Prometheus  (sgl-project#19395)

* fix topk softmax performance issue (sgl-project#14702)

* [CPU] add kernel apply_rotary_pos_emb_cpu for Qwen3-VL and Qwen3-Omni (sgl-project#13121)

Co-authored-by: Ma Mingfei <mingfei.ma@intel.com>

* [CPU] Implement MXFP4 Gemm kernels for intel AMX to support GPT OSS series. (sgl-project#14385)

* [AMD] Fused rope kv store (sgl-project#21315)

Co-authored-by: wunhuang <wunhuang@amd.com>

* [NPU] Update DeepSeek-V3.2 model deployment instructions in documentation (sgl-project#21468)

Co-authored-by: wuxue (C) <w00964934@china.huawei.com>

* [AMD] Support AMD MXFP4 Qwen3.5-397B-A17B model (sgl-project#21234)

* [Fix] Fix weight_loader property assignment for qwen3-next FP8 models (sgl-project#21662)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix mamba cache leak when adder fails to add a matched req. (sgl-project#21404)

* fix: Mistral Small 4 fails to start due to config/weight format mismatch (sgl-project#21620)

Co-authored-by: mengxiancheng03 <mengxiancheng03@kuaishou.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [diffusion] feat: enhance overlay mechanism (sgl-project#21648)

* [diffusion] CI: relax pr-test threshold (sgl-project#21682)

* [NPU][Diffusion] fix sp modulate for qwen-image-edit (sgl-project#20974)

Co-authored-by: 高鑫 <gaoxin@gaoxindeMacBook-Pro.local>

* [NPU] fix eagle3 accept rate (sgl-project#21255)

* DeepSeek-R1-0528-w4a8: DeepEP Low Latency Dispatch Adopts FP8 Communication (sgl-project#14162)

Co-authored-by: undefined <zhouchen.arrebol@jd.com>

* [NPU] GLM-5 optimize with fused kernels (sgl-project#18617)

* [NPU][diffusion]: support parallel decoding of qwen-image (sgl-project#20757)

Co-authored-by: 高鑫 <gaoxin@gaoxindeMacBook-Pro.local>

* [diffusion] [NPU] support ring attention on NPU with FA (sgl-project#21383)

* [diffusion][doc]: add ring sp performance benchmark page (sgl-project#20998)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* [GLM-V and GLM-4.7] Cast to FP32 before gate projection for GLM model. (sgl-project#21660)

* fix nemotron capture for non attention layers (sgl-project#21436)

* [Bugfix][NPU] Skip FRACTAL_NZ format for MoE weights with unaligned dimensions (sgl-project#21209)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: ronnie_zheng <zl19940307@163.com>

* [AMD] Add SGLANG_DISAGGREGATION_NUM_PRE_ALLOCATE_REQS env var for configurable KV transfer overlap (sgl-project#20410)

Co-authored-by: HaiShaw <hixiao@gmail.com>

* [AMD][MoRI] bump MoRI to v0.1.0 (sgl-project#21673)

* [AMD] fix performance regression issue when run gpt-oss with "--context-length 13824" (sgl-project#21691)

* Remove flashinfer wheel cache cleanup that deletes other versions (sgl-project#21711)

Co-authored-by: Alison Shao <alison.shao@MacBook-Pro-D2W773R9CD.local>

* [misc] multiprocess compilation to speed up test (sgl-project#21483)

* Fix human-eval CI install on 5090 runners (sgl-project#21714)

Co-authored-by: Alison Shao <alison.shao@Mac.attlocal.net>

* Revert "DeepSeek-R1-0528-w4a8: DeepEP Low Latency Dispatch Adopts FP8 Communication" (sgl-project#21719)

* [Fix] Update supported custom_mem_pool types for mooncake (sgl-project#21728)

Co-authored-by: 百麒 <yaozhong.lyz@alibaba-inc.com>

* [Perf]Remove H2D  for Qwen3.5 SpecV2 (sgl-project#20864)

* [AMD] Fix CI multimodal-gen-test-1-gpu-amd for gen model  (sgl-project#21621)

* [diffusion] fix: fix Flux.2 with tp(sgl-project#21664)

* Add explicit disable flag for FlashInfer allreduce fusion (sgl-project#21446)

* [NPU] fix conflict between empty_cache and use_mem_pool (sgl-project#21507)

* [AMD] Use tgemm.mm for MoEGate router gemm in deepseek_v2.py (sgl-project#21657)

* [CI]Remove msgm-en and mmlu tests which cause timeout (sgl-project#21733)

* Fix disaggregation hybrid attention ci (sgl-project#21745)

* Rename rerun-ut to rerun-test (sgl-project#21747)

* bugfix(model):fix deepstack index out of range error (sgl-project#21727)

Co-authored-by: xiaoqi.31 <xiaoqi.31@jd.com>

* [diffusion] fix: fix typo (sgl-project#21746)

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* [CI] Fix rerun-test suite detection to skip commented registrations (sgl-project#21753)

* [PD] Refactor Disagg Conn and Fix Hang with total_request/total_tokens Balancing (sgl-project#21299)

Co-authored-by: Weiliangl User <weiliangl@login-node.hosted.internal>

* [CI] Fix ring test timeout (sgl-project#21751)

* Enable evict swa with piecewise cuda graph (sgl-project#21754)

* Fix kimi-linear launch server error (sgl-project#21752)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>

* [PD] Tiny cleanup after KVReceiver refactor (sgl-project#21760)

Signed-off-by: Shangming Cai <csmthu@gmail.com>

* Fix remote weight info nnode>1 and dp>1 (sgl-project#17389)

* [diffusion] UX: replace deprecated ORJSONResponse with orjson_response (sgl-project#21755)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* [diffusion] fix: fix Wan2.2-I2V-A14B video max size issue(sgl-project#21390)

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: Mick <mickjagger19@icloud.com>

* [HiMambaTree]: Optimize mamba host lock mechanism (sgl-project#21750)

* [AMD] Fix Handle missing rope_theta in get_rope_config for Grok-1 (sgl-project#21518)

* [bugfix] Fix rope theta config for MiniMax after transformers v5 update (sgl-project#21241)

* Fix ineffective is_base_mistral CI patch for HF API rate limiting (sgl-project#21729)

* [2/n] lora - Shared outer experts and support qwen3_30b_a3b_instruct (sgl-project#21466)

Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>

* Fix cuda graph max bs capture upper bound (sgl-project#21005)

* [Fix] Fall back to triton MOE for GPT-OSS on Blackwell with driver >= 595 (sgl-project#21780)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Cache nvidia wheels locally to skip repeated 830 MB downloads in CI (sgl-project#21778)

* Add Trivy vulnerability scanning to nightly dev Docker builds (sgl-project#21772)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [CI] Remove more redundant PCG tests (sgl-project#21554)

* [moe] add customized option to moe-a2a-backend (sgl-project#21786)

* Add CompletionSampler for non-chat eval in run_eval (sgl-project#21785)

* Remove redundant test_moe_eval_accuracy_large (sgl-project#21787)

* Increase hicache eval to 200 examples (sgl-project#21791)

* Switch MooncakeSpec to EAGLE3 + Llama-3.1 (sgl-project#21794)

* Reduce redundant speculative decoding CI tests (sgl-project#21779)

* Fix killall.py crash when sglang is not yet installed (sgl-project#21797)

* Remove obsolete sgl-kernel legacy paths (sgl-project#21528)

* [jit_kernel] Optimize fused_qknorm_rope: deduplicate sincosf for interleave RoPE  (sgl-project#21654)

* CUTLASS NVFP4 GEMM improvement of SM120 (sgl-project#21314)

* [gRPC] Preserve original ImportError in grpc_server.py (sgl-project#21801)

Signed-off-by: Chang Su <chang.s.su@oracle.com>

* [Misc] Tiny: Add test network timeouts and dynamic max-parallel for 5090/2-gpu runners (sgl-project#21800)

* Fix draft extend cuda graph when spec_step=1 (sgl-project#21709)

* [Diffusion] Add `--uvicorn-access-log-exclude-prefixes` to suppress noisy access logs (sgl-project#20379)

* Add latency and throughput metrics to run_eval (sgl-project#21793)

* [diffusion] CI: improve ci reliability (sgl-project#21763)

* [bugfix]GLM-4V model (sgl-project#17122)

* Fix CVEs in Docker image: pillow, linux-libc-dev, and broken sgl-model-gateway build (sgl-project#21789)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: only showing recent runners from ci failure analysis (sgl-project#21015)

* [MPS] Fix Triton stub sub-module imports on Python 3.12+ (sgl-project#21551)

Co-authored-by: karanb192 <karan@example.com>
Co-authored-by: R0CKSTAR <yeahdongcn@gmail.com>
Co-authored-by: R0CKSTAR <xiaodong.ye@mthreads.com>

* [KDA] Fuse scaled_dot_kkt + solve_tril + recompute_w_u for KDA (sgl-project#21604)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>

* chore: bump flashinfer version to 0.6.7 (sgl-project#21422)

Co-authored-by: sglang-bot <sglang-bot@users.noreply.github.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>

* [3/n] lora moe - Support Qwen3-VL-30B-A3B-Instruct  (sgl-project#21469)

Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>

* [Feature Restoration] repetition_penalty is essential for GLM-V models (sgl-project#21258)

Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>

* VLM: change default mm-attention backend from triton_attn to fa4 (on blackwell) (sgl-project#21595)

* Fix added tokens config with sensible filter (sgl-project#17905)

* [AMD] Optimize Qwen3-VL decode - fuse QK-norm + 3D mRoPE + KV cache write (sgl-project#21458)

Co-authored-by: Bingxu Chen <bingxche@amd.com>
Co-authored-by: HaiShaw <hixiao@gmail.com>

* [Bugfix] Fix PP tied embeddings weight loading for qwen3.5 4B dense model (sgl-project#21347)

* [CI] Fix lint that was not applied in sgl-project#21458 (sgl-project#21818)

* Bug fix for llama eagle3 (sgl-project#21397)

* glm_interleave for GLM-V (sgl-project#21671)

* style refinement for hisparse (sgl-project#21198)

* [Bug][VLM] Fix shared memory race condition in ShmPointerMMData broadcast for multi-GPU VLM serving (sgl-project#21655)

* [Bugfix] Fix effective_mamba_size over-allocation (sgl-project#20858)

Co-authored-by: Shangming Cai <csmthu@gmail.com>

* Fix in-place mode in pause generation (sgl-project#21705)

* [diffusion] fix: respect --prompt-path (sgl-project#21756)

* [NPU] update ascend docs (sgl-project#21807)

* [VLM] remove AsyncMMDataProcessor wrapper (sgl-project#21651)

* Use CustomTestCase for TestSessionControl to enable CI retry (sgl-project#21830)

* [NPU]Add a full test pipeline on NPU, resolve issues in the NPU test architecture (sgl-project#20751)

* [diffusion][CI]: Add individual component accuracy CI for diffusion models (sgl-project#18709)

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>

* [Feature] JIT rmsnorm update (with claude) (sgl-project#21834)

* [Diffusion][NPU] add ring sp performance benchmark page in npu (sgl-project#21811)

* fix(MiMo-V2-Flash): add mimo reasoning parser (sgl-project#21414)

* [diffusion] hardware: support FA3 attention backend on MUSA (attn backend, 14/N) (sgl-project#18648)

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: Mick <mickjagger19@icloud.com>

* fix: pre-init tokenizer_manager to avoid AttributeError in shutdown (sgl-project#21824)

* [FlashInver v0.6.7] Integrate flashinfer_trtllm mxfp8 gemm (sgl-project#21576)

* [Misc] Add network timeout to eval dataset downloads (sgl-project#21873)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [refactor] Clean up duplicate flashinfer trtllm moe code (sgl-project#21233)

* [DSA] Support trtllm sparse mla kernel for prefill batches  (sgl-project#21783)

* [Disagg] GPU staging buffer with dynamic ring allocator for heterogeneous TP KV transfer (sgl-project#19890)

* Add merge prohibition policy during CI maintenance mode (sgl-project#21882)

* [Misc] Fix comparator e2e tests: add polars dep + fix dp-attention test (sgl-project#21804)

Co-authored-by: Alison Shao <alison.shao@mac.lan>

* revert: remove TTL-based hard pin from HiRadixCache (sgl-project#21884)

* Unify GSM8K eval path to Chat API for regression CI readiness (sgl-project#21667)

* [HiCache] fix: Clone host indices to avoid memory leak (sgl-project#21624)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>

* [HiCache & PD]Fixed detailed cache hit breakdown in PD scenarios. (sgl-project#21764)

* [CI] Add Llama 3.1 8B Instruct FP4 CI test on SM120 (sgl-project#20648)

* [CI] Add Per-Tensor, Blockwise FP8 Tests on SM120 (sgl-project#20717)

Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>

* Allow /rerun-test to checkout fork PR branch for trusted users (sgl-project#21890)

* Direct model loading from object storage with Runai Model Streamer (sgl-project#17948)

Signed-off-by: Noa Neria <noa@run.ai>

* fix pcg torch dynamo recompile in mxfp8 Triton path (sgl-project#21888)

Co-authored-by: Hanlin Bi <hanlinbi@umich.edu>

* chore: bump mooncake version to 0.3.10.post1 (sgl-project#21844)

* [VLM] Add VLM TP=4 per-commit CI test and improve MMMU eval prompt/parser (sgl-project#21841)

* fix(ci): update est_time for 57 tests based on runtime analysis (sgl-project#21896)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [CI] Increase multimodal server test timeout from 60 to 90 minutes (sgl-project#21897)

* [CI] Remove crashing Kimi K2.5 EAGLE3/MTP variants, keep TP8 and TP8+DP8 (sgl-project#21898)

* [diffusion] CI: add initial nvfp4 ci test for b200 (sgl-project#21767)

Co-authored-by: Mick <mickjagger19@icloud.com>

* Migrate all callers from /get_server_info to /server_info (sgl-project#21463)

* Support PP key for file backend (sgl-project#21901)

* Enable multi-thread weight loading by default (sgl-project#20289)

* Skip Go stdlib and NVIDIA tool CVEs in Trivy scan (sgl-project#21905)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [Kernel] Fuse temperature + softmax in sampling for decode speedup (sgl-project#20501)

* Multi tool streaming fix (sgl-project#20004)

* Return HTTP 400 for streaming validation errors (sgl-project#21900)

* [Spec][Ngram] 4/N: Remove `max_match_window_size` and `min_match_window_size`, matching all suffixes of the Trie (sgl-project#21225)

* Fix ngram doc for speculative_num_draft_tokens default (sgl-project#21910)

* [NVIDIA] Enable fp8 flashinfer_trtllm_routed MoE for MiniMax-M2.5 (sgl-project#20394)

* scheduler: add prefill-only update in merge batch (sgl-project#21840)

* [DSA] Set trtllm kernels as nsa default for Blackwell (sgl-project#21914)

* Revert "Rollback flashmla to older version [1/2]" (sgl-project#21922)

* test: add manual init test for mooncake transfer engine (sgl-project#21842)

Co-authored-by: yunzhi <ningyunxiao.nyx@antgroup.com>

* Fix spec v2 + logprob when max_num_token is set (sgl-project#20799)

* Migrate ngram corpus from torch cpp_extension to TVM FFI jit_kernel (sgl-project#21920)

Co-authored-by: DarkSharpness <2040703891@qq.com>

* [NPU] Support  GLM-4.7-Flash on NPU (sgl-project#21408)

* [CI] Fix gpu deps import in cpu test (sgl-project#21950)

* [Parallel State Refactor 1/n] Remove stream of PyNCCL (sgl-project#20866)

* [diffusion] chore: fix stage profiler for multi-stage denoising (sgl-project#21955)

* [CI] [Tracing] Add ci for tracing and fix bugs (sgl-project#21740)

* Remove logging for subprocess watchdog start (sgl-project#21968)

* [4/n] Support gpt oss 20b lora (sgl-project#21570)

* [MUSA][9/N] Add FA3 attention backend support through MATE (MUSA AI Tensor Engine) (sgl-project#17985)

Co-authored-by: R0CKSTAR <xiaodong.ye@mthreads.com>

* [Feature] Stronger transformers modeling backend with TP, PP, MoE, VLMs, and torch compile (sgl-project#19163)

* [CI] Remove stale Ascend suite entries from test/srt/run_suite.py (sgl-project#21978)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Skip broken AutoModel mapping entries when resolving Llava submodules (sgl-project#21892)

* [CI] Add timeouts to Slack upload urlopen and WebClient (sgl-project#21903)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [Diffusion][NPU] Add support for MOVA (sgl-project#21633)

Co-authored-by: zhangshuai (S) <z00836796@china.huawei.com>

* Remove maxItems=1 restriction when tool_choice is specified (sgl-project#20208)

* [Feature] NVFP4 Marlin fallback for non-Blackwell GPUs (SM75+) (sgl-project#19652)

* [PP] qwen3 vl skip layer id for pp (sgl-project#19135)

* [VLM] Enable per-image MM splitting by default and remove MULTI_IMAGES modality (sgl-project#21899)

* [Bugfix] Fix incorrect dp-attention parallel info in bench_one_batch (sgl-project#21519)

* Revert "[MUSA][9/N] Add FA3 attention backend support through MATE (MUSA AI Tensor Engine)" (sgl-project#22002)

* [NPU] Optimized the wording in the npu docs (sgl-project#21998)

* [Parallel State Refactor 2/n] Unify code path of AMD deterministic all reduce (sgl-project#20871)

* [AMD] Resolve the performance degression when launch server with "--enable-aiter-allreduce-fusion" (sgl-project#21947)

Co-authored-by: wunhuang <wunhuang@amd.com>

* chore: bump sgl-kernel version to 0.4.1 (sgl-project#21447)

Co-authored-by: sglang-bot <sglang-bot@users.noreply.github.com>

* [Workflow] Avoid triggering nightly tests in kernel bump workflow (sgl-project#22010)

* [Workflow] Fix kernel release jobs skipped on push events (sgl-project#22011)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [PD]: Add support for HiSparse to directly transfer the cache from Prefill to Decode DRAM. (sgl-project#21591)

Co-authored-by: Tingwei Huang <huangtingwei9988@gmail.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>

* [Misc] Update CI permission (sgl-project#22014)

* [ROCM][RL] Shuffle Weight In-Place to Preserve Parameter Attributes (sgl-project#21825)

* [CI] Fix duplicate job names that bypass branch protection (sgl-project#22001)

* fix: remove duplicate words in comments (sgl-project#22007)

* [PD] Tiny register info field cleanup for mooncake backend (sgl-project#22016)

* [NPU] optimize glm4.7 (sgl-project#19246)

* [AMD] Enable FP8 KV cache and FP8 attention kernel for NSA on MI300/MI355 with TileLang backend (sgl-project#21511)

* [AMD] Add MiniMax-M2.5 nightly perf benchmarks for MI30x and MI35x (sgl-project#21524)

---------

Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Signed-off-by: P V R K Jyothendra Varma <polisetty.v.r.k.jyothendra.varma@intel.com>
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Signed-off-by: Chang Su <chang.s.su@oracle.com>
Signed-off-by: Noa Neria <noa@run.ai>
Co-authored-by: Bingxu Chen <bingxche@amd.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: yang1002378395-cmyk <yang1002378395@gmail.com>
Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: Bi Xue <bi@thinkingmachines.ai>
Co-authored-by: huangtingwei <141888744+huangtingwei9988@users.noreply.github.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: Muqi Li <muqi1029@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Qiaolin Yu <liin1211@outlook.com>
Co-authored-by: narutolhy <582909902@qq.com>
Co-authored-by: Ethan (Yusheng) Su <yushengsu.thu@gmail.com>
Co-authored-by: zhangxiaolei <zhangxiaolei.666@bytedance.com>
Co-authored-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Co-authored-by: Trevor Morris <tmorris@nvidia.com>
Co-authored-by: Eitan Turok <150733043+eitanturok@users.noreply.github.com>
Co-authored-by: Fengyuan Yu <Yuandao151112@163.com>
Co-authored-by: Fengyuan Yu <15fengyuan@gmail.com>
Co-authored-by: Adarsh Shirawalmath <114558126+adarshxs@users.noreply.github.com>
Co-authored-by: Yuhao Yang <47235274+yhyang201@users.noreply.github.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Jianying <53503712+jianyingzhu@users.noreply.github.com>
Co-authored-by: jianyingzhu <joeyzhu@nvidia.com>
Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Jacob0226 <jacchang@amd.com>
Co-authored-by: Aditya Sharma <89210949+adityavaid@users.noreply.github.com>
Co-authored-by: Yuan Luo <yuan.luo@hotmail.com>
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: Артем Савкин <58187114+OrangeRedeng@users.noreply.github.com>
Co-authored-by: Tamir Baydasov <41994229+TamirBaydasov@users.noreply.github.com>
Co-authored-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: eigen <52445717+yyihuang@users.noreply.github.com>
Co-authored-by: Avery Huang <averyh@nvidia.com>
Co-authored-by: jacky.cheng <yichiche@amd.com>
Co-authored-by: HaiShaw <hixiao@gmail.com>
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: Junrong Lin <33685709+ocss884@users.noreply.github.com>
Co-authored-by: Simon (Jiyou) Li <Simon-Li@users.noreply.github.com>
Co-authored-by: 继优 <jiyou.ljy@alibaba-inc.com>
Co-authored-by: shuwenn <47200617+alphabetc1@users.noreply.github.com>
Co-authored-by: psaab <ps@meta.com>
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
Co-authored-by: Hanlin Bi <52993433+wolfcomos@users.noreply.github.com>
Co-authored-by: wili <98001977+wili-65535@users.noreply.github.com>
Co-authored-by: saatwiknagpal <saatwiknagpal@gmail.com>
Co-authored-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
Co-authored-by: wan4ch <wan4ch@gmail.com>
Co-authored-by: Feng Su <sufeng@linux.alibaba.com>
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Polisetty V R K Jyothendra Varma <polisetty.v.r.k.jyothendra.varma@intel.com>
Co-authored-by: Ziang Li <ziangli@umich.edu>
Co-authored-by: Aishwarya Ramasethu <56765596+aramasethu@users.noreply.github.com>
Co-authored-by: Ma Mingfei <mingfei.ma@intel.com>
Co-authored-by: blzheng <beilei.zheng@intel.com>
Co-authored-by: kk <43161300+kkHuang-amd@users.noreply.github.com>
Co-authored-by: wunhuang <wunhuang@amd.com>
Co-authored-by: Michelle Wu <michellewu351@gmail.com>
Co-authored-by: wuxue (C) <w00964934@china.huawei.com>
Co-authored-by: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com>
Co-authored-by: strgrb <zhangkaihong.zkh@antgroup.com>
Co-authored-by: LiYomi <106872109+LiYomi@users.noreply.github.com>
Co-authored-by: mengxiancheng03 <mengxiancheng03@kuaishou.com>
Co-authored-by: GXIN <37653830+gxxx-hum@users.noreply.github.com>
Co-authored-by: 高鑫 <gaoxin@gaoxindeMacBook-Pro.local>
Co-authored-by: heziiop <q_m_p@qq.com>
Co-authored-by: xieminghe1 <141820649+xieminghe1@users.noreply.github.com>
Co-authored-by: undefined <zhouchen.arrebol@jd.com>
Co-authored-by: Makcum888e <79456407+Makcum888e@users.noreply.github.com>
Co-authored-by: yuefeng Wu <33725817+ChefWu551@users.noreply.github.com>
Co-authored-by: Yuxuan Zhang <2448370773@qq.com>
Co-authored-by: Vedant V Jhaveri <vedantjh2@gmail.com>
Co-authored-by: ronnie_zheng <zl19940307@163.com>
Co-authored-by: Zhai Feiyue <80079571+ZhaiFeiyue@users.noreply.github.com>
Co-authored-by: jhchouuu <jiahzhou@amd.com>
Co-authored-by: Alison Shao <54658187+alisonshao@users.noreply.github.com>
Co-authored-by: Alison Shao <alison.shao@MacBook-Pro-D2W773R9CD.local>
Co-authored-by: DarkSharpness <76582120+DarkSharpness@users.noreply.github.com>
Co-authored-by: Alison Shao <alison.shao@Mac.attlocal.net>
Co-authored-by: Lewis <63569348+TTThanos@users.noreply.github.com>
Co-authored-by: 百麒 <yaozhong.lyz@alibaba-inc.com>
Co-authored-by: Jincong Chen <jincong.cjc@ant-intl.com>
Co-authored-by: xiazhahe <86939755+xiazhahe@users.noreply.github.com>
Co-authored-by: Thomas Wang <thomawan@amd.com>
Co-authored-by: Ke Bao <ispobaoke@gmail.com>
Co-authored-by: xiaoqi <xq25478@qq.com>
Co-authored-by: xiaoqi.31 <xiaoqi.31@jd.com>
Co-authored-by: R0CKSTAR <xiaodong.ye@mthreads.com>
Co-authored-by: weireweire <weiliangl@nvidia.com>
Co-authored-by: Weiliangl User <weiliangl@login-node.hosted.internal>
Co-authored-by: JD <jaedon.guo@gmail.com>
Co-authored-by: Zhangheng <hzh0425@apache.org>
Co-authored-by: Michael <13900043+michaelzhang-ai@users.noreply.github.com>
Co-authored-by: Yilong Zhao <74357408+happierpig@users.noreply.github.com>
Co-authored-by: Johnsonms <lizhaofu@gmail.com>
Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Co-authored-by: Chang Su <chang.s.su@oracle.com>
Co-authored-by: KnightLTC <56717110+KnightLTC@users.noreply.github.com>
Co-authored-by: Douglas Yang <dyang@college.harvard.edu>
Co-authored-by: Karan Bansal <karanb192@users.noreply.github.com>
Co-authored-by: karanb192 <karan@example.com>
Co-authored-by: R0CKSTAR <yeahdongcn@gmail.com>
Co-authored-by: sglang-bot <sglangbot@gmail.com>
Co-authored-by: sglang-bot <sglang-bot@users.noreply.github.com>
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: sbeurnier <sbeurnier@together.ai>
Co-authored-by: YC Yen-Ching Tseng <yctseng@amd.com>
Co-authored-by: Wenyao Gao <105094497+edwingao28@users.noreply.github.com>
Co-authored-by: Alex Nails <alex.nails@radixark.ai>
Co-authored-by: khalilzhk <khalilzhk@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
Co-authored-by: yudian0504 <138860534+yudian0504@users.noreply.github.com>
Co-authored-by: yunkchen <chenyunkuo.cyk@alibaba-inc.com>
Co-authored-by: wduan-hai <wduan@humansand.ai>
Co-authored-by: amote-i <49533125+amote-i@users.noreply.github.com>
Co-authored-by: Cherry_ming <136634645@qq.com>
Co-authored-by: Ratish P <114130421+Ratish1@users.noreply.github.com>
Co-authored-by: YAMY <74099316+YAMY1234@users.noreply.github.com>
Co-authored-by: Alison Shao <alison.shao@mac.lan>
Co-authored-by: ishandhanani <82981111+ishandhanani@users.noreply.github.com>
Co-authored-by: Derek Yu <81697272+DerekY2@users.noreply.github.com>
Co-authored-by: Noa Neria <noa@run.ai>
Co-authored-by: Hanlin Bi <hanlinbi@umich.edu>
Co-authored-by: Prozac614 <dwt614707404@163.com>
Co-authored-by: David Cheung <d7cheung@gmail.com>
Co-authored-by: Mook <68294499+Godmook@users.noreply.github.com>
Co-authored-by: Khoa Pham <khoa.pham@radixark.ai>
Co-authored-by: foraxe <73625538+foraxe@users.noreply.github.com>
Co-authored-by: yunzhi <ningyunxiao.nyx@antgroup.com>
Co-authored-by: DarkSharpness <2040703891@qq.com>
Co-authored-by: Todobe <43903496+Todobe@users.noreply.github.com>
Co-authored-by: ori <39351881+froststeam@users.noreply.github.com>
Co-authored-by: Thomas <zs033@qq.com>
Co-authored-by: zhangshuai (S) <z00836796@china.huawei.com>
Co-authored-by: lviy <142899752+lviy@users.noreply.github.com>
Co-authored-by: Tingwei Huang <huangtingwei9988@gmail.com>
Co-authored-by: Yuzhen Zhou <82826991+zyzshishui@users.noreply.github.com>
Co-authored-by: Ricardo-M-L <69202550+Ricardo-M-L@users.noreply.github.com>
Co-authored-by: Kelon <kelonlu@163.com>
Co-authored-by: cen121212 <luochen23@huawei.com>
b8zhong added a commit that referenced this pull request Apr 3, 2026
| `SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2` | Apply per token group quantization kernel with fused silu and mul and masked m | `false` |
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
| `SGLANG_FORCE_NVFP4_MARLIN` | Force using NVFP4 Marlin fallback kernels even on Blackwell GPUs with native FP4 support | `false` |
| `SGLANG_FLASHINFER_FP4_GEMM_BACKEND` (deprecated) | Select backend for `mm_fp4` on Blackwell GPUs. **DEPRECATED**: Please use `--fp4-gemm-backend` instead. | `` |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge conflict

Copy link
Copy Markdown
Contributor Author

@Godmook Godmook Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. @b8zhong Thanks for the heads-up and sorry for this problems....
I traced all 16 files against current main and found 5 files with stale merge resolution from my earlier main merge. Here's the full list what I saw:

  1. docs/references/environment_variables.md

SGLANG_FLASHINFER_FP4_GEMM_BACKEND line — removed by #21536, I accidentally kept it. I'll fix it.
2. python/sglang/srt/environ.py

SGLANG_HICACHE_MAX_PINNED_RATIO — removed by #21884
SGLANG_ENABLE_MM_SPLITTING — removed by #21899

  1. python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py

silu_and_mul moved from sgl_kernel → sglang.jit_kernel.activation (#21766)
4. compressed_tensors_w4a4_nvfp4.py

Missing and not get_fp4_gemm_runner_backend().is_cutlass() guard on the flashinfer path
5. modelopt_quant.py

  • cutlass_fp4_gemm import changed to top-level try/except
  • New CUTLASS FP4 GEMM code path added
  • Same .is_cutlass() guard missing

The remaining 11 files are identical to main. I'll rebase on latest main to resolve all of these cleanly If you approve my plan. Really sorry about that...

if not isinstance(hf_quant_config, dict):
hf_quant_config = hf_quant_config.to_dict()
hf_quant_config["packed_modules_mapping"] = packed_modules_mapping
# For modelopt, route to FP4 vs FP8 config based on quant_algo
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge conflict?

@Fridge003
Copy link
Copy Markdown
Collaborator

@Godmook This PR is just reverted due to the merge conflicts. Please rebase your branch and reland with a new PR, thanks!

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 3, 2026

@Godmook This PR is just reverted due to the merge conflicts. Please rebase your branch and reland with a new PR, thanks!

Yap. I talked with @b8zhong and I'll firstly do more Benchmark tests and after saw that datas, I think we have to decide it is good to land or not. Thanks for help!

@Godmook Godmook restored the nvfp4-marlin-fallback branch April 6, 2026 01:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120 documentation Improvements or additions to documentation high priority jit-kernel quant LLM Quantization run-ci sgl-kernel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants