diff --git a/vllm/model_executor/kernels/mhc/tilelang_kernels.py b/vllm/model_executor/kernels/mhc/tilelang_kernels.py index 5cc91a470a31..9fa13041b3f3 100644 --- a/vllm/model_executor/kernels/mhc/tilelang_kernels.py +++ b/vllm/model_executor/kernels/mhc/tilelang_kernels.py @@ -309,7 +309,7 @@ def mhc_pre_big_fuse_with_norm_tilelang( sumsq_per_pos = T.alloc_fragment(hidden_block, T.float32) T.clear(sumsq_per_pos) - for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=3): + for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2): xs = T.alloc_shared((hc_mult, hidden_block), T.bfloat16) xl = T.alloc_fragment((hc_mult, hidden_block), T.float32) T.copy(residual[i, 0, i0_h * hidden_block], xs)