From 1687b2beb2cc0bbc6f1a96de8b1367cd6af9d745 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Tue, 3 Mar 2026 00:13:40 +0000 Subject: [PATCH] [SGLang-Diffusion] Fix fake impl eps default for torch.compile tracing Add default value `eps=1e-5` to `register_fake` implementations of `fused_norm_scale_shift` and `fused_scale_residual_norm_scale_shift` custom ops, matching the default in the actual custom_op signatures. Made-with: Cursor --- .../diffusion/cutedsl/scale_residual_norm_scale_shift.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py b/python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py index ba8782d3d8e4..734b9bbf78ad 100644 --- a/python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py +++ b/python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py @@ -341,7 +341,7 @@ def fused_norm_scale_shift( @fused_norm_scale_shift.register_fake -def _fused_norm_scale_shift_fake(x, weight, bias, scale, shift, norm_type, eps): +def _fused_norm_scale_shift_fake(x, weight, bias, scale, shift, norm_type, eps=1e-5): y = x.new_empty(x.shape) return y @@ -424,7 +424,7 @@ def fused_scale_residual_norm_scale_shift( @fused_scale_residual_norm_scale_shift.register_fake def _fused_scale_residual_norm_scale_shift_fake( - residual, x, gate, weight, bias, scale, shift, norm_type, eps + residual, x, gate, weight, bias, scale, shift, norm_type, eps=1e-5 ): y = x.new_empty(x.shape) residual_out = x.new_empty(x.shape)