diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index 3cd52160dfb6..633529edf16d 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -408,18 +408,18 @@ def run_benchmarks( rms_eps = 1e-6 results = {} - vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) use_oneshot_options = [False] if no_oneshot else [True, False] - # Create RMSNorm and QuantFP8 layers once for native benchmarks - if "none" in quant_modes: # Standard AllReduce + RMSNorm + # Re-create VllmFusedAllreduce per config so CustomOp binds the + # correct forward method (native vs custom kernel). for custom_op in ["-rms_norm", "+rms_norm"]: with set_current_vllm_config( VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op])) ): try: + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) suffix = ( "_custom_rms_norm" if "+" in custom_op else "_native_rms_norm" ) @@ -438,6 +438,7 @@ def run_benchmarks( VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) ): try: + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) standard_allreduce_rmsnorm_native_compiled = torch.compile( vllm_fused_allreduce.allreduce_rmsnorm, fullgraph=True, @@ -482,7 +483,7 @@ def run_benchmarks( "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm" ) for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]: - suffix += ( + op_suffix = suffix + ( "_custom_quant_fp8" if "+" in quant_fp8_custom_op else "_native_quant_fp8" @@ -495,16 +496,17 @@ def run_benchmarks( ) ): try: + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) time_ms = benchmark_operation( vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, input_tensor, residual=residual, scale_factor=scale_fp8, ) - results[f"standard_allreduce{suffix}"] = time_ms + results[f"standard_allreduce{op_suffix}"] = time_ms except Exception as e: logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) - results[f"standard_allreduce{suffix}"] = float("inf") + results[f"standard_allreduce{op_suffix}"] = float("inf") # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled with set_current_vllm_config( @@ -515,6 +517,7 @@ def run_benchmarks( ) ): try: + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile( vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, fullgraph=True, @@ -580,6 +583,7 @@ def run_benchmarks( ) ): try: + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) time_ms = benchmark_operation( vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, input_tensor, @@ -598,6 +602,7 @@ def run_benchmarks( VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) ): try: + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile( vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, fullgraph=True,