Skip to content

Conversation

@BBuf
Copy link
Collaborator

@BBuf BBuf commented Jun 6, 2025

H100 Benchmark FBGEMM GroupedGEMM Results

When running benchmarks with triton==3.2.0, the following warning appears: we can't use warp-specialized features, but persistent kernels and TMA load/store remain available.

/home/ubuntu/bbuf/sglang/benchmark/kernels/fbgemm/fbgemm_grouped_gemm.py:1104: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.

Qwen2-57B-A14B-Instruct BF16 W8A8 TP4

python3 benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4

grouped-gemm-performance:
    batch_size  FBGEMM Grouped GEMM BF16  SGLang Grouped GEMM BF16
0          1.0                  0.032352                  0.022272
1          2.0                  0.032096                  0.022080
2          4.0                  0.032640                  0.021984
3          8.0                  0.031840                  0.021472
4         16.0                  0.030832                  0.021536
5         32.0                  0.032192                  0.021632
6         64.0                  0.393504                  0.595008
7        128.0                  0.393872                  0.598048
8        256.0                  0.394848                  0.589760
9        512.0                  0.397488                  0.605888
10      1024.0                  0.401248                  0.581952
11      2048.0                  0.407232                  0.559232
12      4096.0                  0.416368                  0.717936

图片

Qwen2-57B-A14B-Instruct FP8 W8A8 TP4

python3 benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8 

    batch_size  FBGEMM Grouped GEMM FP8  SGLang Grouped GEMM FP8
0          1.0                 0.042560                 0.022336
1          2.0                 0.041312                 0.022128
2          4.0                 0.040384                 0.022240
3          8.0                 0.041184                 0.022016
4         16.0                 0.040128                 0.022816
5         32.0                 0.014272                 0.021440
6         64.0                 0.212832                 0.595040
7        128.0                 0.211328                 0.598688
8        256.0                 0.211776                 0.590992
9        512.0                 0.213504                 0.606304
10      1024.0                 0.216864                 0.582624
11      2048.0                 0.220512                 0.558128
12      4096.0                 0.227296                 0.718848

图片

meta-llama/Llama-4-Scout-17B-16E-Instruct FP16 TP8

python3 benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py --model meta-llama/Llama-4-Scout-17B-16E-Instruct --tp-size 8 

grouped-gemm-performance:
    batch_size  FBGEMM Grouped GEMM BF16  SGLang Grouped GEMM BF16
0          1.0                  0.034592                  0.022816
1          2.0                  0.033440                  0.022016
2          4.0                  0.033984                  0.022400
3          8.0                  0.324592                  0.532960
4         16.0                  0.321024                  0.516960
5         32.0                  0.322736                  0.695840
6         64.0                  0.321184                  0.607008
7        128.0                  0.321264                  0.475136
8        256.0                  0.321984                  0.419232
9        512.0                  0.325728                  0.363392
10      1024.0                  0.339616                  0.693824
11      2048.0                  0.396928                  1.383792
12      4096.0                  0.732640                  2.761792

图片

meta-llama/Llama-4-Scout-17B-16E-Instruct FP8 TP8

python3 benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py --model meta-llama/Llama-4-Scout-17B-16E-Instruct --tp-size 8 --use-fp8-w8a8

grouped-gemm-performance:
    batch_size  FBGEMM Grouped GEMM FP8  SGLang Grouped GEMM FP8
0          1.0                 0.042336                 0.020592
1          2.0                 0.006464                 0.013536
2          4.0                 0.006464                 0.014112
3          8.0                 0.171712                 0.531744
4         16.0                 0.170944                 0.518208
5         32.0                 0.170432                 0.693952
6         64.0                 0.172704                 0.608352
7        128.0                 0.173248                 0.475200
8        256.0                 0.175040                 0.420544
9        512.0                 0.178400                 0.367200
10      1024.0                 0.196736                 0.697968
11      2048.0                 0.230688                 1.385600
12      4096.0                 0.383872                 2.766432

图片

The conclusion is that FBGEMM can achieve significant performance improvements over SGLang's grouped GEMM implementation for MoE models. This kernel can be directly applied to SGLang's EP-MoE grouped GEMM kernel to boost performance under fp16/bf16 and per-tensor quantized fp8 conditions.

Limitation

The current limitation is that warp-specialized kernels appear unavailable without compiling the Meta-specific Triton version. Additionally, this kernel currently only supports fp16/bf16 and per-tensor quantized fp8w8a8—further modifications would be needed for DeepSeek compatibility.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @BBuf, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello! gemini-code-assist here, providing a summary of this pull request. This PR introduces a benchmark script and the necessary kernel code to compare the performance of FBGEMM's grouped GEMM implementation against the existing SGLang grouped GEMM kernel, specifically targeting Mixture-of-Experts (MoE) models. The goal is to evaluate if the FBGEMM kernel can offer performance improvements for MoE layers in SGLang, particularly under BF16 and FP8 W8A8 conditions. The PR includes the FBGEMM kernel code itself and a detailed benchmark script that tests various batch sizes and model configurations, presenting the results and highlighting the potential performance gains observed with FBGEMM.

Highlights

  • New Benchmark Script: Adds a new Python script (benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py) to benchmark FBGEMM's grouped GEMM kernel against SGLang's existing implementation.
  • FBGEMM Kernel Integration: Includes the FBGEMM grouped GEMM kernel code (benchmark/fbgemm/fbgemm_grouped_gemm.py), likely copied from the PyTorch FBGEMM project, supporting both standard (BF16/FP16) and FP8 row-wise quantized inputs.
  • Performance Comparison: The benchmark results provided in the description show that the FBGEMM kernel can achieve significant performance improvements over the current SGLang grouped GEMM for various model sizes and data types, especially at larger batch sizes.
  • Correctness Verification: Includes a test script (benchmark/fbgemm/test_grouped_gemm.py) to verify the numerical correctness of the integrated FBGEMM kernel against the SGLang implementation for different group configurations (uniform and non-uniform) and data types.
  • Triton Warp Specialization: Notes a limitation where Triton's warp-specialized features might not be available without a specific Triton build, which could impact potential performance gains from the FBGEMM kernel's advanced features.

Changelog

Click here to see the changelog
  • benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py
    • Added a new script for benchmarking FBGEMM grouped GEMM.
    • Includes logic to load model configurations from Hugging Face transformers.
    • Functions to create test data for BF16 and FP8 W8A8.
    • Uses Triton's perf_report for benchmarking across different batch sizes.
    • Compares FBGEMM and SGLang kernels.
    • Includes an optional correctness verification step.
  • benchmark/fbgemm/fbgemm_grouped_gemm.py
    • Added the FBGEMM grouped GEMM kernel implementation.
    • Includes Triton JIT kernels for standard and FP8 row-wise GEMM.
    • Supports TMA load/store and warp specialization (conditional on Triton build).
    • Includes Python wrappers grouped_gemm and grouped_gemm_fp8_rowwise.
  • benchmark/fbgemm/test_grouped_gemm.py
    • Added a new script for testing the correctness of the FBGEMM grouped GEMM kernel.
    • Includes functions to create uniform and non-uniform group configurations.
    • Tests BF16 correctness against SGLang's kernel.
    • Tests FP8 functionality (shape and dtype) for the FBGEMM kernel.
    • Uses pytest for test execution.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a benchmark for comparing FBGEMM's grouped GEMM kernel against SGLang's existing implementation, along with the FBGEMM kernel code itself (adapted from the FBGEMM repository) and unit tests. The benchmark results provided in the PR description are informative and demonstrate potential performance benefits with FBGEMM, especially for FP8 and larger batch sizes.

The code for the benchmark script, the FBGEMM kernels, and the tests are generally well-structured. The use of Triton for kernel implementation is appropriate for high-performance computing. The PR also clearly outlines current limitations, such as the dependency on a specific Triton build for warp-specialized features.

I've identified a few areas for improvement, primarily concerning the robustness of the benchmark data generation and long-term maintainability of the Triton kernel code. Addressing these would enhance the reliability and future upkeep of this valuable contribution.

Summary of Findings

  • Benchmark Data Generation Correctness: The m_sizes calculation in create_test_data and create_fp8_test_data does not correctly handle cases where batch_size is not divisible by num_groups. This can lead to some input tokens not being processed, affecting benchmark accuracy. The seg_indptr calculation also needs to be updated accordingly.
  • Model Configuration Robustness: The get_model_config function's fallback for unknown model architectures might use incorrect parameters if default attribute names don't apply, potentially leading to benchmarks with unintended model shapes.
  • Triton Kernel Maintainability: The FBGEMM Triton kernel code (fbgemm_grouped_gemm.py) contains significant duplication across different kernel versions, which could pose long-term maintainability challenges. This is acknowledged by existing TODO comments.

Merge Readiness

This pull request adds valuable benchmarking capabilities and high-performance FBGEMM kernels. The benchmark results are promising. However, there are a few issues that should be addressed to ensure the benchmark's correctness and the code's long-term maintainability:

  1. Critical: The generation of m_sizes and seg_indptr in the benchmark script needs to be robust to cases where batch_size is not divisible by num_groups to ensure all input data is processed.
  2. Consideration: The model configuration fallback logic could be made more informative to users.
  3. Consideration: The acknowledged code duplication in Triton kernels is a maintainability concern for the future.

I am unable to approve pull requests. Given the high-severity issue related to benchmark data generation, I recommend that these changes be made before merging to ensure the reliability of the benchmark. The medium-severity items should also be considered for improvement.

Comment on lines +66 to +69
tokens_per_group = batch_size // num_groups
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int64, device="cuda"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current calculation of m_sizes assumes batch_size is perfectly divisible by num_groups. If not, sum(m_sizes) (where m_sizes elements are batch_size // num_groups) would be less than batch_size. This means not all input tokens from x would be processed, potentially leading to incorrect or misleading benchmark results for certain configurations.

Could m_sizes be adjusted to ensure all batch_size tokens are distributed among the groups, for example, by distributing any remainder tokens? This would make the benchmark more robust for arbitrary batch_size and num_groups combinations.

    base_tokens_per_group = batch_size // num_groups
    remainder_tokens = batch_size % num_groups
    m_sizes_list = [base_tokens_per_group] * num_groups
    for i in range(remainder_tokens):
        m_sizes_list[i] += 1
    m_sizes = torch.tensor(m_sizes_list, dtype=torch.int64, device="cuda")
    # Ensure sum(m_sizes) == batch_size, critical for processing all input tokens.
    # Consider adding: assert torch.sum(m_sizes).item() == batch_size, "Sum of m_sizes must equal batch_size"

Comment on lines +87 to +89
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device="cuda")
for i in range(1, num_groups + 1):
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Related to the m_sizes calculation: if m_sizes is updated to correctly distribute all tokens (especially when batch_size is not divisible by num_groups), the calculation for seg_indptr also needs to use these corrected m_sizes rather than the potentially inaccurate tokens_per_group.

SGLang's grouped GEMM relies on seg_indptr to define token segments for each expert. Using the actual m_sizes for its calculation is crucial for correctness.

    seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device="cuda")
    # seg_indptr[0] is 0. Calculate cumulative sum of actual m_sizes.
    # This assumes 'm_sizes' has been redefined to correctly sum to batch_size.
    torch.cumsum(m_sizes, dim=0, out=seg_indptr[1:])

Comment on lines +108 to +111
tokens_per_group = batch_size // num_groups
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int64, device="cuda"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the create_test_data function, the m_sizes calculation here in create_fp8_test_data might not account for all tokens if batch_size is not divisible by num_groups. This could affect the FP8 benchmark path.

Would it be beneficial to apply the same robust m_sizes calculation here (distributing remainder tokens) to ensure all batch_size tokens are processed?

    base_tokens_per_group = batch_size // num_groups
    remainder_tokens = batch_size % num_groups
    m_sizes_list = [base_tokens_per_group] * num_groups
    for i in range(remainder_tokens):
        m_sizes_list[i] += 1
    m_sizes = torch.tensor(m_sizes_list, dtype=torch.int64, device="cuda")

Comment on lines +49 to +51
else:
num_groups = config.num_local_experts
intermediate_size = config.intermediate_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback logic in get_model_config for unhandled model architectures defaults to using config.num_local_experts and config.intermediate_size. If a new, unhandled architecture doesn't conform to these attribute names, an AttributeError will occur. While this error is caught in main() and a hardcoded default configuration is used, this might lead to the benchmark running with parameters that don't match the user's intended model.

Could we enhance this fallback? For example, by printing a more specific warning when this else branch is taken for an architecture not explicitly listed in the if/elif conditions? This would alert the user that the benchmark might be using assumed (and potentially incorrect) parameters for their model.

    else:
        # Fallback for unrecognized architectures.
        print(f"Warning: Model architecture '{config.architectures[0]}' not explicitly handled. "
              f"Attempting to use default attributes 'num_local_experts' and 'intermediate_size'. "
              f"This may lead to incorrect benchmark parameters if these attributes are not applicable or missing.")
        num_groups = config.num_local_experts
        intermediate_size = config.intermediate_size

iterated_tiles += num_tiles


# TODO(shikaili): Too much code duplication. Need to refactor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This TODO comment correctly highlights a key maintainability aspect: there's significant code duplication between the different Triton kernel versions (e.g., standard vs. warp-specialized, FP16/BF16 vs. FP8).

While refactoring Triton kernels to minimize duplication without sacrificing performance is non-trivial, especially when adapting existing code, have you considered any strategies for this as a potential long-term improvement? For instance, could some common logic be encapsulated in tl.device_func utilities, or could Python-level pre-processing/code generation help manage variations if the core algorithmic structure is similar?

Copy link
Member

@zhyncs zhyncs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: dead code

@zhyncs zhyncs merged commit bae4fdc into main Jun 7, 2025
4 checks passed
@zhyncs zhyncs deleted the add_fbgemm_grouped_gemm_benchmark branch June 7, 2025 09:57
@jwfromm
Copy link

jwfromm commented Jun 11, 2025

Hi @zhyncs, good to see you again! It's great that there's some value FBGEMM can bring to the SGLang community! I just wanted to note that FBGEMM's GenAI ops can now be pretty simply pip installed pip install fbgemm-gpu-genai and that would allow SGLang to not only use the triton kernels highlighted in this PR, but also the highly optimized CUTLASS grouped gemm operators that are comparably fast and often easier to deploy (https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py#L1181). I'd love to see more integration between FBGEMM and the broader GenAI optimization community :)

@yiakwy-xpu-ml-framework-team
Copy link
Contributor

yiakwy-xpu-ml-framework-team commented Jun 12, 2025

@jwfromm do you have plan to work on fp8 (nv_fp8_e4m3, hip_fp8_e4m3_fnuz) ?

I just checked the source code, the hip is not supported yet :

https://github.com/pytorch/FBGEMM/blob/2f2a1ef555571022e36aa66f15b826cb3c55ba73/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py#L1656

cc @HaiShaw

jianan-gu pushed a commit to jianan-gu/sglang that referenced this pull request Jun 12, 2025
@jianyuh
Copy link

jianyuh commented Jun 12, 2025

@jianyuh
Copy link

jianyuh commented Jun 12, 2025

When running benchmarks with triton==3.2.0, the following warning appears: we can't use warp-specialized features, but persistent kernels and TMA load/store remain available.

If we can run with triton 3.3, we might expect higher perf with warp specializations. cc @levendlee

@jwfromm
Copy link

jwfromm commented Jun 12, 2025

@yiakwy-xpu-ml-framework-team Just to add a bit more to Jianyu's pointers, I actually think FBGEMM brings the most value in enabling AMD kernels since the options for Nvidia outside FBGEMM are a bit more fleshed out. We have full support across Nvidia and AMD for all kernels except the mixed dtype variety (FP8 activation INT4 weight) that you pointed out. Fortunately, AMD GPUs tend to have more RAM so the INT4 weight compression is less relevant. I definitely recommend trying out some of the FP8 activation FP8 weight kernels on AMD, I think you'll be really happy with the performance :)

walker-ai pushed a commit to walker-ai/sglang that referenced this pull request Jul 8, 2025
Merge branch 'sgl_20250610_sync_tag047 of [email protected]:Theta/SGLang.git into main

https://code.alipay.com/Theta/SGLang/pull_requests/52


