Conversation
… for sampling - Add fused_temperature_softmax Triton kernel (2-pass online softmax) - Replace div_ + softmax with fused kernel in sampler standard path - Add correctness tests and benchmark script - Reduces kernel launches from 2 to 1 and memory passes from 6 to 3
Made-with: Cursor
|
@DarkSharpness I relanded. You can see logs at PR Description. If you think there need be more CI tests locally, please let me know Thanks! :) |
There was a problem hiding this comment.
Code Review
This pull request introduces fused Triton kernels for temperature scaling and softmax to optimize the sampling pipeline, featuring both single-pass and multi-pass implementations with an integration threshold in the sampler. Review feedback highlights critical numerical stability concerns in the kernels where scaling logits before computing the maximum could lead to overflows and NaNs at low temperatures; it is recommended to scale the differences after the maximum reduction instead. Additionally, the feedback suggests handling potential division-by-zero for zero temperatures and correcting the multi-pass kernel documentation to accurately reflect its three-pass implementation.
| x = tl.load( | ||
| logits_ptr + row_idx * logits_stride + offsets, | ||
| mask=mask, | ||
| other=float("-inf"), | ||
| ) | ||
| x = (x / temp).to(tl.float32) | ||
|
|
||
| x_max = tl.max(x, axis=0) | ||
| exp_x = tl.exp(x - x_max) | ||
| prob = exp_x / tl.sum(exp_x, axis=0) |
There was a problem hiding this comment.
The current implementation scales logits by temperature before computing the maximum. This can lead to numerical instability: if the temperature is very small, x / temp can overflow to infinity. If multiple elements overflow, x - x_max becomes NaN (inf - inf).
A more robust approach is to compute the maximum on the raw logits first, then scale the difference. Additionally, consider handling the case where temp might be zero to avoid division by zero errors in mixed batches.
| x = tl.load( | |
| logits_ptr + row_idx * logits_stride + offsets, | |
| mask=mask, | |
| other=float("-inf"), | |
| ) | |
| x = (x / temp).to(tl.float32) | |
| x_max = tl.max(x, axis=0) | |
| exp_x = tl.exp(x - x_max) | |
| prob = exp_x / tl.sum(exp_x, axis=0) | |
| x = tl.load( | |
| logits_ptr + row_idx * logits_stride + offsets, | |
| mask=mask, | |
| other=float("-inf"), | |
| ).to(tl.float32) | |
| x_max = tl.max(x, axis=0) | |
| # Use a small epsilon for temp to avoid division by zero if temp=0 reaches here | |
| safe_temp = tl.maximum(temp, 1e-6) | |
| exp_x = tl.exp((x - x_max) / safe_temp) | |
| prob = exp_x / tl.sum(exp_x, axis=0) |
| x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) | ||
| x = (x / temp).to(tl.float32) | ||
|
|
||
| x_max = tl.max(x, axis=0) | ||
| exp_x = tl.exp(x - x_max) | ||
| prob = exp_x / tl.sum(exp_x, axis=0) |
There was a problem hiding this comment.
This in-place kernel suffers from the same numerical stability and potential division-by-zero issues as the out-of-place variant. Scaling before the maximum reduction can cause overflows and NaNs at low temperatures.
| x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) | |
| x = (x / temp).to(tl.float32) | |
| x_max = tl.max(x, axis=0) | |
| exp_x = tl.exp(x - x_max) | |
| prob = exp_x / tl.sum(exp_x, axis=0) | |
| x = tl.load(row_start + offsets, mask=mask, other=float("-inf")).to(tl.float32) | |
| x_max = tl.max(x, axis=0) | |
| safe_temp = tl.maximum(temp, 1e-6) | |
| exp_x = tl.exp((x - x_max) / safe_temp) | |
| prob = exp_x / tl.sum(exp_x, axis=0) |
| global_max = tl.full([], value=float("-inf"), dtype=tl.float32) | ||
| for start in range(0, vocab_size, BLOCK_SIZE): | ||
| offsets = start + tl.arange(0, BLOCK_SIZE) | ||
| mask = offsets < vocab_size | ||
| x = tl.load(logits_row + offsets, mask=mask, other=float("-inf")) | ||
| x = (x / temp).to(tl.float32) | ||
| global_max = tl.maximum(global_max, tl.max(x, axis=0)) |
There was a problem hiding this comment.
Similar to the single-pass kernel, computing the maximum on temperature-scaled logits (x / temp) can lead to numerical instability and NaNs if the temperature is very low. It is safer to find the global maximum of the raw logits first and then apply the temperature scaling to the differences in subsequent passes.
|
/tag-and-rerun-ci |
|
Close because #22178 shows better performance |
Motivation
PR #20501 was merged but failed in CI mainly due to numerical tolerance. After reverting, I refined the test tolerances and removed unstable cases, validated locally against the CI setup, and are relanding once GitHub CI is green. Benchmark notes and rationale are in #20501.
Modifications
reference_fp32(strict, matches kernel fp32 math) vsreference_native(looser, matches olddiv_+ softmax in logits dtype)._TOL: Separatefp32_ref/native_refper dtype (e.g. bf16 native(2e-2, 1e-1))._check_both_refs: OOP tests assert against both references.row_sums ≈ 1relaxed1e-4→1e-3(bf16 in-place buffer).test_very_low_temperature;test_mixed_temperaturesuses native_ref only fortemp >= 0.5; random temps use* 1.5 + 0.5instead of+ 0.1where noted.native_reftol (extra bf16 store)._TOL[dtype]["native_ref"]; mixed-temp compare onlytemp >= 0.5.Accuracy Tests
Tested on A100-40GB
test_pytorch_sampling_backend.pytest_fused_temperature_softmax.pySpeed Tests and Profiling
I tried to return to online softmax with more better codes. I found some errors previously when I applied online softmax, and I think it works well when I'm doing CI tests.
A. Accuracy (fused vs PyTorch fp32 reference)
Overall accuracy: ALL PASS
B. Performance (bf16, μs / iter, CUDA events)
C. Kernel variant (vocab → path)
So Set Threshold to 32 for better speed.
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci