diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py index d81df6f33737..3dd1118aa8a4 100644 --- a/vllm/model_executor/models/longcat_flash.py +++ b/vllm/model_executor/models/longcat_flash.py @@ -69,6 +69,7 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import ( + AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, @@ -485,6 +486,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config + self.quant_config = quant_config self.vocab_size = config.vocab_size @@ -551,77 +553,6 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - -class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP): - """Flash model for causal language modeling.""" - - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = FlashConfig(**vllm_config.model_config.hf_config.__dict__) - quant_config = vllm_config.quant_config - - self.config = config - config.intermediate_size = ( - config.ffn_hidden_size - if hasattr(config, "ffn_hidden_size") - else config.intermediate_size - ) - - self.quant_config = quant_config - - self.model = FlashModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) - - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - else: - self.lm_head = PPMissingLayer() - - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor | None, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor | IntermediateTensors: - hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds - ) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - logits = self.logits_processor(self.lm_head, hidden_states) - return logits - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) @@ -730,9 +661,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params.add(name) for layer_id in range(self.config.num_hidden_layers): for i in range(2): - if isinstance(self.model.layers[layer_id], PPMissingLayer): + if isinstance(self.layers[layer_id], PPMissingLayer): continue - self_attn = self.model.layers[layer_id].self_attn[i] + self_attn = self.layers[layer_id].self_attn[i] if hasattr( self.quant_config, "weight_block_size" ) and self_attn.kv_b_proj.weight.dtype in ( @@ -765,3 +696,81 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self.config.hidden_size / self.config.kv_lora_rank ) ** 0.5 return loaded_params + + +class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + """Flash model for causal language modeling.""" + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = FlashConfig(**vllm_config.model_config.hf_config.__dict__) + quant_config = vllm_config.quant_config + + self.config = config + config.intermediate_size = ( + config.ffn_hidden_size + if hasattr(config, "ffn_hidden_size") + else config.intermediate_size + ) + + self.quant_config = quant_config + + self.model = FlashModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights)