diff --git a/python/sglang/jit_kernel/diffusion/triton/scale_shift.py b/python/sglang/jit_kernel/diffusion/triton/scale_shift.py index 9768b06c34db..1c9ca007d1ec 100644 --- a/python/sglang/jit_kernel/diffusion/triton/scale_shift.py +++ b/python/sglang/jit_kernel/diffusion/triton/scale_shift.py @@ -79,13 +79,20 @@ def _fused_layernorm_scale_shift_gate_select01_kernel( shift1_ptrs = shift1_ptr + batch_idx * stride_sh1_b + cols * stride_sh1_c gate1_ptrs = gate1_ptr + batch_idx * stride_g1_b + cols * stride_g1_c - scale_ptrs = tl.where(idx, scale1_ptrs, scale0_ptrs) - shift_ptrs = tl.where(idx, shift1_ptrs, shift0_ptrs) - gate_ptrs = tl.where(idx, gate1_ptrs, gate0_ptrs) - - scale = tl.load(scale_ptrs, mask=mask, other=0.0).to(tl.float32) - shift = tl.load(shift_ptrs, mask=mask, other=0.0).to(tl.float32) - gate = tl.load(gate_ptrs, mask=mask, other=0.0) + # Branch on scalar idx instead of using tl.where on pointers. + # tl.where on pointers triggers an assertion in AMD Triton's + # CanonicalizePointers pass (ConvertArithSelectOp) on gfx950. + # This keeps it at 3 loads (not 6), avoids the pointer-level + # tl.where entirely, and since idx is uniform across all threads + # the branch has no divergence cost. + if idx: + scale = tl.load(scale1_ptrs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift1_ptrs, mask=mask, other=0.0).to(tl.float32) + gate = tl.load(gate1_ptrs, mask=mask, other=0.0) + else: + scale = tl.load(scale0_ptrs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift0_ptrs, mask=mask, other=0.0).to(tl.float32) + gate = tl.load(gate0_ptrs, mask=mask, other=0.0) y = x_hat * (1.0 + scale) + shift tl.store(out_row_ptr + cols, y, mask=mask) @@ -180,13 +187,20 @@ def _fused_residual_layernorm_scale_shift_gate_select01_kernel( shift1_ptrs = shift1_ptr + batch_idx * stride_sh1_b + cols * stride_sh1_c gate1_ptrs = gate1_ptr + batch_idx * stride_g1_b + cols * stride_g1_c - scale_ptrs = tl.where(idx, scale1_ptrs, scale0_ptrs) - shift_ptrs = tl.where(idx, shift1_ptrs, shift0_ptrs) - gate_ptrs = tl.where(idx, gate1_ptrs, gate0_ptrs) - - scale = tl.load(scale_ptrs, mask=mask, other=0.0).to(tl.float32) - shift = tl.load(shift_ptrs, mask=mask, other=0.0).to(tl.float32) - gate = tl.load(gate_ptrs, mask=mask, other=0.0) + # Branch on scalar idx instead of using tl.where on pointers. + # tl.where on pointers triggers an assertion in AMD Triton's + # CanonicalizePointers pass (ConvertArithSelectOp) on gfx950. + # This keeps it at 3 loads (not 6), avoids the pointer-level + # tl.where entirely, and since idx is uniform across all threads + # the branch has no divergence cost. + if idx: + scale = tl.load(scale1_ptrs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift1_ptrs, mask=mask, other=0.0).to(tl.float32) + gate = tl.load(gate1_ptrs, mask=mask, other=0.0) + else: + scale = tl.load(scale0_ptrs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift0_ptrs, mask=mask, other=0.0).to(tl.float32) + gate = tl.load(gate0_ptrs, mask=mask, other=0.0) y = x_hat * (1.0 + scale) + shift tl.store(out_row_ptr + cols, y, mask=mask)