Skip to content

fix: default FP4 GEMM backend to flashinfer_cudnn on SM120 (Blackwell)#20047

Merged
Fridge003 merged 5 commits intosgl-project:mainfrom
voipmonitor:fix/fp4-gemm-default-sm120
Mar 9, 2026
Merged

fix: default FP4 GEMM backend to flashinfer_cudnn on SM120 (Blackwell)#20047
Fridge003 merged 5 commits intosgl-project:mainfrom
voipmonitor:fix/fp4-gemm-default-sm120

Conversation

@voipmonitor
Copy link
Copy Markdown
Contributor

@voipmonitor voipmonitor commented Mar 6, 2026

Summary

  • Default --fp4-gemm-backend from flashinfer_cutlass to auto
  • auto now selects flashinfer_cudnn on SM120 (Blackwell), flashinfer_cutlass on other architectures (no behavior change for non-Blackwell)

Problem

The flashinfer_cutlass FP4 GEMM backend produces NaN values in dense MLP layers when processing heterogeneous batches on SM120 (Blackwell) GPUs. This causes torch.multinomial to crash with probability tensor contains either inf, nan or element < 0 under concurrent request load.

Key findings from debugging:

  • NaN originates in layer 0 dense MLP (not attention, not MoE)
  • Only 3-5 rows out of ~2000 get full NaN (all hidden dims)
  • Batch-composition dependent — same input replayed alone doesn't trigger NaN; requires heterogeneous batches
  • Identical NaN pattern across all 8 TP ranks (not a NCCL issue)
  • Crashes consistently within 1-4 rounds of 64 concurrent requests

Fix

flashinfer_cudnn does not exhibit this issue. Stress-tested:

  • flashinfer_cutlass: NaN crash within 1-4 rounds of 64 concurrent requests
  • flashinfer_cudnn: 3,200 requests (50 rounds × 64 concurrent) — zero NaN, zero crashes

This PR changes the default to auto which auto-selects flashinfer_cudnn on SM120 and preserves flashinfer_cutlass on all other architectures.

Test plan

  • Verified flashinfer_cudnn produces zero NaN on SM120 with 3,200 concurrent requests
  • Verified --fp4-gemm-backend flashinfer_cutlass still works as explicit override
  • Non-Blackwell GPUs: auto resolves to flashinfer_cutlass (existing behavior unchanged)

Environment tested

  • 8× NVIDIA RTX PRO 6000 Blackwell Server Edition (SM120, 96GB)
  • GLM-5-NVFP4-MTP with --quantization modelopt_fp4, TP=8
  • SGLang built from main

Full investigation details: Fix #20043

🤖 Generated with Claude Code

The flashinfer_cutlass FP4 GEMM backend produces NaN values in dense MLP
layers when processing heterogeneous batches on SM120 (Blackwell) GPUs.
This causes torch.multinomial crashes under concurrent request load.

The flashinfer_cudnn backend does not exhibit this issue — stress-tested
with 3,200 concurrent requests (50 rounds x 64 concurrent) with zero NaN.

Changes:
- Change fp4_gemm_runner_backend default from "flashinfer_cutlass" to "auto"
- Add SM120 detection in auto-resolve: select flashinfer_cudnn on Blackwell,
  flashinfer_cutlass on other architectures (preserving existing behavior)

Fixes: sgl-project#20043

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a crucial fix for numerical stability issues encountered on NVIDIA SM120 (Blackwell) GPUs when using FP4 GEMM operations. By changing the default backend selection to an 'auto' mode, the system now dynamically chooses the more stable flashinfer_cudnn for Blackwell architectures, preventing NaN propagation and ensuring robust operation under high load. This change maintains existing behavior for non-Blackwell GPUs while significantly improving reliability for the latest hardware.

Highlights

  • Default FP4 GEMM Backend Change: The default --fp4-gemm-backend setting has been changed from flashinfer_cutlass to auto in the server arguments.
  • Intelligent Backend Selection for Blackwell GPUs: The auto backend selection now intelligently defaults to flashinfer_cudnn for SM120 (Blackwell) GPUs, while retaining flashinfer_cutlass for other architectures. This change specifically addresses a critical issue on Blackwell GPUs.
  • Resolution of NaN Issue on SM120 (Blackwell): This update fixes a problem where the flashinfer_cutlass backend produced NaN values in dense MLP layers on SM120 (Blackwell) GPUs, leading to crashes during concurrent request loads. The flashinfer_cudnn backend has been verified to be stable under these conditions.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/quantization/fp4_utils.py
    • Imported is_sm120_supported utility function.
    • Implemented logic within initialize_fp4_gemm_config to automatically select flashinfer_cudnn as the FP4 GEMM backend if an SM120 (Blackwell) GPU is detected when backend is set to auto.
    • Added a logger info message to indicate when flashinfer_cudnn is auto-selected for SM120.
  • python/sglang/srt/server_args.py
    • Updated the default value of fp4_gemm_runner_backend from flashinfer_cutlass to auto.
    • Modified the help text for the fp4_gemm_runner_backend CLI argument to accurately describe the new auto behavior, specifying its selection logic for SM120/Blackwell and other architectures.
Activity
  • The author verified that flashinfer_cudnn produces zero NaN values on SM120 with 3,200 concurrent requests.
  • The author confirmed that --fp4-gemm-backend flashinfer_cutlass still functions correctly as an explicit override.
  • The author noted that for non-Blackwell GPUs, auto will continue to resolve to flashinfer_cutlass, preserving existing behavior.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@voipmonitor
Copy link
Copy Markdown
Contributor Author

@b8zhong please check

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request changes the default FP4 GEMM backend to auto to address a NaN issue on SM120 (Blackwell) GPUs by selecting flashinfer_cudnn. For other architectures, auto now defaults to flashinfer_cutlass. While this fixes the issue for Blackwell, I've raised a concern about a potential regression for non-Blackwell architectures. The previous behavior of auto was described as selecting between flashinfer_cudnn and flashinfer_cutlass based on CUDA/cuDNN versions, but the new implementation hardcodes flashinfer_cutlass for non-Blackwell systems. This could impact users who were benefiting from the dynamic selection.

Comment on lines +89 to +90
else:
backend = "flashinfer_cutlass"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change hardcodes the auto backend to flashinfer_cutlass for non-Blackwell architectures. However, the previous help text for this option stated: auto-selects between flashinfer_cudnn/flashinfer_cutlass based on CUDA/cuDNN version. This suggests there might have been more complex logic for auto-selection that is now being removed, which could be a regression for users on non-Blackwell hardware who were relying on auto to potentially select flashinfer_cudnn.

While the PR description mentions that the behavior is unchanged for non-Blackwell, the discrepancy with the old help text is concerning. If the old help text was inaccurate and auto always resolved to flashinfer_cutlass, then this change is fine. Otherwise, the previous auto-selection logic should be preserved here for non-SM120 architectures.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, before this was the case (sm100/103 and sm120 will both pick flashinfer cutlass, due to to a memory leak). So it's alright I think

Copy link
Copy Markdown
Collaborator

@b8zhong b8zhong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridge003 , since #18350, do you still think this PR is hacky to set it to auto? Because, it looks like the SM120 CUTLASS-based implementation has a bug. Thus now SM120 and SM100 will resolve to default backend

help="Choose the runner backend for NVFP4 GEMM operations. "
"Options: 'flashinfer_cutlass' (default), "
"'auto' (auto-selects between flashinfer_cudnn/flashinfer_cutlass based on CUDA/cuDNN version), "
"Options: 'auto' (default; selects flashinfer_cudnn on SM120/Blackwell, flashinfer_cutlass otherwise), "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: @Fridge003 do you still think this is hacky? Because it looks like the SM120 impl has a bug, and now the two devices will resolve to different backend. Otherwise, we can just hardcode cuDNN for SM120 temporarily...

The CUTLASS FP4 GEMM kernel on SM120 (Blackwell) intermittently skips
writing certain output tiles, leaving uninitialized memory (NaN) in
contiguous 128-aligned blocks. Pre-zeroing the output buffer ensures
these unwritten tiles contain 0 instead of NaN.

Verified: 1,280 concurrent requests with zero NaN (vs hundreds without
the fix). The pre-zeroing is applied to all backends as a safety measure.

Upstream bug report: flashinfer-ai/flashinfer#2708
Fixes: sgl-project#20043

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@github-actions github-actions bot added the quant LLM Quantization label Mar 6, 2026
@voipmonitor
Copy link
Copy Markdown
Contributor Author

voipmonitor commented Mar 7, 2026

Root cause found — FlashInfer missing GDC compile flags

The NaN crash is caused by missing -DCUTLASS_ENABLE_GDC_FOR_SM100=1 compile flags in FlashInfer's JIT compilation. This causes PDL synchronization barriers to be compiled as no-ops, creating a race condition.

Fix PR: flashinfer-ai/flashinfer#2716

This PR (defaulting to cudnn on SM120) remains useful as a workaround until the FlashInfer fix is released.

This also probably needs cutlass >= 4.3.0 (not sure about this but this release contains also PDL fixes)

yzh119 pushed a commit to flashinfer-ai/flashinfer that referenced this pull request Mar 9, 2026
…ags cause PDL synchronization barriers to compile as no-ops (#2716)

## Summary

All CUTLASS GEMM templates use `enablePDL=true` (Programmatic Dependent
Launch), but the JIT compilation is missing
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` and `-DCUTLASS_ENABLE_GDC_FOR_SM90=1`
compile flags. Without these flags, `wait_on_dependent_grids()` and
`launch_dependent_grids()` in CUTLASS `grid_dependency_control.h`
compile as **empty no-ops**, eliminating the synchronization barriers
needed for safe PDL execution.

## Root Cause

In `cutlass/include/cutlass/arch/grid_dependency_control.h`:

```cpp
CUTLASS_DEVICE void wait_on_dependent_grids() {
#if (defined(CUTLASS_GDC_ENABLED))  // only defined when CUTLASS_ENABLE_GDC_FOR_SM100 is set
  asm volatile("griddepcontrol.wait;");
#endif
}
```

The `CUTLASS_GDC_ENABLED` macro is only defined when
`CUTLASS_ENABLE_GDC_FOR_SM100` is passed as a compile flag. Without it,
PDL launches kernels with overlap enabled at the host level
(`cudaLaunchAttributeProgrammaticStreamSerialization`), but the
device-side synchronization barriers are compiled out — creating a race
condition.

## Symptoms

On SM120 (Blackwell RTX PRO 6000 / RTX 5090) with high concurrency (64+
simultaneous requests in SGLang with TP=8):
- CUTLASS FP4 GEMM intermittently fails to write output tiles
- Unwritten tiles contain uninitialized memory (NaN/garbage)
- NaN blocks are always contiguous and 128-aligned, matching CTA tile
boundaries
- `CUDA_LAUNCH_BLOCKING=1` eliminates the bug (confirms race condition)
- cudnn backend is unaffected (does not use CUTLASS PDL)
- Retry with identical inputs produces correct output

## Fix

Add `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` and
`-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to all affected GEMM JIT modules:
- `fp4_gemm_cutlass` (SM100)
- `fp4_gemm_cutlass_sm103` (SM103)
- `fp4_gemm_cutlass_sm120` (SM120)
- `fp8_gemm_cutlass` (SM100)
- `mxfp8_gemm_cutlass` (SM100)
- `gemm_sm120` (SM120 FP8 groupwise)

The `tgv_gemm` module already had `DCUTLASS_ENABLE_GDC_FOR_SM100`.

Note: `DCUTLASS_ENABLE_GDC_FOR_SM90` is needed because the SM120 CUTLASS
kernel (`sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp`)
guards `launch_dependent_grids()` with `#ifdef
CUTLASS_ENABLE_GDC_FOR_SM90` instead of `SM100` (upstream CUTLASS bug).

## Verification

| Configuration | Result |
|---|---|
| PDL=true, no GDC flags (current) | **NaN crash** under high
concurrency |
| PDL=false (workaround) | OK |
| PDL=true + GDC flags (this PR) | **OK** — tested with 64 concurrent
requests, multiple SGLang restarts from JIT cache |
| `CUDA_LAUNCH_BLOCKING=1` | OK (confirms race condition) |

## Environment

- Hardware: 8x NVIDIA RTX PRO 6000 Blackwell (SM120, 96GB)
- FlashInfer 0.6.4, CUTLASS 4.4.1
- SGLang with TP=8, EAGLE-v2, GLM-5-NVFP4-MTP model
- PyTorch 2.12.0.dev, CUDA 12.8+

## Related

- #2708
- sgl-project/sglang#20043
- sgl-project/sglang#20047

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

**Chores**
- Updated CUDA compilation configuration for SM100 and SM90 GPU
architectures, enhancing build optimization and extending hardware
compatibility for GPU acceleration workloads.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
@Fridge003
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-4-gpu-b200

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 9, 2026

✅ Triggered stage-c-test-4-gpu-b200 to run independently (skipping dependencies).

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 9, 2026

🔗 View workflow run

@Fridge003
Copy link
Copy Markdown
Collaborator

@Fridge003 Fridge003 merged commit d39ed07 into sgl-project:main Mar 9, 2026
55 of 64 checks passed
liubiyongge pushed a commit to liubiyongge/sglang that referenced this pull request Mar 13, 2026
sgl-project#20047)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
@mratsim
Copy link
Copy Markdown

mratsim commented Mar 17, 2026

By the way, is there a performance difference between flashinfer_cudnn and flashinfer_cutlass?

frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
…ags cause PDL synchronization barriers to compile as no-ops (flashinfer-ai#2716)

## Summary

All CUTLASS GEMM templates use `enablePDL=true` (Programmatic Dependent
Launch), but the JIT compilation is missing
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` and `-DCUTLASS_ENABLE_GDC_FOR_SM90=1`
compile flags. Without these flags, `wait_on_dependent_grids()` and
`launch_dependent_grids()` in CUTLASS `grid_dependency_control.h`
compile as **empty no-ops**, eliminating the synchronization barriers
needed for safe PDL execution.

## Root Cause

In `cutlass/include/cutlass/arch/grid_dependency_control.h`:

```cpp
CUTLASS_DEVICE void wait_on_dependent_grids() {
#if (defined(CUTLASS_GDC_ENABLED))  // only defined when CUTLASS_ENABLE_GDC_FOR_SM100 is set
  asm volatile("griddepcontrol.wait;");
#endif
}
```

The `CUTLASS_GDC_ENABLED` macro is only defined when
`CUTLASS_ENABLE_GDC_FOR_SM100` is passed as a compile flag. Without it,
PDL launches kernels with overlap enabled at the host level
(`cudaLaunchAttributeProgrammaticStreamSerialization`), but the
device-side synchronization barriers are compiled out — creating a race
condition.

## Symptoms

On SM120 (Blackwell RTX PRO 6000 / RTX 5090) with high concurrency (64+
simultaneous requests in SGLang with TP=8):
- CUTLASS FP4 GEMM intermittently fails to write output tiles
- Unwritten tiles contain uninitialized memory (NaN/garbage)
- NaN blocks are always contiguous and 128-aligned, matching CTA tile
boundaries
- `CUDA_LAUNCH_BLOCKING=1` eliminates the bug (confirms race condition)
- cudnn backend is unaffected (does not use CUTLASS PDL)
- Retry with identical inputs produces correct output

## Fix

Add `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` and
`-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to all affected GEMM JIT modules:
- `fp4_gemm_cutlass` (SM100)
- `fp4_gemm_cutlass_sm103` (SM103)
- `fp4_gemm_cutlass_sm120` (SM120)
- `fp8_gemm_cutlass` (SM100)
- `mxfp8_gemm_cutlass` (SM100)
- `gemm_sm120` (SM120 FP8 groupwise)

The `tgv_gemm` module already had `DCUTLASS_ENABLE_GDC_FOR_SM100`.

Note: `DCUTLASS_ENABLE_GDC_FOR_SM90` is needed because the SM120 CUTLASS
kernel (`sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp`)
guards `launch_dependent_grids()` with `#ifdef
CUTLASS_ENABLE_GDC_FOR_SM90` instead of `SM100` (upstream CUTLASS bug).

## Verification

| Configuration | Result |
|---|---|
| PDL=true, no GDC flags (current) | **NaN crash** under high
concurrency |
| PDL=false (workaround) | OK |
| PDL=true + GDC flags (this PR) | **OK** — tested with 64 concurrent
requests, multiple SGLang restarts from JIT cache |
| `CUDA_LAUNCH_BLOCKING=1` | OK (confirms race condition) |

## Environment

- Hardware: 8x NVIDIA RTX PRO 6000 Blackwell (SM120, 96GB)
- FlashInfer 0.6.4, CUTLASS 4.4.1
- SGLang with TP=8, EAGLE-v2, GLM-5-NVFP4-MTP model
- PyTorch 2.12.0.dev, CUDA 12.8+

## Related

- flashinfer-ai#2708
- sgl-project/sglang#20043
- sgl-project/sglang#20047

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

**Chores**
- Updated CUDA compilation configuration for SM100 and SM90 GPU
architectures, enhancing build optimization and extending hardware
compatibility for GPU acceleration workloads.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…ags cause PDL synchronization barriers to compile as no-ops (flashinfer-ai#2716)

## Summary

All CUTLASS GEMM templates use `enablePDL=true` (Programmatic Dependent
Launch), but the JIT compilation is missing
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` and `-DCUTLASS_ENABLE_GDC_FOR_SM90=1`
compile flags. Without these flags, `wait_on_dependent_grids()` and
`launch_dependent_grids()` in CUTLASS `grid_dependency_control.h`
compile as **empty no-ops**, eliminating the synchronization barriers
needed for safe PDL execution.

## Root Cause

In `cutlass/include/cutlass/arch/grid_dependency_control.h`:

```cpp
CUTLASS_DEVICE void wait_on_dependent_grids() {
#if (defined(CUTLASS_GDC_ENABLED))  // only defined when CUTLASS_ENABLE_GDC_FOR_SM100 is set
  asm volatile("griddepcontrol.wait;");
#endif
}
```

The `CUTLASS_GDC_ENABLED` macro is only defined when
`CUTLASS_ENABLE_GDC_FOR_SM100` is passed as a compile flag. Without it,
PDL launches kernels with overlap enabled at the host level
(`cudaLaunchAttributeProgrammaticStreamSerialization`), but the
device-side synchronization barriers are compiled out — creating a race
condition.

## Symptoms

On SM120 (Blackwell RTX PRO 6000 / RTX 5090) with high concurrency (64+
simultaneous requests in SGLang with TP=8):
- CUTLASS FP4 GEMM intermittently fails to write output tiles
- Unwritten tiles contain uninitialized memory (NaN/garbage)
- NaN blocks are always contiguous and 128-aligned, matching CTA tile
boundaries
- `CUDA_LAUNCH_BLOCKING=1` eliminates the bug (confirms race condition)
- cudnn backend is unaffected (does not use CUTLASS PDL)
- Retry with identical inputs produces correct output

## Fix

Add `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` and
`-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to all affected GEMM JIT modules:
- `fp4_gemm_cutlass` (SM100)
- `fp4_gemm_cutlass_sm103` (SM103)
- `fp4_gemm_cutlass_sm120` (SM120)
- `fp8_gemm_cutlass` (SM100)
- `mxfp8_gemm_cutlass` (SM100)
- `gemm_sm120` (SM120 FP8 groupwise)

The `tgv_gemm` module already had `DCUTLASS_ENABLE_GDC_FOR_SM100`.

Note: `DCUTLASS_ENABLE_GDC_FOR_SM90` is needed because the SM120 CUTLASS
kernel (`sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp`)
guards `launch_dependent_grids()` with `#ifdef
CUTLASS_ENABLE_GDC_FOR_SM90` instead of `SM100` (upstream CUTLASS bug).

## Verification

| Configuration | Result |
|---|---|
| PDL=true, no GDC flags (current) | **NaN crash** under high
concurrency |
| PDL=false (workaround) | OK |
| PDL=true + GDC flags (this PR) | **OK** — tested with 64 concurrent
requests, multiple SGLang restarts from JIT cache |
| `CUDA_LAUNCH_BLOCKING=1` | OK (confirms race condition) |

## Environment

- Hardware: 8x NVIDIA RTX PRO 6000 Blackwell (SM120, 96GB)
- FlashInfer 0.6.4, CUTLASS 4.4.1
- SGLang with TP=8, EAGLE-v2, GLM-5-NVFP4-MTP model
- PyTorch 2.12.0.dev, CUDA 12.8+

## Related

- flashinfer-ai#2708
- sgl-project/sglang#20043
- sgl-project/sglang#20047

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

**Chores**
- Updated CUDA compilation configuration for SM100 and SM90 GPU
architectures, enhancing build optimization and extending hardware
compatibility for GPU acceleration workloads.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
sgl-project#20047)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

NaN in hidden states with modelopt_fp4 quantization under concurrent load (GLM-5-NVFP4-MTP on Blackwell)

4 participants