diff --git a/benchmarks/bench_sampling.py b/benchmarks/bench_sampling.py index 2eb2de3875..cc2406e43f 100644 --- a/benchmarks/bench_sampling.py +++ b/benchmarks/bench_sampling.py @@ -220,6 +220,86 @@ def main(): f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" ) + print("---") + print("top-p renorm probs") + for vocab_size in [128512]: + for batch_size in [1, 16, 32, 64, 128, 256, 512]: + torch.manual_seed(42) + for distrib in [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + gumbel_distribution(1), + ]: + for p in [0.1, 0.5, 0.9]: + logits = distrib((batch_size, vocab_size), device="cuda") + probs = torch.softmax(logits, dim=-1) + measurements = bench_gpu_time( + lambda: flashinfer.sampling.top_p_renorm_probs(probs, p), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ms = np.median(measurements) + + io = probs.numel() * probs.element_size() * 2 + bandwidth = io * 1e-6 / ms + print( + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + ) + + print("---") + print("top-k renorm probs") + for vocab_size in [128512]: + for batch_size in [1, 16, 32, 64, 128, 256, 512]: + torch.manual_seed(42) + for distrib in [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + gumbel_distribution(1), + ]: + for k in [10, 100, 1000, 5000]: + logits = distrib((batch_size, vocab_size), device="cuda") + probs = torch.softmax(logits, dim=-1) + measurements = bench_gpu_time( + lambda: flashinfer.sampling.top_k_renorm_probs(probs, k), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ms = np.median(measurements) + + io = probs.numel() * probs.element_size() * 2 + bandwidth = io * 1e-6 / ms + print( + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + ) + + print("---") + print("top-k mask logits") + for vocab_size in [128512]: + for batch_size in [1, 16, 32, 64, 128, 256, 512]: + torch.manual_seed(42) + for distrib in [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + gumbel_distribution(1), + ]: + for k in [10, 100, 1000, 5000]: + logits = distrib((batch_size, vocab_size), device="cuda") + measurements = bench_gpu_time( + lambda: flashinfer.sampling.top_k_mask_logits(logits, k), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ms = np.median(measurements) + + io = logits.numel() * logits.element_size() * 2 + bandwidth = io * 1e-6 / ms + print( + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + ) + if __name__ == "__main__": main() diff --git a/benchmarks/bench_softmax.py b/benchmarks/bench_softmax.py new file mode 100755 index 0000000000..6da8dc9fcb --- /dev/null +++ b/benchmarks/bench_softmax.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +""" +Benchmark script comparing torch.softmax vs flashinfer.softmax performance. +Creates a heatmap showing speedup across different batch sizes and hidden dimensions. +""" + +import numpy as np +import torch +import matplotlib.pyplot as plt +import seaborn as sns +from typing import List, Tuple +import flashinfer +from flashinfer.testing.utils import bench_gpu_time + + +@torch.inference_mode() +def benchmark_torch_softmax(logits: torch.Tensor) -> float: + """Benchmark torch's native softmax.""" + measurements = bench_gpu_time( + lambda: torch.softmax(logits, dim=-1), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + return np.median(measurements) + + +@torch.inference_mode() +def benchmark_flashinfer_softmax(logits: torch.Tensor) -> float: + """Benchmark flashinfer's softmax.""" + measurements = bench_gpu_time( + lambda: flashinfer.sampling.softmax(logits, temperature=None, enable_pdl=False), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + return np.median(measurements) + + +def run_benchmark( + batch_sizes: List[int], hidden_sizes: List[int] +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Run benchmarks for all combinations of batch_size and hidden_size. + + Returns: + torch_times: 2D array of torch softmax times (ms) + flashinfer_times: 2D array of flashinfer softmax times (ms) + speedups: 2D array of speedup ratios (torch_time / flashinfer_time) + """ + n_batch = len(batch_sizes) + n_hidden = len(hidden_sizes) + + torch_times = np.zeros((n_batch, n_hidden)) + flashinfer_times = np.zeros((n_batch, n_hidden)) + speedups = np.zeros((n_batch, n_hidden)) + + print("Running benchmarks...") + print("=" * 100) + print( + f"{'Batch Size':<12} {'Hidden Size':<12} {'Torch (ms)':<15} " + f"{'FlashInfer (ms)':<18} {'Speedup':<10} {'Bandwidth (GB/s)':<18}" + ) + print("=" * 100) + + for i, batch_size in enumerate(batch_sizes): + for j, hidden_size in enumerate(hidden_sizes): + # Generate random logits + torch.manual_seed(42) + logits = torch.randn( + batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + # Benchmark torch softmax + torch_time_ms = benchmark_torch_softmax(logits) + torch_times[i, j] = torch_time_ms + + # Benchmark flashinfer softmax + flashinfer_time_ms = benchmark_flashinfer_softmax(logits) + flashinfer_times[i, j] = flashinfer_time_ms + + # Calculate speedup + speedup = torch_time_ms / flashinfer_time_ms + speedups[i, j] = speedup + + # Calculate effective bandwidth (read + write) + io_bytes = logits.numel() * logits.element_size() * 2 + bandwidth_gb_s = io_bytes * 1e-6 / flashinfer_time_ms + + print( + f"{batch_size:<12} {hidden_size:<12} {torch_time_ms:<15.4f} " + f"{flashinfer_time_ms:<18.4f} {speedup:<10.2f}x {bandwidth_gb_s:<18.2f}" + ) + + print("=" * 100) + return torch_times, flashinfer_times, speedups + + +def plot_heatmap( + speedups: np.ndarray, + batch_sizes: List[int], + hidden_sizes: List[int], + save_path: str = "softmax_speedup_heatmap.png", +): + """Create and save a heatmap of speedup values.""" + # Create figure + fig, ax = plt.subplots(figsize=(12, 8)) + + # Create heatmap + sns.heatmap( + speedups, + annot=True, + fmt=".2f", + cmap="RdYlGn", + center=1.0, + cbar_kws={"label": "Speedup (x)"}, + xticklabels=[f"{h // 1000}K" for h in hidden_sizes], + yticklabels=batch_sizes, + ax=ax, + vmin=0.5, # Adjust color scale + vmax=max(3.0, speedups.max()), # Dynamic upper bound + ) + + ax.set_xlabel("Hidden Size", fontsize=12, fontweight="bold") + ax.set_ylabel("Batch Size", fontsize=12, fontweight="bold") + ax.set_title( + "FlashInfer Softmax Speedup vs PyTorch (Higher is Better)", + fontsize=14, + fontweight="bold", + pad=20, + ) + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches="tight") + print(f"\nHeatmap saved to: {save_path}") + + # Also create a performance comparison plot + _, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) + + # Plot 2: Speedup trends across batch sizes + for j, hidden_size in enumerate(hidden_sizes): + ax2.plot( + batch_sizes, + speedups[:, j], + marker="o", + label=f"Hidden={hidden_size // 1000}K", + linewidth=2, + ) + + ax2.set_xlabel("Batch Size", fontsize=12, fontweight="bold") + ax2.set_ylabel("Speedup (x)", fontsize=12, fontweight="bold") + ax2.set_title("Speedup vs Batch Size", fontsize=13, fontweight="bold") + ax2.set_xscale("log", base=2) + ax2.grid(True, alpha=0.3) + ax2.legend(fontsize=9) + ax2.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup") + + # Plot 1: Speedup trends across hidden sizes + for i, batch_size in enumerate(batch_sizes[::2]): # Sample every other batch size + idx = i * 2 + ax1.plot( + [h // 1000 for h in hidden_sizes], + speedups[idx, :], + marker="s", + label=f"Batch={batch_size}", + linewidth=2, + ) + + ax1.set_xlabel("Hidden Size (K)", fontsize=12, fontweight="bold") + ax1.set_ylabel("Speedup (x)", fontsize=12, fontweight="bold") + ax1.set_title("Speedup vs Hidden Size", fontsize=13, fontweight="bold") + ax1.grid(True, alpha=0.3) + ax1.legend(fontsize=9) + ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5) + + plt.tight_layout() + comparison_path = save_path.replace(".png", "_trends.png") + plt.savefig(comparison_path, dpi=300, bbox_inches="tight") + print(f"Trend plots saved to: {comparison_path}") + + +def main(): + """Main benchmark execution.""" + # Configuration + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + hidden_sizes = [32000, 64000, 128000, 256000] + + print("=" * 100) + print("FlashInfer vs PyTorch Softmax Benchmark") + print("=" * 100) + print(f"Batch sizes: {batch_sizes}") + print(f"Hidden sizes: {hidden_sizes}") + print(f"Device: {torch.cuda.get_device_name()}") + print("=" * 100) + print() + + # Run benchmarks + _, _, speedups = run_benchmark(batch_sizes, hidden_sizes) + + # Print summary statistics + print("\nSummary Statistics:") + print("=" * 100) + print(f"Average speedup: {np.mean(speedups):.2f}x") + print(f"Median speedup: {np.median(speedups):.2f}x") + print(f"Min speedup: {np.min(speedups):.2f}x") + print(f"Max speedup: {np.max(speedups):.2f}x") + print("=" * 100) + + # Generate heatmap + plot_heatmap(speedups, batch_sizes, hidden_sizes) + + print("\nBenchmark complete!") + + +if __name__ == "__main__": + main() diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 6b134630cf..f3b188abec 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -333,6 +333,7 @@ __global__ void OnlineSoftmaxFusedKernel(DType* logits, DType* output, DType* te float running_max = -cuda::std::numeric_limits::infinity(); float running_denominator = 0.0f; + float threadlocal_running_denominator = 0.0f; #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); @@ -368,39 +369,32 @@ __global__ void OnlineSoftmaxFusedKernel(DType* logits, DType* output, DType* te } __syncthreads(); block_max = temp_storage.shared_state.max_val; - // if block_max is -inf, then this block contains all -inf values, so we can skip updating if (!isinf(block_max)) { - float thread_sum = 0.0f; + float threadlocal_sum = 0.0f; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - thread_sum += __expf(logits_vec[j] - block_max); - } - - float block_sum = - cub::BlockReduce(temp_storage.block_prim.reduce).Sum(thread_sum); - __syncthreads(); - - if (tx == 0) { - float new_max = max(running_max, block_max); - running_denominator = running_denominator * __expf(running_max - new_max) + - block_sum * __expf(block_max - new_max); - running_max = new_max; - - temp_storage.shared_state.max_val = running_max; - temp_storage.shared_state.denominator = running_denominator; + threadlocal_sum += __expf(logits_vec[j] - block_max); } - __syncthreads(); - running_max = temp_storage.shared_state.max_val; - running_denominator = temp_storage.shared_state.denominator; + float new_max = max(running_max, block_max); + threadlocal_running_denominator = + threadlocal_running_denominator * __expf(running_max - new_max) + + threadlocal_sum * __expf(block_max - new_max); + running_max = new_max; } } + running_denominator = cub::BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_running_denominator); + if (tx == 0) { + temp_storage.shared_state.denominator = running_denominator; + } + __syncthreads(); + running_denominator = temp_storage.shared_state.denominator; + const float final_max = running_max; const float inv_denominator = 1.0f / running_denominator; - __syncthreads(); - // Pass 2: Normalize in place vec_t prob_vec; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { @@ -458,6 +452,7 @@ __global__ void OnlineSoftmaxMapKernel(DType* logits, PartialSoftmaxResult* part vec_t logits_vec; float running_max = -cuda::std::numeric_limits::infinity(); float running_denominator = 0.0f; + float threadlocal_running_denominator = 0.0f; #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); @@ -489,31 +484,27 @@ __global__ void OnlineSoftmaxMapKernel(DType* logits, PartialSoftmaxResult* part // if block_max is -inf, then this block contains all -inf values, so we can skip updating if (!isinf(block_max)) { - float thread_sum = 0.0f; + float threadlocal_sum = 0.0f; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - thread_sum += __expf(logits_vec[j] - block_max); - } - - float block_sum = - cub::BlockReduce(temp_storage.block_prim.reduce).Sum(thread_sum); - __syncthreads(); - - if (tx == 0) { - float new_max = max(running_max, block_max); - running_denominator = running_denominator * __expf(running_max - new_max) + - block_sum * __expf(block_max - new_max); - running_max = new_max; - - temp_storage.shared_state.max_val = running_max; - temp_storage.shared_state.denominator = running_denominator; + threadlocal_sum += __expf(logits_vec[j] - block_max); } - __syncthreads(); - running_max = temp_storage.shared_state.max_val; - running_denominator = temp_storage.shared_state.denominator; + float new_max = max(running_max, block_max); + threadlocal_running_denominator = + threadlocal_running_denominator * __expf(running_max - new_max) + + threadlocal_sum * __expf(block_max - new_max); + running_max = new_max; } } + running_denominator = cub::BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_running_denominator); + if (tx == 0) { + temp_storage.shared_state.denominator = running_denominator; + } + __syncthreads(); + running_denominator = temp_storage.shared_state.denominator; + if (tx == 0) { partial_results[bx * num_slices + by] = {running_max, running_denominator}; } @@ -887,6 +878,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* double pivot_1 = (pivot_0 + high) / 2; ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; + ValueCount threadlocal_gt_pivot_0{0, 0}, threadlocal_gt_pivot_1{0, 0}; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); @@ -903,26 +895,27 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* probs_gt_pivot_1[j] = { (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + threadlocal_gt_pivot_0 += probs_gt_pivot_0[j]; + threadlocal_gt_pivot_1 += probs_gt_pivot_1[j]; } + } + aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(threadlocal_gt_pivot_0); + if (tx == 0) { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; - aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0); - if (tx == 0) { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; - } - __syncthreads(); - aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; - - aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1); - if (tx == 0) { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; - } - __syncthreads(); - aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; + aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(threadlocal_gt_pivot_1); + if (tx == 0) { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; if (aggregate_gt_pivot_0.count < k) { // case 1: pivot_0 accepted break; @@ -1000,6 +993,8 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* double pivot_1 = (pivot_0 + high) / 2; float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; + float threadlocal_aggregate_gt_pivot_0 = 0; + float threadlocal_aggregate_gt_pivot_1 = 0; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); @@ -1012,24 +1007,26 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0; probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0; + threadlocal_aggregate_gt_pivot_0 += probs_gt_pivot_0[j]; + threadlocal_aggregate_gt_pivot_1 += probs_gt_pivot_1[j]; } + } + aggregate_gt_pivot_0 += BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_aggregate_gt_pivot_0); + if (tx == 0) { + temp_storage.block_aggregate.value = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.value; - aggregate_gt_pivot_0 += BlockReduce(temp_storage.block_prim.reduce) - .template Sum(probs_gt_pivot_0); - if (tx == 0) { - temp_storage.block_aggregate.value = aggregate_gt_pivot_0; - } - __syncthreads(); - aggregate_gt_pivot_0 = temp_storage.block_aggregate.value; - - aggregate_gt_pivot_1 += BlockReduce(temp_storage.block_prim.reduce) - .template Sum(probs_gt_pivot_1); - if (tx == 0) { - temp_storage.block_aggregate.value = aggregate_gt_pivot_1; - } - __syncthreads(); - aggregate_gt_pivot_1 = temp_storage.block_aggregate.value; + aggregate_gt_pivot_1 += BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_aggregate_gt_pivot_1); + if (tx == 0) { + temp_storage.block_aggregate.value = aggregate_gt_pivot_1; } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.value; + if (aggregate_gt_pivot_0 < top_p) { // case 1: pivot_0 accepted break; @@ -1077,6 +1074,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp vec_t probs_vec; float aggregate_gt_pivot = 0; + float threadlocal_aggregate_gt_pivot = 0; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); @@ -1088,15 +1086,16 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_gt_pivot[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0; + threadlocal_aggregate_gt_pivot += probs_gt_pivot[j]; } + } - aggregate_gt_pivot += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot); - if (tx == 0) { - temp_storage.block_aggregate.value = aggregate_gt_pivot; - } - __syncthreads(); + aggregate_gt_pivot += BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_aggregate_gt_pivot); + if (tx == 0) { + temp_storage.block_aggregate.value = aggregate_gt_pivot; } + __syncthreads(); float aggregate = 0; float q = temp_storage.block_aggregate.value; @@ -1187,6 +1186,8 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, double pivot_1 = (pivot_0 + high) / 2; ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; + ValueCount threadlocal_aggregate_gt_pivot_0{0, 0}; + ValueCount threadlocal_aggregate_gt_pivot_1{0, 0}; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); @@ -1203,26 +1204,27 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, probs_gt_pivot_1[j] = { (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + threadlocal_aggregate_gt_pivot_0 += probs_gt_pivot_0[j]; + threadlocal_aggregate_gt_pivot_1 += probs_gt_pivot_1[j]; } + } + aggregate_gt_pivot_0 += + BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) + .Sum(threadlocal_aggregate_gt_pivot_0); + if (tx == 0) { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; - aggregate_gt_pivot_0 += - BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0); - if (tx == 0) { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; - } - __syncthreads(); - aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; - - aggregate_gt_pivot_1 += - BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1); - if (tx == 0) { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; - } - __syncthreads(); - aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; + aggregate_gt_pivot_1 += + BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) + .Sum(threadlocal_aggregate_gt_pivot_1); + if (tx == 0) { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; if (aggregate_gt_pivot_0.count < k && aggregate_gt_pivot_0.value < p) { // case 1: pivot_0 accepted break; @@ -1663,6 +1665,8 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; min_gt_low = high; max_le_high = low; + float threadlocal_aggregate_gt_pivot_0 = 0; + float threadlocal_aggregate_gt_pivot_1 = 0; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); @@ -1682,18 +1686,19 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { max_le_high = max(max_le_high, probs_vec[j]); } + threadlocal_aggregate_gt_pivot_0 += probs_gt_pivot_0[j]; + threadlocal_aggregate_gt_pivot_1 += probs_gt_pivot_1[j]; } - - aggregate_gt_pivot_0 += - BlockReduce(temp_storage.block_prim.reduce) - .template Sum(probs_gt_pivot_0); - __syncthreads(); - - aggregate_gt_pivot_1 += - BlockReduce(temp_storage.block_prim.reduce) - .template Sum(probs_gt_pivot_1); - __syncthreads(); } + aggregate_gt_pivot_0 = + BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_aggregate_gt_pivot_0); + __syncthreads(); + aggregate_gt_pivot_1 = + BlockReduce(temp_storage.block_prim.reduce) + .Sum(threadlocal_aggregate_gt_pivot_1); + __syncthreads(); + min_gt_low = BlockReduce(temp_storage.block_prim.reduce) .Reduce(min_gt_low, MinReduceOp{}); __syncthreads(); @@ -1783,6 +1788,8 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType int aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; min_gt_low = high; max_le_high = low; + int threadlocal_aggregate_gt_pivot_0 = 0; + int threadlocal_aggregate_gt_pivot_1 = 0; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { logits_vec.fill(0); @@ -1803,18 +1810,20 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType if (logits_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { max_le_high = max(max_le_high, logits_vec[j]); } + threadlocal_aggregate_gt_pivot_0 += probs_gt_pivot_0_count[j]; + threadlocal_aggregate_gt_pivot_1 += probs_gt_pivot_1_count[j]; } + } + aggregate_gt_pivot_0 += + BlockReduce(temp_storage.block_prim.reduce_int) + .Sum(threadlocal_aggregate_gt_pivot_0); + __syncthreads(); - aggregate_gt_pivot_0 += - BlockReduce(temp_storage.block_prim.reduce_int) - .Sum(probs_gt_pivot_0_count); - __syncthreads(); + aggregate_gt_pivot_1 += + BlockReduce(temp_storage.block_prim.reduce_int) + .Sum(threadlocal_aggregate_gt_pivot_1); + __syncthreads(); - aggregate_gt_pivot_1 += - BlockReduce(temp_storage.block_prim.reduce_int) - .Sum(probs_gt_pivot_1_count); - __syncthreads(); - } min_gt_low = BlockReduce(temp_storage.block_prim.reduce) .Reduce(min_gt_low, MinReduceOp{}); @@ -1901,6 +1910,8 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; min_gt_low = high; max_le_high = low; + ValueCount threadlocal_aggregate_gt_pivot_0{0, 0}, + threadlocal_aggregate_gt_pivot_1{0, 0}; #pragma unroll 1 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); @@ -1923,18 +1934,20 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { max_le_high = max(max_le_high, probs_vec[j]); } + threadlocal_aggregate_gt_pivot_0 += probs_gt_pivot_0_pair[j]; + threadlocal_aggregate_gt_pivot_1 += probs_gt_pivot_1_pair[j]; } + } + aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(threadlocal_aggregate_gt_pivot_0); + __syncthreads(); - aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .template Sum(probs_gt_pivot_0_pair); - __syncthreads(); + aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(threadlocal_aggregate_gt_pivot_1); + __syncthreads(); - aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .template Sum(probs_gt_pivot_1_pair); - __syncthreads(); - } min_gt_low = BlockReduce(temp_storage.block_prim.reduce) .Reduce(min_gt_low, MinReduceOp{});