diff --git a/python/sglang/srt/models/longcat_flash.py b/python/sglang/srt/models/longcat_flash.py index fa7ba01bc048..f3860e572f0b 100644 --- a/python/sglang/srt/models/longcat_flash.py +++ b/python/sglang/srt/models/longcat_flash.py @@ -388,6 +388,7 @@ def __init__( layer_scatter_modes=self.mlp_layer_scatter_modes[i], input_layernorm=self.input_layernorm[i], post_attention_layernorm=self.post_attention_layernorm[i], + qkv_latent_func=self.self_attn[i].prepare_qkv_latent, ) for i in range(2) ] @@ -402,6 +403,7 @@ def __init__( layer_scatter_modes=self.moe_layer_scatter_modes, input_layernorm=self.input_layernorm[0], post_attention_layernorm=self.post_attention_layernorm[0], + qkv_latent_func=self.self_attn[0].prepare_qkv_latent, ) def forward(