Skip to content

Conversation

@PopSoda2002
Copy link
Contributor

@PopSoda2002 PopSoda2002 commented May 2, 2025

Motivation

In previous versions, FlashInfer 0.2.5 supports norm's PDL, but currently, norm's PDL is disabled by default. This PR modifies the code to enable it automatically on Hopper architecture.

Modifications

  1. Add is_hopper_arch() utility function to detect Hopper architecture (compute capability >= 9.0)
  2. Modify rmsnorm, fused_add_rmsnorm, gemma_rmsnorm and gemma_fused_add_rmsnorm functions to auto-enable PDL on Hopper
  3. Update documentation to reflect these changes

This implementation automatically enables PDL optimization on Hopper GPUs while maintaining backward compatibility by allowing explicit override through the parameter.

Checklist

@FlamingoPg FlamingoPg changed the title [Feat.] Enable PDL automatically on Hopper architecture [Feat] Enable PDL automatically on Hopper architecture May 2, 2025
@zhyncs
Copy link
Member

zhyncs commented May 2, 2025

Also please provide the performance benchmark after this enhancement

@PopSoda2002
Copy link
Contributor Author

Also please provide the performance benchmark after this enhancement

Yes, there is another guy who is testing the performance!

@hebiao064
Copy link
Collaborator

Also please provide the performance benchmark after this enhancement

Yes, there is another guy who is testing the performance!

Hi, how's the progress on the benchmark?

@PopSoda2002
Copy link
Contributor Author

Still working bro

Also please provide the performance benchmark after this enhancement

Yes, there is another guy who is testing the performance!

Hi, how's the progress on the benchmark?

@PopSoda2002
Copy link
Contributor Author

Here is my benchmark for testing(test on H100):
command:

python3 -m sglang.bench_one_batch --model-path meta-llama/Llama-3.1-8B-Instruct --attention-backend fa3 --batch 16 --input-len 1024 --output-len 10

Before this PR:
image

After:
image

Thanks @Fridge003 for helping!

@FlamingoPg
Copy link
Collaborator

FlamingoPg commented May 22, 2025

batch_size hidden_size dtype w/o pdl w/ pdl
1 111 torch.float16 12.224000 9.632000
1 111 torch.bfloat16 10.976000 11.040000
1 500 torch.float16 11.008000 10.144000
1 500 torch.bfloat16 11.424000 9.632000
1 1024 torch.float16 11.360000 9.696000
1 1024 torch.bfloat16 11.392000 9.984000
1 3072 torch.float16 12.544000 10.304000
1 3072 torch.bfloat16 11.136000 11.264000
1 3584 torch.float16 11.616000 10.016000
1 3584 torch.bfloat16 11.296000 10.368000
1 4096 torch.float16 11.648000 11.456000
1 4096 torch.bfloat16 11.680000 11.456000
1 8192 torch.float16 13.248000 10.880000
1 8192 torch.bfloat16 12.128000 11.936000
1 16384 torch.float16 12.864000 11.904000
1 16384 torch.bfloat16 13.216000 11.584000
19 111 torch.float16 11.776000 11.584000
19 111 torch.bfloat16 12.864000 11.584000
19 500 torch.float16 12.544000 9.920000
19 500 torch.bfloat16 11.584000 9.920000
19 1024 torch.float16 11.776000 10.496000
19 1024 torch.bfloat16 12.672000 10.048000
19 3072 torch.float16 12.832000 11.520000
19 3072 torch.bfloat16 11.456000 10.528000
19 3584 torch.float16 12.992000 10.272000
19 3584 torch.bfloat16 12.992000 10.272000
19 4096 torch.float16 12.032000 11.744000
19 4096 torch.bfloat16 13.088000 10.336000
19 8192 torch.float16 12.576000 11.264000
19 8192 torch.bfloat16 12.672000 11.392000
19 16384 torch.float16 13.856000 12.704000
19 16384 torch.bfloat16 14.048000 13.728000
99 111 torch.float16 13.088000 10.464000
99 111 torch.bfloat16 13.088000 10.848000
99 500 torch.float16 11.392000 10.400000
99 500 torch.bfloat16 11.584000 10.016000
99 1024 torch.float16 11.552000 10.688000
99 1024 torch.bfloat16 12.960000 10.688000
99 3072 torch.float16 12.288000 10.688000
99 3072 torch.bfloat16 13.376000 11.136000
99 3584 torch.float16 12.640000 12.352000
99 3584 torch.bfloat16 12.640000 11.328000
99 4096 torch.float16 12.704000 11.552000
99 4096 torch.bfloat16 14.080000 11.552000
99 8192 torch.float16 16.192000 14.016000
99 8192 torch.bfloat16 16.160000 14.752000
99 16384 torch.float16 18.208001 15.776001
99 16384 torch.bfloat16 18.176001 17.120000
989 111 torch.float16 12.864000 11.648000
989 111 torch.bfloat16 12.896000 11.680000
989 500 torch.float16 13.408000 13.280000
989 500 torch.bfloat16 14.528000 12.224000
989 1024 torch.float16 18.975999 17.696001
989 1024 torch.bfloat16 18.975999 17.664000
989 3072 torch.float16 23.647999 22.399999
989 3072 torch.bfloat16 23.712000 22.431999
989 3584 torch.float16 24.831999 23.520000
989 3584 torch.bfloat16 24.831999 23.456000
989 4096 torch.float16 23.808001 22.368001
989 4096 torch.bfloat16 23.680000 21.152001
989 8192 torch.float16 35.808001 33.408001
989 8192 torch.bfloat16 34.880001 34.015998
989 16384 torch.float16 64.800002 63.616000
989 16384 torch.bfloat16 64.576000 64.032003

benchmark down cc @zhyncs @hebiao064

Copy link
Collaborator

@FlamingoPg FlamingoPg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with benchmark down

@hebiao064
Copy link
Collaborator

LFG!

@PopSoda2002
Copy link
Contributor Author

Hi @zhyncs, can you help to review this PR? I think it is ready to merge

@PopSoda2002 PopSoda2002 requested a review from zhyncs June 1, 2025 06:22
@Fridge003 Fridge003 added the ready-to-merge The PR is ready to merge after the CI is green. label Jun 1, 2025
@Fridge003 Fridge003 merged commit 2f7420b into sgl-project:main Jun 1, 2025
42 of 44 checks passed
if out is None:
out = torch.empty_like(input)
if enable_pdl is None:
enable_pdl = is_hopper_arch()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should cache this result instead of calling it every time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will do it later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More precisely, hopper or later architectures.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@Edenzzzz Edenzzzz mentioned this pull request Jun 1, 2025
6 tasks
Edenzzzz pushed a commit to Edenzzzz/sglang that referenced this pull request Jun 2, 2025
Layssy pushed a commit to Layssy/sglang-iaas that referenced this pull request Jun 9, 2025
xwu-intel pushed a commit to xwu-intel/sglang that referenced this pull request Jun 17, 2025
walker-ai pushed a commit to walker-ai/sglang that referenced this pull request Jul 8, 2025
Merge branch 'sgl_20250610_sync_tag047 of [email protected]:Theta/SGLang.git into main

https://code.alipay.com/Theta/SGLang/pull_requests/52


Reviewed-by: 剑川 <[email protected]>


* [Bugfix] Fix slice operation when chunk size mismatch (sgl-project#6697)
* [Bugfix] Fix ChatCompletion endpoint of mini_lb when stream is set (sgl-project#6703)
* [CI] Fix setup of disaggregation with different tp (sgl-project#6706)
* [PD] Remove Unnecessary Exception Handling for FastQueue.get() (sgl-project#6712)
* Fuse routed_scaling_factor in DeepSeek (sgl-project#6710)
* Overlap two kernels in DeepSeek with communication (sgl-project#6711)
* Minor refactor two-batch overlap (sgl-project#6682)
* Speed up when having padding tokens two-batch overlap (sgl-project#6668)
* [Feature] Support Flashinfer fp8 blockwise GEMM kernel on Blackwell (sgl-project#6479)
* Fix LoRA bench (sgl-project#6719)
* temp
* Fix PP for Qwen3 MoE (sgl-project#6709)
* [feat] triton kernel for get_last_loc (sgl-project#6676)
* [fix] more mem for draft_extend cuda_graph (sgl-project#6726)
* [PD] bug fix:  Update status if nixl receiver send a a dummy req. (sgl-project#6720)
* Tune memory arguments on B200 (sgl-project#6718)
* Add DeepSeek-R1-0528 function call chat template (sgl-project#6725)
* refactor(tool call): Fix BaseFormatDetector tool_index issue and refactor `parse_streaming_increment` (sgl-project#6715)
* Add draft extend CUDA graph for Triton backend (sgl-project#6705)
* refactor apply_w8a8_block_fp8_linear in fp (sgl-project#6545)
* [PD] Support completion endpoint (sgl-project#6729)
* PD Rust LB (PO2) (sgl-project#6437)
* Super tiny enable sole usage of expert distribution metrics and update doc (sgl-project#6680)
* Support picking variants of EPLB algorithms (sgl-project#6728)
* Support tuning DeepEP configs (sgl-project#6742)
* [test] add ut and bm for get_last_loc (sgl-project#6746)
* Fix mem_fraction_static for AMD CI (sgl-project#6748)
* [fix][RL] Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight (sgl-project#6265)
* Improve EPLB logical to physical dispatch map (sgl-project#6727)
* Update DeepSeek-R1-0528 function call chat template (sgl-project#6765)
* [PD] Optimize time out logic and add env var doc for mooncake (sgl-project#6761)
* Fix aiohttp 'Chunk too big' in bench_serving (sgl-project#6737)
* Support sliding window in triton backend (sgl-project#6509)
* Fix shared experts fusion error (sgl-project#6289)
* Fix one bug in the grouped-gemm triton kernel (sgl-project#6772)
* update llama4 chat template and pythonic parser (sgl-project#6679)
* feat(tool call): Enhance Llama32Detector for improved JSON parsing in non-stream (sgl-project#6784)
* Support token-level quantization for EP MoE (sgl-project#6782)
* Temporarily lower mmlu threshold for triton sliding window backend (sgl-project#6785)
* ci: relax test_function_call_required (sgl-project#6786)
* Add intel_amx backend for Radix Attention for CPU (sgl-project#6408)
* Fix incorrect LoRA weight loading for fused gate_up_proj (sgl-project#6734)
* fix(PD-disaggregation): Can not get local ip (sgl-project#6792)
* [FIX] mmmu bench serving result display error (sgl-project#6525) (sgl-project#6791)
* Bump torch to 2.7.0 (sgl-project#6788)
* chore: bump sgl-kernel v0.1.5 (sgl-project#6794)
* Improve profiler and integrate profiler in bench_one_batch_server (sgl-project#6787)
* chore: upgrade sgl-kernel v0.1.5 (sgl-project#6795)
* [Minor] Always append newline after image token when parsing chat message (sgl-project#6797)
* Update CI tests for Llama4 models (sgl-project#6421)
* [Feat] Enable PDL automatically on Hopper architecture (sgl-project#5981)
* chore: update blackwell docker (sgl-project#6800)
* misc: cache is_hopper_arch (sgl-project#6799)
* Remove contiguous before Flashinfer groupwise fp8 gemm (sgl-project#6804)
* Correctly abort the failed grammar requests & Improve the handling of abort (sgl-project#6803)
* [EP] Add cuda kernel for moe_ep_pre_reorder (sgl-project#6699)
* Add draft extend CUDA graph for flashinfer backend  (sgl-project#6805)
* Refactor CustomOp to avoid confusing bugs (sgl-project#5382)
* Tiny log prefill time (sgl-project#6780)
* Tiny fix EPLB assertion about rebalancing period and recorder window size (sgl-project#6813)
* Add simple utility to dump tensors for debugging (sgl-project#6815)
* Fix profiles do not have consistent names (sgl-project#6811)
* Speed up rebalancing when using non-static dispatch algorithms (sgl-project#6812)
* [1/2] Add Kernel support for Cutlass based Fused FP4 MoE (sgl-project#6093)
* [Router] Fix k8s Service Discovery (sgl-project#6766)
* Add CPU optimized kernels for topk and rope fusions  (sgl-project#6456)
* fix new_page_count_next_decode (sgl-project#6671)
* Fix wrong weight reference in dynamic EPLB (sgl-project#6818)
* Minor add metrics to expert location updater (sgl-project#6816)
* [Refactor] Rename `n_share_experts_fusion` as `num_fused_shared_experts` (sgl-project#6735)
* [FEAT] Add transformers backend support  (sgl-project#5929)
* [fix] recover auto-dispatch for rmsnorm and rope (sgl-project#6745)
* fix ep_moe_reorder kernel bugs (sgl-project#6858)
* [Refactor] Multimodal data processing for VLM (sgl-project#6659)
* Decoder-only Scoring API (sgl-project#6460)
* feat: add dp-rank to KV events (sgl-project#6852)
* Set `num_fused_shared_experts` as `num_shared_experts` when shared_experts fusion is not disabled (sgl-project#6736)
* Fix one missing arg in DeepEP (sgl-project#6878)
* Support LoRA in TestOpenAIVisionServer and fix fused kv_proj loading bug. (sgl-project#6861)
* support 1 shot allreduce  in 1-node and 2-node using mscclpp (sgl-project#6277)
* Fix Qwen3MoE missing token padding optimization (sgl-project#6820)
* Tiny update error hints (sgl-project#6846)
* Support layerwise rebalancing experts (sgl-project#6851)
* Tiny allow profiler API to auto create directory (sgl-project#6865)
* Support Blackwell DeepEP docker images (sgl-project#6868)
* [EP] Add cuda kernel for moe_ep_post_reorder (sgl-project#6837)
* [theta]merge 0605
* oai: fix openAI client error with single request via batch api (sgl-project#6170)
* [PD] Fix potential perf spike caused by tracker gc and optimize doc (sgl-project#6764)
* Use deepgemm instead of triton for fused_qkv_a_proj_with_mqa (sgl-project#6890)
* [CUTLASS-FP4-MOE]  Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (sgl-project#6887)
* bugfix(OAI): Fix image_data processing for jinja chat templates (sgl-project#6877)
* [CPU] enable CI for PRs, add Dockerfile and auto build task (sgl-project#6458)
* AITER backend extension and workload optimizations (sgl-project#6838)
* [theta]merge
* [theta]merge
* [Feature] Support Flashinfer fmha on Blackwell (sgl-project#6930)
* Fix a bug in abort & Improve docstrings for abort (sgl-project#6931)
* Tiny support customize DeepEP max dispatch tokens per rank (sgl-project#6934)
* Sync the changes on cuda graph runners (sgl-project#6932)
* [PD] Optimize transfer queue forward logic for dummy rank (sgl-project#6922)
* [Refactor] image data process in bench_serving (sgl-project#6879)
* [fix] logical_to_all_physical_map index 256 is out of bounds in EP parallel. (sgl-project#6767)
* Add triton fused moe kernel config for E=257 on B200 (sgl-project#6939)
* [sgl-kernel] update deepgemm (sgl-project#6942)
* chore: bump sgl-kernel v0.1.6 (sgl-project#6943)
* Minor compile fused topk (sgl-project#6944)
* [Bugfix] pipeline parallelism and Eagle Qwen2 (sgl-project#6910)
* Tiny re-introduce profile id logging (sgl-project#6912)
* Add triton version as a fused_moe_triton config search key to avoid performace decrease in different Triton version (sgl-project#5955)
* reduce torch.zeros overhead in moe align block size kernel (sgl-project#6369)
* chore: upgrade sgl-kernel v0.1.6 (sgl-project#6945)
* add fbgemm moe grouped gemm kernel benchmark (sgl-project#6924)
* [Docker] Add docker file for SGL Router (sgl-project#6915)
* Disabling mixed chunked prefill when eagle is enabled (sgl-project#6874)
* Add canary for EPLB rebalancing (sgl-project#6895)
* Refactor global_server_args_dict (sgl-project#6866)
* Fuse routed scaling factor in topk_reduce kernel (sgl-project#6220)
* Update server timeout time in AMD CI. (sgl-project#6953)
* [misc] add is_cpu() (sgl-project#6950)
* Add H20 fused MoE kernel tuning configs for DeepSeek-R1/V3 (sgl-project#6885)
* Add a CUDA kernel for fusing mapping and weighted sum for MoE. (sgl-project#6916)
* chore: bump sgl-kernel v0.1.6.post1 (sgl-project#6955)
* chore: upgrade sgl-kernel v0.1.6.post1 (sgl-project#6957)
* [DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model (sgl-project#6853)
* Revert "Fuse routed scaling factor in topk_reduce kernel (sgl-project#6220)" (sgl-project#6968)
* [AMD] Add more tests to per-commit-amd (sgl-project#6926)
* chore: bump sgl-kernel v0.1.7 (sgl-project#6963)
* Slightly improve the sampler to skip unnecessary steps (sgl-project#6956)
* rebase h20 fused_moe config (sgl-project#6966)
* Fix CI and triton moe Configs (sgl-project#6974)
* Remove unnecessary kernels of num_token_non_padded (sgl-project#6965)
* Extend cuda graph capture bs for B200 (sgl-project#6937)
* Fuse routed scaling factor in deepseek (sgl-project#6970)
* Sync cuda graph runners (sgl-project#6976)
* Fix draft extend ut stability with flush cache (sgl-project#6979)
* Fix triton sliding window test case (sgl-project#6981)
* Fix expert distribution dumping causes OOM (sgl-project#6967)
* Minor remove one kernel for DeepSeek (sgl-project#6977)
* [perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (sgl-project#6929)
* Enable more unit tests for AMD CI. (sgl-project#6983)
* Use torch.compile to fuse flash attention decode metadata preparation (sgl-project#6973)
* Eliminate stream sync to speed up LoRA batch init  (sgl-project#6960)
* support qwen3 emebedding (sgl-project#6990)
* Fix torch profiler bugs for bench_offline_throughput.py (sgl-project#6557)
* chore: upgrade flashinfer v0.2.6.post1 jit (sgl-project#6958)
* cleanup tmp dir (sgl-project#7007)
* chore: update pr test xeon (sgl-project#7008)
* Fix cutlass MLA gets almost zero accuracy (sgl-project#6998)
* Update amd nightly models CI. (sgl-project#6992)
* feat: add direct routing strategy to DP worker (sgl-project#6884)
* Fallback to lower triton version for unfound fused moe configs (sgl-project#7013)
* Fix torchvision version for Blackwell (sgl-project#7015)
* Simplify prepare_extend_after_decode (sgl-project#6987)
* Migrate to assertEqual (sgl-project#6741)
* Fix torch version in blackwell dockerfile (sgl-project#7017)
* chore: update pr test xeon (sgl-project#7018)
* Update default settings for blackwell (sgl-project#7023)
* Support both approximate and exact expert distribution collection (sgl-project#6964)
* Add decode req pool (sgl-project#6980)
* [theta]merge 0610
* [theta]merge 0610
* [CI] Add CI workflow for sgl-router docker build (sgl-project#7027)
* Fix fused_moe triton configs (sgl-project#7029)
* CPU: map changes from developing branch in sgl-kernel (sgl-project#6833)
* chore: bump v0.4.7 (sgl-project#7038)
* Update README.md (sgl-project#7040)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready-to-merge The PR is ready to merge after the CI is green.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants