Skip to content

Reland FusedSampling (#20501)#22166

Closed
Godmook wants to merge 30 commits intosgl-project:mainfrom
Godmook:fused_sampling
Closed

Reland FusedSampling (#20501)#22166
Godmook wants to merge 30 commits intosgl-project:mainfrom
Godmook:fused_sampling

Conversation

@Godmook
Copy link
Copy Markdown
Contributor

@Godmook Godmook commented Apr 5, 2026

Motivation

PR #20501 was merged but failed in CI mainly due to numerical tolerance. After reverting, I refined the test tolerances and removed unstable cases, validated locally against the CI setup, and are relanding once GitHub CI is green. Benchmark notes and rationale are in #20501.

Modifications

  • Two references: reference_fp32 (strict, matches kernel fp32 math) vs reference_native (looser, matches old div_ + softmax in logits dtype).
  • _TOL: Separate fp32_ref / native_ref per dtype (e.g. bf16 native (2e-2, 1e-1)).
  • _check_both_refs: OOP tests assert against both references.
  • Row sums: row_sums ≈ 1 relaxed 1e-41e-3 (bf16 in-place buffer).
  • Less flaky: deterministic test_very_low_temperature; test_mixed_temperatures uses native_ref only for temp >= 0.5; random temps use * 1.5 + 0.5 instead of + 0.1 where noted.
  • In-place bf16: both refs checked with native_ref tol (extra bf16 store).
  • vs Flashinfer: use _TOL[dtype]["native_ref"]; mixed-temp compare only temp >= 0.5.

Accuracy Tests

Tested on A100-40GB

test_pytorch_sampling_backend.py

test_mmlu (__main__.TestPyTorchSamplingBackend.test_mmlu) ... [CI Test Method] TestPyTorchSamplingBackend.test_mmlu
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.1 self.max_tokens=2048 self.reasoning_effort=None self.extra_body=None
  0% 0/64 [00:00<?, ?it/s][2026-04-05 20:02:45] Prefill batch, #new-seq: 1, #new-token: 155, #cached-token: 0, token usage: 0.03, #running-req: 0, #queue-req: 4, cuda graph: True, input throughput (token/s): 9.46
[2026-04-05 20:02:45] Prefill batch, #new-seq: 25, #new-token: 4096, #cached-token: 0, token usage: 0.04, #running-req: 1, #queue-req: 0, cuda graph: True, input throughput (token/s): 8597.39
[2026-04-05 20:02:45] Prefill batch, #new-seq: 5, #new-token: 770, #cached-token: 0, token usage: 0.04, #running-req: 25, #queue-req: 0, cuda graph: True, input throughput (token/s): 81654.47
[2026-04-05 20:02:45] Prefill batch, #new-seq: 2, #new-token: 285, #cached-token: 0, token usage: 0.04, #running-req: 30, #queue-req: 0, cuda graph: True, input throughput (token/s): 6238.24
[2026-04-05 20:02:46] Decode batch, #running-req: 32, #token: 6574, token usage: 0.05, cuda graph: True, gen throughput (token/s): 187.66, #queue-req: 0
[2026-04-05 20:02:46] INFO:     127.0.0.1:40266 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:46] Decode batch, #running-req: 32, #token: 7626, token usage: 0.05, cuda graph: True, gen throughput (token/s): 2046.91, #queue-req: 0
[2026-04-05 20:02:46] INFO:     127.0.0.1:40130 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:46] Prefill batch, #new-seq: 1, #new-token: 316, #cached-token: 0, token usage: 0.06, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 237.38
[2026-04-05 20:02:46] Prefill batch, #new-seq: 1, #new-token: 218, #cached-token: 0, token usage: 0.06, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 7857.04
[2026-04-05 20:02:47] Decode batch, #running-req: 32, #token: 9228, token usage: 0.07, cuda graph: True, gen throughput (token/s): 1865.38, #queue-req: 0
[2026-04-05 20:02:47] Decode batch, #running-req: 32, #token: 10508, token usage: 0.08, cuda graph: True, gen throughput (token/s): 2014.92, #queue-req: 0
[2026-04-05 20:02:48] INFO:     127.0.0.1:40232 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:48] Prefill batch, #new-seq: 1, #new-token: 125, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 154.19
[2026-04-05 20:02:48] INFO:     127.0.0.1:40118 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:48] Prefill batch, #new-seq: 1, #new-token: 307, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 461.77
[2026-04-05 20:02:48] INFO:     127.0.0.1:40104 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:48] Prefill batch, #new-seq: 1, #new-token: 136, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 1837.89
[2026-04-05 20:02:48] Decode batch, #running-req: 32, #token: 11137, token usage: 0.08, cuda graph: True, gen throughput (token/s): 1786.95, #queue-req: 0
[2026-04-05 20:02:48] INFO:     127.0.0.1:40180 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:48] INFO:     127.0.0.1:40064 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:48] INFO:     127.0.0.1:40174 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:48] INFO:     127.0.0.1:40254 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:48] Prefill batch, #new-seq: 1, #new-token: 111, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 928.19
[2026-04-05 20:02:48] Prefill batch, #new-seq: 2, #new-token: 303, #cached-token: 0, token usage: 0.08, #running-req: 32, #queue-req: 0, cuda graph: True, input throughput (token/s): 3748.56
[2026-04-05 20:02:48] Prefill batch, #new-seq: 1, #new-token: 418, #cached-token: 0, token usage: 0.08, #running-req: 34, #queue-req: 0, cuda graph: True, input throughput (token/s): 12356.25
[2026-04-05 20:02:48] INFO:     127.0.0.1:40202 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:48] INFO:     127.0.0.1:40220 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:48] Prefill batch, #new-seq: 1, #new-token: 113, #cached-token: 0, token usage: 0.07, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 6136.42
[2026-04-05 20:02:48] Prefill batch, #new-seq: 1, #new-token: 148, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 2074.38
[2026-04-05 20:02:49] INFO:     127.0.0.1:40262 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:49] Prefill batch, #new-seq: 1, #new-token: 119, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 905.21
[2026-04-05 20:02:49] INFO:     127.0.0.1:39992 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:49] INFO:     127.0.0.1:40044 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:49] Prefill batch, #new-seq: 1, #new-token: 177, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 562.04
[2026-04-05 20:02:49] Prefill batch, #new-seq: 1, #new-token: 352, #cached-token: 0, token usage: 0.08, #running-req: 32, #queue-req: 0, cuda graph: True, input throughput (token/s): 9543.17
[2026-04-05 20:02:49] INFO:     127.0.0.1:40086 - "POST /v1/chat/completions HTTP/1.1" 200 OK
  2% 1/64 [00:04<04:51,  4.62s/it][2026-04-05 20:02:49] Prefill batch, #new-seq: 1, #new-token: 172, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 4042.20
[2026-04-05 20:02:49] INFO:     127.0.0.1:40240 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:49] Prefill batch, #new-seq: 1, #new-token: 113, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 1485.43
[2026-04-05 20:02:49] Decode batch, #running-req: 32, #token: 10670, token usage: 0.08, cuda graph: True, gen throughput (token/s): 1433.85, #queue-req: 0
[2026-04-05 20:02:49] INFO:     127.0.0.1:40144 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:49] Prefill batch, #new-seq: 1, #new-token: 122, #cached-token: 0, token usage: 0.07, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 974.94
[2026-04-05 20:02:49] INFO:     127.0.0.1:40012 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:49] Prefill batch, #new-seq: 1, #new-token: 177, #cached-token: 0, token usage: 0.07, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 1705.23
[2026-04-05 20:02:49] INFO:     127.0.0.1:40010 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:49] Prefill batch, #new-seq: 1, #new-token: 138, #cached-token: 0, token usage: 0.07, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 2515.99
[2026-04-05 20:02:49] INFO:     127.0.0.1:40274 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:49] Prefill batch, #new-seq: 1, #new-token: 231, #cached-token: 0, token usage: 0.07, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 1326.07
[2026-04-05 20:02:49] INFO:     127.0.0.1:40080 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:49] INFO:     127.0.0.1:40018 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:50] Prefill batch, #new-seq: 1, #new-token: 293, #cached-token: 0, token usage: 0.07, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 1664.65
[2026-04-05 20:02:50] Prefill batch, #new-seq: 1, #new-token: 186, #cached-token: 0, token usage: 0.07, #running-req: 32, #queue-req: 0, cuda graph: True, input throughput (token/s): 25804.03
[2026-04-05 20:02:50] INFO:     127.0.0.1:40028 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:50] Prefill batch, #new-seq: 1, #new-token: 120, #cached-token: 0, token usage: 0.07, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 952.24
[2026-04-05 20:02:50] Decode batch, #running-req: 32, #token: 10430, token usage: 0.07, cuda graph: True, gen throughput (token/s): 1586.51, #queue-req: 0
[2026-04-05 20:02:51] Decode batch, #running-req: 32, #token: 11710, token usage: 0.08, cuda graph: True, gen throughput (token/s): 1998.76, #queue-req: 0
[2026-04-05 20:02:51] INFO:     127.0.0.1:40160 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:51] Prefill batch, #new-seq: 1, #new-token: 110, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 135.68
[2026-04-05 20:02:51] INFO:     127.0.0.1:40114 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:51] Prefill batch, #new-seq: 1, #new-token: 125, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 941.67
[2026-04-05 20:02:51] INFO:     127.0.0.1:40266 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:51] INFO:     127.0.0.1:40202 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:51] Prefill batch, #new-seq: 1, #new-token: 106, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 1831.74
[2026-04-05 20:02:51] Prefill batch, #new-seq: 1, #new-token: 169, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 1911.14
[2026-04-05 20:02:51] INFO:     127.0.0.1:39998 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:51] Prefill batch, #new-seq: 1, #new-token: 280, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 1186.15
[2026-04-05 20:02:51] INFO:     127.0.0.1:40236 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:51] Prefill batch, #new-seq: 1, #new-token: 147, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 1674.56
[2026-04-05 20:02:51] Decode batch, #running-req: 32, #token: 11179, token usage: 0.08, cuda graph: True, gen throughput (token/s): 1635.11, #queue-req: 0
[2026-04-05 20:02:51] INFO:     127.0.0.1:40052 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:51] Prefill batch, #new-seq: 1, #new-token: 158, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 559.53
[2026-04-05 20:02:51] INFO:     127.0.0.1:40086 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:52] Prefill batch, #new-seq: 1, #new-token: 290, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 1670.03
[2026-04-05 20:02:52] INFO:     127.0.0.1:40180 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:52] Prefill batch, #new-seq: 1, #new-token: 160, #cached-token: 0, token usage: 0.08, #running-req: 31, #queue-req: 0, cuda graph: True, input throughput (token/s): 4099.38
[2026-04-05 20:02:52] INFO:     127.0.0.1:40130 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:52] INFO:     127.0.0.1:40232 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:52] INFO:     127.0.0.1:40254 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:52] Decode batch, #running-req: 29, #token: 10096, token usage: 0.07, cuda graph: True, gen throughput (token/s): 1715.33, #queue-req: 0
[2026-04-05 20:02:52] INFO:     127.0.0.1:40012 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:52] INFO:     127.0.0.1:40090 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:52] INFO:     127.0.0.1:40174 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:52] INFO:     127.0.0.1:39992 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:53] Decode batch, #running-req: 25, #token: 9399, token usage: 0.07, cuda graph: True, gen throughput (token/s): 1674.80, #queue-req: 0
[2026-04-05 20:02:53] INFO:     127.0.0.1:40018 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:53] INFO:     127.0.0.1:40010 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:53] INFO:     127.0.0.1:40220 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:53] INFO:     127.0.0.1:40114 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:53] Decode batch, #running-req: 23, #token: 8925, token usage: 0.06, cuda graph: True, gen throughput (token/s): 1539.05, #queue-req: 0
[2026-04-05 20:02:53] INFO:     127.0.0.1:40104 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:53] INFO:     127.0.0.1:40210 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:53] INFO:     127.0.0.1:40144 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:53] INFO:     127.0.0.1:40028 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:54] INFO:     127.0.0.1:40044 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:54] INFO:     127.0.0.1:40262 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:54] INFO:     127.0.0.1:40118 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:54] Decode batch, #running-req: 14, #token: 5976, token usage: 0.04, cuda graph: True, gen throughput (token/s): 1160.40, #queue-req: 0
[2026-04-05 20:02:54] INFO:     127.0.0.1:40064 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:54] INFO:     127.0.0.1:40266 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:54] INFO:     127.0.0.1:40080 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:54] INFO:     127.0.0.1:40240 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:54] INFO:     127.0.0.1:40160 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:54] INFO:     127.0.0.1:40052 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:54] Decode batch, #running-req: 9, #token: 3993, token usage: 0.03, cuda graph: True, gen throughput (token/s): 873.69, #queue-req: 0
[2026-04-05 20:02:54] INFO:     127.0.0.1:40274 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:54] INFO:     127.0.0.1:39998 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:55] Decode batch, #running-req: 6, #token: 3197, token usage: 0.02, cuda graph: True, gen throughput (token/s): 462.08, #queue-req: 0
[2026-04-05 20:02:56] Decode batch, #running-req: 6, #token: 3437, token usage: 0.02, cuda graph: True, gen throughput (token/s): 441.84, #queue-req: 0
[2026-04-05 20:02:56] INFO:     127.0.0.1:40202 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:56] INFO:     127.0.0.1:40236 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:56] INFO:     127.0.0.1:40086 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:56] Decode batch, #running-req: 3, #token: 2087, token usage: 0.01, cuda graph: True, gen throughput (token/s): 384.49, #queue-req: 0
[2026-04-05 20:02:57] Decode batch, #running-req: 3, #token: 2207, token usage: 0.02, cuda graph: True, gen throughput (token/s): 225.05, #queue-req: 0
[2026-04-05 20:02:57] Decode batch, #running-req: 3, #token: 2327, token usage: 0.02, cuda graph: True, gen throughput (token/s): 224.85, #queue-req: 0
[2026-04-05 20:02:57] INFO:     127.0.0.1:40180 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:02:58] Decode batch, #running-req: 2, #token: 1860, token usage: 0.01, cuda graph: True, gen throughput (token/s): 175.33, #queue-req: 0
[2026-04-05 20:02:58] Decode batch, #running-req: 2, #token: 1940, token usage: 0.01, cuda graph: True, gen throughput (token/s): 151.35, #queue-req: 0
[2026-04-05 20:02:59] Decode batch, #running-req: 2, #token: 2020, token usage: 0.01, cuda graph: True, gen throughput (token/s): 151.19, #queue-req: 0
[2026-04-05 20:02:59] Decode batch, #running-req: 2, #token: 2100, token usage: 0.02, cuda graph: True, gen throughput (token/s): 151.01, #queue-req: 0
[2026-04-05 20:03:00] Decode batch, #running-req: 2, #token: 2180, token usage: 0.02, cuda graph: True, gen throughput (token/s): 150.90, #queue-req: 0
[2026-04-05 20:03:00] Decode batch, #running-req: 2, #token: 2260, token usage: 0.02, cuda graph: True, gen throughput (token/s): 150.75, #queue-req: 0
[2026-04-05 20:03:01] Decode batch, #running-req: 2, #token: 2340, token usage: 0.02, cuda graph: True, gen throughput (token/s): 150.66, #queue-req: 0
[2026-04-05 20:03:01] Decode batch, #running-req: 2, #token: 2420, token usage: 0.02, cuda graph: True, gen throughput (token/s): 150.63, #queue-req: 0
[2026-04-05 20:03:02] Decode batch, #running-req: 2, #token: 2500, token usage: 0.02, cuda graph: True, gen throughput (token/s): 150.57, #queue-req: 0
[2026-04-05 20:03:02] Decode batch, #running-req: 2, #token: 2580, token usage: 0.02, cuda graph: True, gen throughput (token/s): 150.42, #queue-req: 0
[2026-04-05 20:03:03] Decode batch, #running-req: 2, #token: 2660, token usage: 0.02, cuda graph: True, gen throughput (token/s): 150.33, #queue-req: 0
[2026-04-05 20:03:03] Decode batch, #running-req: 2, #token: 2740, token usage: 0.02, cuda graph: True, gen throughput (token/s): 150.22, #queue-req: 0
[2026-04-05 20:03:04] Decode batch, #running-req: 2, #token: 2820, token usage: 0.02, cuda graph: True, gen throughput (token/s): 150.14, #queue-req: 0
[2026-04-05 20:03:05] Decode batch, #running-req: 2, #token: 2900, token usage: 0.02, cuda graph: True, gen throughput (token/s): 150.07, #queue-req: 0
[2026-04-05 20:03:05] Decode batch, #running-req: 2, #token: 2980, token usage: 0.02, cuda graph: True, gen throughput (token/s): 149.95, #queue-req: 0
[2026-04-05 20:03:06] Decode batch, #running-req: 2, #token: 3060, token usage: 0.02, cuda graph: True, gen throughput (token/s): 149.87, #queue-req: 0
[2026-04-05 20:03:06] Decode batch, #running-req: 2, #token: 3140, token usage: 0.02, cuda graph: True, gen throughput (token/s): 149.82, #queue-req: 0
[2026-04-05 20:03:07] Decode batch, #running-req: 2, #token: 3220, token usage: 0.02, cuda graph: True, gen throughput (token/s): 149.70, #queue-req: 0
[2026-04-05 20:03:07] Decode batch, #running-req: 2, #token: 3300, token usage: 0.02, cuda graph: True, gen throughput (token/s): 149.69, #queue-req: 0
[2026-04-05 20:03:08] Decode batch, #running-req: 2, #token: 3380, token usage: 0.02, cuda graph: True, gen throughput (token/s): 149.32, #queue-req: 0
[2026-04-05 20:03:08] Decode batch, #running-req: 2, #token: 3460, token usage: 0.02, cuda graph: True, gen throughput (token/s): 148.80, #queue-req: 0
[2026-04-05 20:03:09] Decode batch, #running-req: 2, #token: 3540, token usage: 0.03, cuda graph: True, gen throughput (token/s): 148.82, #queue-req: 0
[2026-04-05 20:03:09] Decode batch, #running-req: 2, #token: 3620, token usage: 0.03, cuda graph: True, gen throughput (token/s): 148.75, #queue-req: 0
[2026-04-05 20:03:10] Decode batch, #running-req: 2, #token: 3700, token usage: 0.03, cuda graph: True, gen throughput (token/s): 148.79, #queue-req: 0
[2026-04-05 20:03:10] Decode batch, #running-req: 2, #token: 3780, token usage: 0.03, cuda graph: True, gen throughput (token/s): 148.69, #queue-req: 0
[2026-04-05 20:03:11] Decode batch, #running-req: 2, #token: 3860, token usage: 0.03, cuda graph: True, gen throughput (token/s): 148.67, #queue-req: 0
[2026-04-05 20:03:12] Decode batch, #running-req: 2, #token: 3940, token usage: 0.03, cuda graph: True, gen throughput (token/s): 148.66, #queue-req: 0
[2026-04-05 20:03:12] Decode batch, #running-req: 2, #token: 4020, token usage: 0.03, cuda graph: True, gen throughput (token/s): 148.65, #queue-req: 0
[2026-04-05 20:03:13] Decode batch, #running-req: 2, #token: 4100, token usage: 0.03, cuda graph: True, gen throughput (token/s): 148.59, #queue-req: 0
[2026-04-05 20:03:13] Decode batch, #running-req: 2, #token: 4180, token usage: 0.03, cuda graph: True, gen throughput (token/s): 148.53, #queue-req: 0
[2026-04-05 20:03:14] Decode batch, #running-req: 2, #token: 4260, token usage: 0.03, cuda graph: True, gen throughput (token/s): 148.43, #queue-req: 0
[2026-04-05 20:03:14] Decode batch, #running-req: 2, #token: 4340, token usage: 0.03, cuda graph: True, gen throughput (token/s): 148.42, #queue-req: 0
[2026-04-05 20:03:14] INFO:     127.0.0.1:39988 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2026-04-05 20:03:14] INFO:     127.0.0.1:40194 - "POST /v1/chat/completions HTTP/1.1" 200 OK
100% 64/64 [00:30<00:00,  2.13it/s]
Total latency: 30.093 s
Score: 0.766
Output throughput: 644.566 token/s
[METRIC] mmlu_score=0.765625 labels={"model": "meta-llama/Llama-3.1-8B-Instruct", "eval": "mmlu"}
[METRIC] mmlu_latency=30.09313351999981 labels={"model": "meta-llama/Llama-3.1-8B-Instruct", "eval": "mmlu"}
Writing report to /tmp/mmlu_meta-llama_Llama-3.1-8B-Instruct.html
{'other': np.float64(0.8125), 'other:std': np.float64(0.3903123748998999), 'score:std': np.float64(0.4236075534914362), 'stem': np.float64(0.8181818181818182), 'stem:std': np.float64(0.385694607919935), 'humanities': np.float64(0.7391304347826086), 'humanities:std': np.float64(0.43910891036356864), 'social_sciences': np.float64(0.7142857142857143), 'social_sciences:std': np.float64(0.45175395145262565), 'score': np.float64(0.765625), 'latency': 30.09313351999981, 'output_throughput': 644.565644422134}
Writing results to /tmp/mmlu_meta-llama_Llama-3.1-8B-Instruct.json
ok

----------------------------------------------------------------------
Ran 2 tests in 129.011s

OK

test_fused_temperature_softmax.py

test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_argmax_preserved PASSED [  4%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_basic PASSED [  8%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_batch_sizes PASSED [ 12%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_empty_batch PASSED [ 16%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_fp16_input PASSED [ 20%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_fp32_input PASSED [ 25%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_inplace_basic PASSED [ 29%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_inplace_bf16 PASSED [ 33%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_inplace_large_vocab PASSED [ 37%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_known_softmax_values PASSED [ 41%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_known_softmax_with_temperature PASSED [ 45%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_known_uniform_logits PASSED [ 50%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_large_logits_inplace_no_nan PASSED [ 54%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_large_logits_no_nan PASSED [ 58%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_large_vocab PASSED [ 62%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_mixed_temperatures PASSED [ 66%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_temperature_one PASSED [ 70%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_very_high_temperature PASSED [ 75%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_very_low_temperature PASSED [ 79%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_vs_flashinfer_basic PASSED [ 83%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_vs_flashinfer_batch_sizes PASSED [ 87%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_vs_flashinfer_large_vocab PASSED [ 91%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_vs_flashinfer_mixed_temperatures PASSED [ 95%]
test/registered/sampling/test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_vs_flashinfer_scalar_temperature PASSED [100%]

=============================== warnings summary ===============================
../../usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1474
  /usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1474: PytestConfigWarning: Unknown config option: asyncio_mode
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 24 passed, 1 warning in 14.53s ========================

Speed Tests and Profiling

I tried to return to online softmax with more better codes. I found some errors previously when I applied online softmax, and I think it works well when I'm doing CI tests.

A. Accuracy (fused vs PyTorch fp32 reference)

label max_abs max_rel row_sum_err fi_max_abs status
bf16 small vocab 1.28e-09 8.77e-07 1.19e-07 9.31e-10 PASS
bf16 large vocab (multi-pass) 3.73e-09 1.79e-06 1.19e-07 4.66e-09 PASS
fp32 large vocab (multi-pass) 4.66e-10 1.48e-06 1.19e-07 4.66e-10 PASS
fp16 large vocab (multi-pass) 2.91e-10 1.35e-06 1.19e-07 2.33e-10 PASS
bf16 very low temp 7.51e-06 1.38e-05 7.45e-06 7.51e-06 PASS
bf16 very high temp 1.82e-12 2.38e-07 1.19e-07 1.82e-12 PASS
bf16 batch=64 large vocab 4.66e-09 2.44e-06 2.38e-07 4.66e-09 PASS
inplace bf16 large vocab 4.70e-06 PASS

Overall accuracy: ALL PASS


B. Performance (bf16, μs / iter, CUDA events)

bs vocab PyTorch Triton oop Triton in-place FlashInfer tri/py ip/py tri/fi
1 32K 35.4 64.2 52.6 46.5 0.55x 0.67x 0.72x
1 128K 76.6 79.7 81.4 46.1 0.96x 0.94x 0.58x
32 32K 42.9 64.0 52.0 46.4 0.67x 0.83x 0.73x
32 128K 106.7 79.7 68.9 81.6 1.34x 1.55x 1.02x
128 32K 68.5 63.7 51.8 71.3 1.08x 1.32x 1.12x
128 128K 285.7 170.8 152.0 309.3 1.67x 1.88x 1.81x
512 32K 259.3 173.9 113.0 259.8 1.49x 2.29x 1.49x
512 128K 1080.5 605.6 523.5 1185.3 1.78x 2.06x 1.96x
  • tri/py = PyTorch / Triton oop (higher = Triton faster)
  • ip/py = PyTorch / Triton in-place
  • tri/fi = FlashInfer / Triton oop

C. Kernel variant (vocab → path)

vocab next_pow2 path
32000 32768 single-pass
128256 131072 multi-pass (online 2-pass)
152064 262144 multi-pass (online 2-pass)

So Set Threshold to 32 for better speed.

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 5, 2026

@DarkSharpness I relanded. You can see logs at PR Description. If you think there need be more CI tests locally, please let me know Thanks! :)

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces fused Triton kernels for temperature scaling and softmax to optimize the sampling pipeline, featuring both single-pass and multi-pass implementations with an integration threshold in the sampler. Review feedback highlights critical numerical stability concerns in the kernels where scaling logits before computing the maximum could lead to overflows and NaNs at low temperatures; it is recommended to scale the differences after the maximum reduction instead. Additionally, the feedback suggests handling potential division-by-zero for zero temperatures and correcting the multi-pass kernel documentation to accurately reflect its three-pass implementation.

Comment on lines +45 to +54
x = tl.load(
logits_ptr + row_idx * logits_stride + offsets,
mask=mask,
other=float("-inf"),
)
x = (x / temp).to(tl.float32)

x_max = tl.max(x, axis=0)
exp_x = tl.exp(x - x_max)
prob = exp_x / tl.sum(exp_x, axis=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation scales logits by temperature before computing the maximum. This can lead to numerical instability: if the temperature is very small, x / temp can overflow to infinity. If multiple elements overflow, x - x_max becomes NaN (inf - inf).

A more robust approach is to compute the maximum on the raw logits first, then scale the difference. Additionally, consider handling the case where temp might be zero to avoid division by zero errors in mixed batches.

Suggested change
x = tl.load(
logits_ptr + row_idx * logits_stride + offsets,
mask=mask,
other=float("-inf"),
)
x = (x / temp).to(tl.float32)
x_max = tl.max(x, axis=0)
exp_x = tl.exp(x - x_max)
prob = exp_x / tl.sum(exp_x, axis=0)
x = tl.load(
logits_ptr + row_idx * logits_stride + offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
x_max = tl.max(x, axis=0)
# Use a small epsilon for temp to avoid division by zero if temp=0 reaches here
safe_temp = tl.maximum(temp, 1e-6)
exp_x = tl.exp((x - x_max) / safe_temp)
prob = exp_x / tl.sum(exp_x, axis=0)

Comment on lines +74 to +79
x = tl.load(row_start + offsets, mask=mask, other=float("-inf"))
x = (x / temp).to(tl.float32)

x_max = tl.max(x, axis=0)
exp_x = tl.exp(x - x_max)
prob = exp_x / tl.sum(exp_x, axis=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This in-place kernel suffers from the same numerical stability and potential division-by-zero issues as the out-of-place variant. Scaling before the maximum reduction can cause overflows and NaNs at low temperatures.

Suggested change
x = tl.load(row_start + offsets, mask=mask, other=float("-inf"))
x = (x / temp).to(tl.float32)
x_max = tl.max(x, axis=0)
exp_x = tl.exp(x - x_max)
prob = exp_x / tl.sum(exp_x, axis=0)
x = tl.load(row_start + offsets, mask=mask, other=float("-inf")).to(tl.float32)
x_max = tl.max(x, axis=0)
safe_temp = tl.maximum(temp, 1e-6)
exp_x = tl.exp((x - x_max) / safe_temp)
prob = exp_x / tl.sum(exp_x, axis=0)

Comment on lines +124 to +130
global_max = tl.full([], value=float("-inf"), dtype=tl.float32)
for start in range(0, vocab_size, BLOCK_SIZE):
offsets = start + tl.arange(0, BLOCK_SIZE)
mask = offsets < vocab_size
x = tl.load(logits_row + offsets, mask=mask, other=float("-inf"))
x = (x / temp).to(tl.float32)
global_max = tl.maximum(global_max, tl.max(x, axis=0))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the single-pass kernel, computing the maximum on temperature-scaled logits (x / temp) can lead to numerical instability and NaNs if the temperature is very low. It is safer to find the global maximum of the raw logits first and then apply the temperature scaling to the differences in subsequent passes.

@Godmook Godmook marked this pull request as draft April 5, 2026 23:07
@Godmook Godmook marked this pull request as ready for review April 6, 2026 01:27
@DarkSharpness
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 6, 2026

Close because #22178 shows better performance

@Godmook Godmook closed this Apr 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants