-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Model] Refactor JambaForCausalLM #21394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |||||||||||||
| from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||||||||||||||
| DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) | ||||||||||||||
| from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||||||||||||||
| from vllm.model_executor.models.llama import LlamaMLP as JambaMLP | ||||||||||||||
| from vllm.model_executor.models.mamba_cache import (MambaCacheManager, | ||||||||||||||
| MambaCacheParams) | ||||||||||||||
| from vllm.model_executor.sampling_metadata import SamplingMetadata | ||||||||||||||
|
|
@@ -33,7 +34,7 @@ | |||||||||||||
|
|
||||||||||||||
| from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, | ||||||||||||||
| SupportsV0Only) | ||||||||||||||
| from .utils import (is_pp_missing_parameter, | ||||||||||||||
| from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, | ||||||||||||||
| make_empty_intermediate_tensors_factory, make_layers, | ||||||||||||||
| maybe_prefix) | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -87,23 +88,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |||||||||||||
| return hidden_states.view(orig_shape) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class JambaMLP(JambaMoE): | ||||||||||||||
|
|
||||||||||||||
| def __init__(self, | ||||||||||||||
| config: JambaConfig, | ||||||||||||||
| params_dtype: Optional[torch.dtype] = None, | ||||||||||||||
| tp_size: Optional[int] = None, | ||||||||||||||
| quant_config: Optional[QuantizationConfig] = None, | ||||||||||||||
| prefix: str = ""): | ||||||||||||||
| super().__init__(config, | ||||||||||||||
| num_experts=1, | ||||||||||||||
| top_k=1, | ||||||||||||||
| params_dtype=params_dtype, | ||||||||||||||
| tp_size=tp_size, | ||||||||||||||
| quant_config=quant_config, | ||||||||||||||
| prefix=prefix) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class JambaMambaDecoderLayer(nn.Module): | ||||||||||||||
|
|
||||||||||||||
| def __init__(self, | ||||||||||||||
|
|
@@ -132,10 +116,20 @@ def __init__(self, | |||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| num_experts = config.layers_num_experts[layer_idx] | ||||||||||||||
| ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP | ||||||||||||||
| self.feed_forward = ffn_layer_class(config, | ||||||||||||||
| quant_config=quant_config, | ||||||||||||||
| prefix=f"{prefix}.feed_forward") | ||||||||||||||
| if num_experts > 1: | ||||||||||||||
| self.feed_forward = JambaMoE( | ||||||||||||||
| config, | ||||||||||||||
| quant_config=quant_config, | ||||||||||||||
| prefix=f"{prefix}.feed_forward", | ||||||||||||||
| ) | ||||||||||||||
| else: | ||||||||||||||
| self.feed_forward = JambaMLP( | ||||||||||||||
| config.hidden_size, | ||||||||||||||
| config.intermediate_size, | ||||||||||||||
| config.hidden_act, | ||||||||||||||
| quant_config=quant_config, | ||||||||||||||
| prefix=f"{prefix}.feed_forward", | ||||||||||||||
| ) | ||||||||||||||
|
Comment on lines
+119
to
+132
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah this is cleaner |
||||||||||||||
| self.input_layernorm = RMSNorm(config.hidden_size, | ||||||||||||||
| eps=config.rms_norm_eps) | ||||||||||||||
| self.pre_ff_layernorm = RMSNorm(config.hidden_size, | ||||||||||||||
|
|
@@ -216,10 +210,20 @@ def __init__(self, | |||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| num_experts = config.layers_num_experts[layer_idx] | ||||||||||||||
| ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP | ||||||||||||||
| self.feed_forward = ffn_layer_class(config, | ||||||||||||||
| quant_config=quant_config, | ||||||||||||||
| prefix=f"{prefix}.feed_forward") | ||||||||||||||
| if num_experts > 1: | ||||||||||||||
| self.feed_forward = JambaMoE( | ||||||||||||||
| config, | ||||||||||||||
| quant_config=quant_config, | ||||||||||||||
| prefix=f"{prefix}.feed_forward", | ||||||||||||||
| ) | ||||||||||||||
| else: | ||||||||||||||
| self.feed_forward = JambaMLP( | ||||||||||||||
| config.hidden_size, | ||||||||||||||
| config.intermediate_size, | ||||||||||||||
| config.hidden_act, | ||||||||||||||
| quant_config=quant_config, | ||||||||||||||
| prefix=f"{prefix}.feed_forward", | ||||||||||||||
| ) | ||||||||||||||
| self.input_layernorm = RMSNorm(config.hidden_size, | ||||||||||||||
| eps=config.rms_norm_eps) | ||||||||||||||
| self.pre_ff_layernorm = RMSNorm(config.hidden_size, | ||||||||||||||
|
|
@@ -359,15 +363,98 @@ def forward( | |||||||||||||
| hidden_states, _ = self.final_layernorm(hidden_states, residual) | ||||||||||||||
| return hidden_states | ||||||||||||||
|
|
||||||||||||||
| def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: | ||||||||||||||
| # Params for weights, fp8 weight scales, fp8 activation scales | ||||||||||||||
| # (param_name, weight_name, expert_id, shard_id) | ||||||||||||||
| return FusedMoE.make_expert_params_mapping( | ||||||||||||||
| ckpt_gate_proj_name="gate_proj", | ||||||||||||||
| ckpt_down_proj_name="down_proj", | ||||||||||||||
| ckpt_up_proj_name="up_proj", | ||||||||||||||
| num_experts=self.config.num_experts) | ||||||||||||||
|
|
||||||||||||||
| def load_weights(self, weights: Iterable[tuple[str, | ||||||||||||||
| torch.Tensor]]) -> set[str]: | ||||||||||||||
| stacked_params_mapping = [ | ||||||||||||||
| # (param_name, shard_name, shard_id) | ||||||||||||||
| ("qkv_proj", "q_proj", "q"), | ||||||||||||||
| ("qkv_proj", "k_proj", "k"), | ||||||||||||||
| ("qkv_proj", "v_proj", "v"), | ||||||||||||||
|
Comment on lines
+379
to
+381
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Style Guide References
Suggested change
Footnotes |
||||||||||||||
| (".gate_up_proj", ".gate_proj", 0), | ||||||||||||||
| (".gate_up_proj", ".up_proj", 1), | ||||||||||||||
|
Comment on lines
+382
to
+383
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we have to handle these now? (the old code didn't mention these)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These exist in the old version of the code https://github.com/vllm-project/vllm/blob/v0.5.4/vllm/model_executor/models/jamba.py#L857, this PR is just a revert |
||||||||||||||
| ] | ||||||||||||||
|
|
||||||||||||||
| params_dict = dict(self.named_parameters()) | ||||||||||||||
| loaded_params: set[str] = set() | ||||||||||||||
| expert_params_mapping = self.get_expert_mapping() | ||||||||||||||
| for name, loaded_weight in weights: | ||||||||||||||
| if "rotary_emb.inv_freq" in name: | ||||||||||||||
| continue | ||||||||||||||
| for param_name, weight_name, shard_id in stacked_params_mapping: | ||||||||||||||
| if weight_name not in name: | ||||||||||||||
| continue | ||||||||||||||
| if 'experts' in name: | ||||||||||||||
| continue | ||||||||||||||
| name = name.replace(weight_name, param_name) | ||||||||||||||
| # Skip loading extra bias for GPTQ models. | ||||||||||||||
|
|
||||||||||||||
| if name.endswith(".bias") and name not in params_dict: | ||||||||||||||
jeejeelee marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
| continue | ||||||||||||||
| # Skip layers on other devices. | ||||||||||||||
| if is_pp_missing_parameter(name, self): | ||||||||||||||
| continue | ||||||||||||||
| param = params_dict[name] | ||||||||||||||
| weight_loader = param.weight_loader | ||||||||||||||
| weight_loader(param, loaded_weight, shard_id) | ||||||||||||||
| break | ||||||||||||||
| else: | ||||||||||||||
| for ( | ||||||||||||||
| param_name, | ||||||||||||||
| weight_name, | ||||||||||||||
| expert_id, | ||||||||||||||
| shard_id, | ||||||||||||||
| ) in expert_params_mapping: | ||||||||||||||
| if weight_name not in name: | ||||||||||||||
| continue | ||||||||||||||
|
|
||||||||||||||
| if is_pp_missing_parameter(name, self): | ||||||||||||||
| continue | ||||||||||||||
| name = name.replace(weight_name, param_name) | ||||||||||||||
| param = params_dict[name] | ||||||||||||||
| weight_loader = param.weight_loader | ||||||||||||||
| weight_loader(param, | ||||||||||||||
| loaded_weight, | ||||||||||||||
| name, | ||||||||||||||
| shard_id=shard_id, | ||||||||||||||
| expert_id=expert_id) | ||||||||||||||
| break | ||||||||||||||
| else: | ||||||||||||||
| # Skip loading extra bias for GPTQ models. | ||||||||||||||
| if name.endswith(".bias") and name not in params_dict: | ||||||||||||||
| continue | ||||||||||||||
| if is_pp_missing_parameter(name, self): | ||||||||||||||
| continue | ||||||||||||||
|
|
||||||||||||||
| param = params_dict[name] | ||||||||||||||
| weight_loader = getattr(param, "weight_loader", | ||||||||||||||
| default_weight_loader) | ||||||||||||||
| weight_loader(param, loaded_weight) | ||||||||||||||
| loaded_params.add(name) | ||||||||||||||
| return loaded_params | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, | ||||||||||||||
| IsHybrid, SupportsV0Only): | ||||||||||||||
| hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ | ||||||||||||||
| ".self_attn.": ".", | ||||||||||||||
| ".A_log": ".A" | ||||||||||||||
| }, ) | ||||||||||||||
| packed_modules_mapping = { | ||||||||||||||
| "qkv_proj": [ | ||||||||||||||
| "q_proj", | ||||||||||||||
| "k_proj", | ||||||||||||||
| "v_proj", | ||||||||||||||
| ], | ||||||||||||||
| "gate_up_proj": ["gate_proj", "up_proj"], | ||||||||||||||
| "in_proj": ["in_proj"], | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -468,96 +555,11 @@ def compute_logits( | |||||||||||||
|
|
||||||||||||||
| def load_weights(self, weights: Iterable[tuple[str, | ||||||||||||||
| torch.Tensor]]) -> set[str]: | ||||||||||||||
| stacked_params_mapping = [ | ||||||||||||||
| # (param_name, shard_name, shard_id) | ||||||||||||||
| ("qkv_proj", "q_proj", "q"), | ||||||||||||||
| ("qkv_proj", "k_proj", "k"), | ||||||||||||||
| ("qkv_proj", "v_proj", "v"), | ||||||||||||||
| ] | ||||||||||||||
|
|
||||||||||||||
| # Params for weights, fp8 weight scales, fp8 activation scales | ||||||||||||||
| # (param_name, weight_name, expert_id, shard_id) | ||||||||||||||
| expert_params_mapping = FusedMoE.make_expert_params_mapping( | ||||||||||||||
| ckpt_gate_proj_name="gate_proj", | ||||||||||||||
| ckpt_down_proj_name="down_proj", | ||||||||||||||
| ckpt_up_proj_name="up_proj", | ||||||||||||||
| num_experts=self.config.num_experts) | ||||||||||||||
|
|
||||||||||||||
| params_dict = dict(self.named_parameters()) | ||||||||||||||
| loaded_params: set[str] = set() | ||||||||||||||
| for name, loaded_weight in weights: | ||||||||||||||
| if "rotary_emb.inv_freq" in name: | ||||||||||||||
| continue | ||||||||||||||
|
|
||||||||||||||
| if "A_log" in name: | ||||||||||||||
| name = name.replace("A_log", "A") | ||||||||||||||
|
|
||||||||||||||
| if ".self_attn." in name: | ||||||||||||||
| name = name.replace(".self_attn", "") | ||||||||||||||
|
|
||||||||||||||
| if "feed_forward" in name and not _is_moe_layer(name): | ||||||||||||||
| ## map MLP layers to expert with ID=0 | ||||||||||||||
| name = name.replace("feed_forward", "feed_forward.experts.0") | ||||||||||||||
|
|
||||||||||||||
| for param_name, weight_name, shard_id in stacked_params_mapping: | ||||||||||||||
| if weight_name not in name: | ||||||||||||||
| continue | ||||||||||||||
| if 'experts' in name: | ||||||||||||||
| continue | ||||||||||||||
| name = name.replace(weight_name, param_name) | ||||||||||||||
| # Skip loading extra bias for GPTQ models. | ||||||||||||||
|
|
||||||||||||||
| if name.endswith(".bias") and name not in params_dict: | ||||||||||||||
| continue | ||||||||||||||
| # Skip layers on other devices. | ||||||||||||||
| if is_pp_missing_parameter(name, self): | ||||||||||||||
| continue | ||||||||||||||
| param = params_dict[name] | ||||||||||||||
| weight_loader = param.weight_loader | ||||||||||||||
| weight_loader(param, loaded_weight, shard_id) | ||||||||||||||
| break | ||||||||||||||
| else: | ||||||||||||||
| for ( | ||||||||||||||
| param_name, | ||||||||||||||
| weight_name, | ||||||||||||||
| expert_id, | ||||||||||||||
| shard_id, | ||||||||||||||
| ) in expert_params_mapping: | ||||||||||||||
| if weight_name not in name: | ||||||||||||||
| continue | ||||||||||||||
|
|
||||||||||||||
| if is_pp_missing_parameter(name, self): | ||||||||||||||
| continue | ||||||||||||||
| name = name.replace(weight_name, param_name) | ||||||||||||||
| param = params_dict[name] | ||||||||||||||
| weight_loader = param.weight_loader | ||||||||||||||
| weight_loader(param, | ||||||||||||||
| loaded_weight, | ||||||||||||||
| name, | ||||||||||||||
| shard_id=shard_id, | ||||||||||||||
| expert_id=expert_id) | ||||||||||||||
| break | ||||||||||||||
| else: | ||||||||||||||
| # Skip loading extra bias for GPTQ models. | ||||||||||||||
| if name.endswith(".bias") and name not in params_dict: | ||||||||||||||
| continue | ||||||||||||||
| if is_pp_missing_parameter(name, self): | ||||||||||||||
| continue | ||||||||||||||
|
|
||||||||||||||
| param = params_dict[name] | ||||||||||||||
| weight_loader = getattr(param, "weight_loader", | ||||||||||||||
| default_weight_loader) | ||||||||||||||
| weight_loader(param, loaded_weight) | ||||||||||||||
| loaded_params.add(name) | ||||||||||||||
| return loaded_params | ||||||||||||||
|
|
||||||||||||||
| loader = AutoWeightsLoader(self) | ||||||||||||||
| return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) | ||||||||||||||
|
|
||||||||||||||
| def _is_moe_layer(name: str): | ||||||||||||||
| return any( | ||||||||||||||
| [experts_name in name for experts_name in [ | ||||||||||||||
| "experts", | ||||||||||||||
| "router", | ||||||||||||||
| ]]) | ||||||||||||||
| def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: | ||||||||||||||
| return self.model.get_expert_mapping() | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class JambaForSequenceClassification(JambaForCausalLM): | ||||||||||||||
|
|
||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#7415 Change JambaMLP to MoE , now this PR wants to revert, do you accept? @mzusman
The main motivation is to standardize the model implementation and support BNB quantization