Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 82 additions & 73 deletions vllm/model_executor/models/longcat_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
Expand Down Expand Up @@ -485,6 +486,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config

self.vocab_size = config.vocab_size

Expand Down Expand Up @@ -551,77 +553,6 @@ def forward(
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"""Flash model for causal language modeling."""

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = FlashConfig(**vllm_config.model_config.hf_config.__dict__)
quant_config = vllm_config.quant_config

self.config = config
config.intermediate_size = (
config.ffn_hidden_size
if hasattr(config, "ffn_hidden_size")
else config.intermediate_size
)

self.quant_config = quant_config

self.model = FlashModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)

if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()

self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)

def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states

def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
Expand Down Expand Up @@ -730,9 +661,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_params.add(name)
for layer_id in range(self.config.num_hidden_layers):
for i in range(2):
if isinstance(self.model.layers[layer_id], PPMissingLayer):
if isinstance(self.layers[layer_id], PPMissingLayer):
continue
self_attn = self.model.layers[layer_id].self_attn[i]
self_attn = self.layers[layer_id].self_attn[i]
if hasattr(
self.quant_config, "weight_block_size"
) and self_attn.kv_b_proj.weight.dtype in (
Expand Down Expand Up @@ -765,3 +696,81 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
self.config.hidden_size / self.config.kv_lora_rank
) ** 0.5
return loaded_params


class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"""Flash model for causal language modeling."""

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = FlashConfig(**vllm_config.model_config.hf_config.__dict__)
quant_config = vllm_config.quant_config

self.config = config
config.intermediate_size = (
config.ffn_hidden_size
if hasattr(config, "ffn_hidden_size")
else config.intermediate_size
)

self.quant_config = quant_config

self.model = FlashModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)

if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()

self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)

def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states

def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()

Comment on lines +771 to +773
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This get_expert_mapping method appears to be a remnant from the refactoring and is now dead code. The weight loading is now handled by AutoWeightsLoader, which delegates to FlashModel.load_weights. FlashModel.load_weights in turn calls FlashModel.get_expert_mapping. This version of get_expert_mapping on LongcatFlashForCausalLM is no longer used.

Leaving this here could cause confusion or bugs in the future if other code accidentally calls it. It would be best to remove it to improve maintainability.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the outer get_expert_mapping implementation to forward to FlashModel.get_expert_mapping, matching the pattern used by other MoE ForCausalLM wrappers such as Qwen2MoeForCausalLM to avoid duplicate loading logic.

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
Loading