Reviewed-by: 剑川 <[email protected]>


* [Bugfix] Fix slice operation when chunk size mismatch (sgl-project#6697)
* [Bugfix] Fix ChatCompletion endpoint of mini_lb when stream is set (sgl-project#6703)
* [CI] Fix setup of disaggregation with different tp (sgl-project#6706)
* [PD] Remove Unnecessary Exception Handling for FastQueue.get() (sgl-project#6712)
* Fuse routed_scaling_factor in DeepSeek (sgl-project#6710)
* Overlap two kernels in DeepSeek with communication (sgl-project#6711)
* Minor refactor two-batch overlap (sgl-project#6682)
* Speed up when having padding tokens two-batch overlap (sgl-project#6668)
* [Feature] Support Flashinfer fp8 blockwise GEMM kernel on Blackwell (sgl-project#6479)
* Fix LoRA bench (sgl-project#6719)
* temp
* Fix PP for Qwen3 MoE (sgl-project#6709)
* [feat] triton kernel for get_last_loc (sgl-project#6676)
* [fix] more mem for draft_extend cuda_graph (sgl-project#6726)
* [PD] bug fix:  Update status if nixl receiver send a a dummy req. (sgl-project#6720)
* Tune memory arguments on B200 (sgl-project#6718)
* Add DeepSeek-R1-0528 function call chat template (sgl-project#6725)
* refactor(tool call): Fix BaseFormatDetector tool_index issue and refactor `parse_streaming_increment` (sgl-project#6715)
* Add draft extend CUDA graph for Triton backend (sgl-project#6705)
* refactor apply_w8a8_block_fp8_linear in fp (sgl-project#6545)
* [PD] Support completion endpoint (sgl-project#6729)
* PD Rust LB (PO2) (sgl-project#6437)
* Super tiny enable sole usage of expert distribution metrics and update doc (sgl-project#6680)
* Support picking variants of EPLB algorithms (sgl-project#6728)
* Support tuning DeepEP configs (sgl-project#6742)
* [test] add ut and bm for get_last_loc (sgl-project#6746)
* Fix mem_fraction_static for AMD CI (sgl-project#6748)
* [fix][RL] Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight (sgl-project#6265)
* Improve EPLB logical to physical dispatch map (sgl-project#6727)
* Update DeepSeek-R1-0528 function call chat template (sgl-project#6765)
* [PD] Optimize time out logic and add env var doc for mooncake (sgl-project#6761)
* Fix aiohttp 'Chunk too big' in bench_serving (sgl-project#6737)
* Support sliding window in triton backend (sgl-project#6509)
* Fix shared experts fusion error (sgl-project#6289)
* Fix one bug in the grouped-gemm triton kernel (sgl-project#6772)
* update llama4 chat template and pythonic parser (sgl-project#6679)
* feat(tool call): Enhance Llama32Detector for improved JSON parsing in non-stream (sgl-project#6784)
* Support token-level quantization for EP MoE (sgl-project#6782)
* Temporarily lower mmlu threshold for triton sliding window backend (sgl-project#6785)
* ci: relax test_function_call_required (sgl-project#6786)
* Add intel_amx backend for Radix Attention for CPU (sgl-project#6408)
* Fix incorrect LoRA weight loading for fused gate_up_proj (sgl-project#6734)
* fix(PD-disaggregation): Can not get local ip (sgl-project#6792)
* [FIX] mmmu bench serving result display error (sgl-project#6525) (sgl-project#6791)
* Bump torch to 2.7.0 (sgl-project#6788)
* chore: bump sgl-kernel v0.1.5 (sgl-project#6794)
* Improve profiler and integrate profiler in bench_one_batch_server (sgl-project#6787)
* chore: upgrade sgl-kernel v0.1.5 (sgl-project#6795)
* [Minor] Always append newline after image token when parsing chat message (sgl-project#6797)
* Update CI tests for Llama4 models (sgl-project#6421)
* [Feat] Enable PDL automatically on Hopper architecture (sgl-project#5981)
* chore: update blackwell docker (sgl-project#6800)
* misc: cache is_hopper_arch (sgl-project#6799)
* Remove contiguous before Flashinfer groupwise fp8 gemm (sgl-project#6804)
* Correctly abort the failed grammar requests & Improve the handling of abort (sgl-project#6803)
* [EP] Add cuda kernel for moe_ep_pre_reorder (sgl-project#6699)
* Add draft extend CUDA graph for flashinfer backend  (sgl-project#6805)
* Refactor CustomOp to avoid confusing bugs (sgl-project#5382)
* Tiny log prefill time (sgl-project#6780)
* Tiny fix EPLB assertion about rebalancing period and recorder window size (sgl-project#6813)
* Add simple utility to dump tensors for debugging (sgl-project#6815)
* Fix profiles do not have consistent names (sgl-project#6811)
* Speed up rebalancing when using non-static dispatch algorithms (sgl-project#6812)
* [1/2] Add Kernel support for Cutlass based Fused FP4 MoE (sgl-project#6093)
* [Router] Fix k8s Service Discovery (sgl-project#6766)
* Add CPU optimized kernels for topk and rope fusions  (sgl-project#6456)
* fix new_page_count_next_decode (sgl-project#6671)
* Fix wrong weight reference in dynamic EPLB (sgl-project#6818)
* Minor add metrics to expert location updater (sgl-project#6816)
* [Refactor] Rename `n_share_experts_fusion` as `num_fused_shared_experts` (sgl-project#6735)
* [FEAT] Add transformers backend support  (sgl-project#5929)
* [fix] recover auto-dispatch for rmsnorm and rope (sgl-project#6745)
* fix ep_moe_reorder kernel bugs (sgl-project#6858)
* [Refactor] Multimodal data processing for VLM (sgl-project#6659)
* Decoder-only Scoring API (sgl-project#6460)
* feat: add dp-rank to KV events (sgl-project#6852)
* Set `num_fused_shared_experts` as `num_shared_experts` when shared_experts fusion is not disabled (sgl-project#6736)
* Fix one missing arg in DeepEP (sgl-project#6878)
* Support LoRA in TestOpenAIVisionServer and fix fused kv_proj loading bug. (sgl-project#6861)
* support 1 shot allreduce  in 1-node and 2-node using mscclpp (sgl-project#6277)
* Fix Qwen3MoE missing token padding optimization (sgl-project#6820)
* Tiny update error hints (sgl-project#6846)
* Support layerwise rebalancing experts (sgl-project#6851)
* Tiny allow profiler API to auto create directory (sgl-project#6865)
* Support Blackwell DeepEP docker images (sgl-project#6868)
* [EP] Add cuda kernel for moe_ep_post_reorder (sgl-project#6837)
* [theta]merge 0605
* oai: fix openAI client error with single request via batch api (sgl-project#6170)
* [PD] Fix potential perf spike caused by tracker gc and optimize doc (sgl-project#6764)
* Use deepgemm instead of triton for fused_qkv_a_proj_with_mqa (sgl-project#6890)
* [CUTLASS-FP4-MOE]  Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (sgl-project#6887)
* bugfix(OAI): Fix image_data processing for jinja chat templates (sgl-project#6877)
* [CPU] enable CI for PRs, add Dockerfile and auto build task (sgl-project#6458)
* AITER backend extension and workload optimizations (sgl-project#6838)
* [theta]merge
* [theta]merge
* [Feature] Support Flashinfer fmha on Blackwell (sgl-project#6930)
* Fix a bug in abort & Improve docstrings for abort (sgl-project#6931)
* Tiny support customize DeepEP max dispatch tokens per rank (sgl-project#6934)
* Sync the changes on cuda graph runners (sgl-project#6932)
* [PD] Optimize transfer queue forward logic for dummy rank (sgl-project#6922)
* [Refactor] image data process in bench_serving (sgl-project#6879)
* [fix] logical_to_all_physical_map index 256 is out of bounds in EP parallel. (sgl-project#6767)
* Add triton fused moe kernel config for E=257 on B200 (sgl-project#6939)
* [sgl-kernel] update deepgemm (sgl-project#6942)
* chore: bump sgl-kernel v0.1.6 (sgl-project#6943)
* Minor compile fused topk (sgl-project#6944)
* [Bugfix] pipeline parallelism and Eagle Qwen2 (sgl-project#6910)
* Tiny re-introduce profile id logging (sgl-project#6912)
* Add triton version as a fused_moe_triton config search key to avoid performace decrease in different Triton version (sgl-project#5955)
* reduce torch.zeros overhead in moe align block size kernel (sgl-project#6369)
* chore: upgrade sgl-kernel v0.1.6 (sgl-project#6945)
* add fbgemm moe grouped gemm kernel benchmark (sgl-project#6924)
* [Docker] Add docker file for SGL Router (sgl-project#6915)
* Disabling mixed chunked prefill when eagle is enabled (sgl-project#6874)
* Add canary for EPLB rebalancing (sgl-project#6895)
* Refactor global_server_args_dict (sgl-project#6866)
* Fuse routed scaling factor in topk_reduce kernel (sgl-project#6220)
* Update server timeout time in AMD CI. (sgl-project#6953)
* [misc] add is_cpu() (sgl-project#6950)
* Add H20 fused MoE kernel tuning configs for DeepSeek-R1/V3 (sgl-project#6885)
* Add a CUDA kernel for fusing mapping and weighted sum for MoE. (sgl-project#6916)
* chore: bump sgl-kernel v0.1.6.post1 (sgl-project#6955)
* chore: upgrade sgl-kernel v0.1.6.post1 (sgl-project#6957)
* [DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model (sgl-project#6853)
* Revert "Fuse routed scaling factor in topk_reduce kernel (sgl-project#6220)" (sgl-project#6968)
* [AMD] Add more tests to per-commit-amd (sgl-project#6926)
* chore: bump sgl-kernel v0.1.7 (sgl-project#6963)
* Slightly improve the sampler to skip unnecessary steps (sgl-project#6956)
* rebase h20 fused_moe config (sgl-project#6966)
* Fix CI and triton moe Configs (sgl-project#6974)
* Remove unnecessary kernels of num_token_non_padded (sgl-project#6965)
* Extend cuda graph capture bs for B200 (sgl-project#6937)
* Fuse routed scaling factor in deepseek (sgl-project#6970)
* Sync cuda graph runners (sgl-project#6976)
* Fix draft extend ut stability with flush cache (sgl-project#6979)
* Fix triton sliding window test case (sgl-project#6981)
* Fix expert distribution dumping causes OOM (sgl-project#6967)
* Minor remove one kernel for DeepSeek (sgl-project#6977)
* [perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (sgl-project#6929)
* Enable more unit tests for AMD CI. (sgl-project#6983)
* Use torch.compile to fuse flash attention decode metadata preparation (sgl-project#6973)
* Eliminate stream sync to speed up LoRA batch init  (sgl-project#6960)
* support qwen3 emebedding (sgl-project#6990)
* Fix torch profiler bugs for bench_offline_throughput.py (sgl-project#6557)
* chore: upgrade flashinfer v0.2.6.post1 jit (sgl-project#6958)
* cleanup tmp dir (sgl-project#7007)
* chore: update pr test xeon (sgl-project#7008)
* Fix cutlass MLA gets almost zero accuracy (sgl-project#6998)
* Update amd nightly models CI. (sgl-project#6992)
* feat: add direct routing strategy to DP worker (sgl-project#6884)
* Fallback to lower triton version for unfound fused moe configs (sgl-project#7013)
* Fix torchvision version for Blackwell (sgl-project#7015)
* Simplify prepare_extend_after_decode (sgl-project#6987)
* Migrate to assertEqual (sgl-project#6741)
* Fix torch version in blackwell dockerfile (sgl-project#7017)
* chore: update pr test xeon (sgl-project#7018)
* Update default settings for blackwell (sgl-project#7023)
* Support both approximate and exact expert distribution collection (sgl-project#6964)
* Add decode req pool (sgl-project#6980)
* [theta]merge 0610
* [theta]merge 0610
* [CI] Add CI workflow for sgl-router docker build (sgl-project#7027)
* Fix fused_moe triton configs (sgl-project#7029)
* CPU: map changes from developing branch in sgl-kernel (sgl-project#6833)
* chore: bump v0.4.7 (sgl-project#7038)
* Update README.md (sgl-project#7040)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants