diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index dfb64fcd0869..581032ef7d24 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -74,6 +74,8 @@ class PretrainedConfig(PushToHubMixin): naming of attributes. - **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor parallel plan applied to the sub-module when `model.tensor_parallel` is called. + - **base_model_pp_plan** (`Dict[str, Tuple[List[str]]]`) -- A dict that maps child-modules of a base model to a + pipeline parallel plan that enables users to place the child-module on the appropriate device. Common attributes (present in all subclasses): @@ -198,6 +200,7 @@ class PretrainedConfig(PushToHubMixin): is_composition: bool = False attribute_map: Dict[str, str] = {} base_model_tp_plan: Optional[Dict[str, Any]] = None + base_model_pp_plan: Optional[Dict[str, Tuple[List[str]]]] = None _auto_class: Optional[str] = None def __setattr__(self, key, value): @@ -860,6 +863,9 @@ def to_diff_dict(self) -> Dict[str, Any]: # Do not serialize `base_model_tp_plan` for now if "base_model_tp_plan" in serializable_config_dict: del serializable_config_dict["base_model_tp_plan"] + # Do not serialize `base_model_pp_plan` for now + if "base_model_pp_plan" in serializable_config_dict: + del serializable_config_dict["base_model_pp_plan"] return serializable_config_dict @@ -882,6 +888,9 @@ def to_dict(self) -> Dict[str, Any]: # Do not serialize `base_model_tp_plan` for now if "base_model_tp_plan" in output: del output["base_model_tp_plan"] + # Do not serialize `base_model_pp_plan` for now + if "base_model_pp_plan" in output: + del output["base_model_pp_plan"] # Transformers version when serializing the model output["transformers_version"] = __version__ diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e292b1061a28..13c8719b3603 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -28,6 +28,7 @@ import warnings from contextlib import contextmanager from dataclasses import dataclass +from enum import Enum from functools import partial, wraps from threading import Thread from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union @@ -923,6 +924,11 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: return weights_name +class PipelineParallel(Enum): + inputs: 0 + outputs: 1 + + class ModuleUtilsMixin: """ A few utilities for `torch.nn.Modules`, to be used as a mixin. @@ -1312,6 +1318,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # `config.base_model_tp_plan` during `post_init`. _tp_plan = None + # A pipeline parallel plan specifying the layers which may not be present + # on all ranks when PP is enabled. For top-level models, this attribute is + # currently defined in respective model code. For base models, this + # attribute comes from `config.base_model_pp_plan` during `post_init`. + # + # The variable names for the inputs and outputs of the specified layers can + # be indexed using the `PipelineParallel` enum as follows: + # - `_pp_plan["layers"][PipelineParallel.inputs]` + # - `_pp_plan["layers"][PipelineParallel.outputs]` + _pp_plan = None + # This flag signal that the model can be used as an efficient backend in TGI and vLLM # In practice, it means that they support attention interface functions, fully pass the kwargs # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan @@ -1374,6 +1391,9 @@ def post_init(self): # If current model is a base model, attach `base_model_tp_plan` from config if self.base_model is self: self._tp_plan = self.config.base_model_tp_plan + # If current model is a base model, attach `base_model_pp_plan` from config + if self.base_model is self: + self._pp_plan = self.config.base_model_pp_plan def dequantize(self): """ @@ -5196,6 +5216,15 @@ def tplize(mod: torch.nn.Module) -> None: # function to every submodule. self.apply(tplize) + @property + def supports_pp_plan(self): + if self._pp_plan is not None: + return True + # Check if base model has PP plan + if getattr(self.base_model, "_pp_plan", None) is not None: + return True + return False + @property def loss_function(self): if hasattr(self, "_loss_function"): diff --git a/src/transformers/models/aria/configuration_aria.py b/src/transformers/models/aria/configuration_aria.py index ff34d59f5dfe..fed90c86b4a7 100644 --- a/src/transformers/models/aria/configuration_aria.py +++ b/src/transformers/models/aria/configuration_aria.py @@ -144,6 +144,11 @@ class AriaTextConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } base_config_key = "text_config" def __init__( diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index ee692c9616f9..dacc92b7951f 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1141,6 +1141,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = AriaTextConfig def __init__(self, config: AriaTextConfig): diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 41ba1c5b26e0..6fdce41e5a68 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1446,6 +1446,7 @@ def _update_mamba_mask(self, attention_mask, cache_position): class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/cohere/configuration_cohere.py b/src/transformers/models/cohere/configuration_cohere.py index dc9ca2cb4dd0..eeeb23642802 100644 --- a/src/transformers/models/cohere/configuration_cohere.py +++ b/src/transformers/models/cohere/configuration_cohere.py @@ -148,6 +148,11 @@ class CohereConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 90b4e6dc63c1..5101a0f9e083 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -780,6 +780,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/cohere2/configuration_cohere2.py b/src/transformers/models/cohere2/configuration_cohere2.py index 88d3265eadfe..c792ab3f8278 100644 --- a/src/transformers/models/cohere2/configuration_cohere2.py +++ b/src/transformers/models/cohere2/configuration_cohere2.py @@ -148,6 +148,11 @@ class Cohere2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index e900413740cc..df0cb24d79c4 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -781,6 +781,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config: Cohere2Config): super().__init__(config) diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index f24e1378ecc7..979b5abc2600 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -173,6 +173,11 @@ class Cohere2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 38d9d3ce001d..301668d21a82 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -1019,6 +1019,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 38b285be7373..ef086ab12ea8 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1598,6 +1598,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = Emu3TextConfig def __init__(self, config): diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index b8470e92fb32..2aeb20058058 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -102,6 +102,11 @@ class GemmaConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 047516a4b162..59b7dc3dc347 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -752,6 +752,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 540ce2b87c15..dc8ced15f962 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -126,6 +126,11 @@ class GemmaConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 3dfbd6a107e9..c9e66f8beace 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -106,6 +106,11 @@ class Gemma2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e21d0b656a86..d55fafe05677 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -790,6 +790,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 351f083f813d..76123af3ec52 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -132,6 +132,11 @@ class Gemma2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/glm/configuration_glm.py b/src/transformers/models/glm/configuration_glm.py index 553b71cf234d..f9a3ab53a931 100644 --- a/src/transformers/models/glm/configuration_glm.py +++ b/src/transformers/models/glm/configuration_glm.py @@ -93,6 +93,11 @@ class GlmConfig(PretrainedConfig): "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 0e73af6b1554..54c138212e86 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -761,6 +761,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index e570662c1021..cea854eabb94 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -137,6 +137,12 @@ class GPTNeoXConfig(PretrainedConfig): "layers.*.mlp.dense_h_to_4h": "colwise", "layers.*.mlp.dense_4h_to_h": "rowwise", } + base_model_pp_plan = { + "embed_in": (["input_ids"], ["inputs_embeds"]), + "emb_dropout": (["inputs_embeds"], ["hidden_states"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "final_layer_norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index d83ee58af5ee..efb298243177 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -758,6 +758,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] _tp_plan = {"embed_out": "colwise_rep"} + _pp_plan = {"embed_out": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 295882a9eedb..3a7cc49542ef 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -456,6 +456,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] _tp_plan = {"embed_out": "colwise_rep"} + _pp_plan = {"embed_out": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/granite/configuration_granite.py b/src/transformers/models/granite/configuration_granite.py index 404d60ca32a3..fc651a94e1bd 100644 --- a/src/transformers/models/granite/configuration_granite.py +++ b/src/transformers/models/granite/configuration_granite.py @@ -122,6 +122,11 @@ class GraniteConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 52cdc96e6435..85c8e97c77ad 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -764,6 +764,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/helium/configuration_helium.py b/src/transformers/models/helium/configuration_helium.py index 73c2638a88c2..7b27c6e54b69 100644 --- a/src/transformers/models/helium/configuration_helium.py +++ b/src/transformers/models/helium/configuration_helium.py @@ -95,6 +95,11 @@ class HeliumConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 3c17e18e4c11..86635f2d7200 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -748,6 +748,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config: HeliumConfig): super().__init__(config) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 646c06bdc4ba..066534f109fa 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -151,6 +151,11 @@ class LlamaConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 566fa57413f9..a06084e82567 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -750,6 +750,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/mistral/configuration_mistral.py b/src/transformers/models/mistral/configuration_mistral.py index c4b874f27017..3a237bc7343b 100644 --- a/src/transformers/models/mistral/configuration_mistral.py +++ b/src/transformers/models/mistral/configuration_mistral.py @@ -107,6 +107,11 @@ class MistralConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a6d9a54efc09..92e555b3d768 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -751,6 +751,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/mixtral/configuration_mixtral.py b/src/transformers/models/mixtral/configuration_mixtral.py index c3f7ec8e4cc1..d9b02e10fc4c 100644 --- a/src/transformers/models/mixtral/configuration_mixtral.py +++ b/src/transformers/models/mixtral/configuration_mixtral.py @@ -119,6 +119,11 @@ class MixtralConfig(PretrainedConfig): "layers.*.block_sparse_moe.experts.*.w2": "rowwise", "layers.*.block_sparse_moe.experts.*.w3": "colwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 251187677fd7..0835e33722e9 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -967,6 +967,7 @@ def load_balancing_loss_func( class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/olmo/configuration_olmo.py b/src/transformers/models/olmo/configuration_olmo.py index d80910e8456b..ded0bf4f017c 100644 --- a/src/transformers/models/olmo/configuration_olmo.py +++ b/src/transformers/models/olmo/configuration_olmo.py @@ -115,6 +115,11 @@ class OlmoConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index ef3e10582f59..37d15475be89 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -726,6 +726,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/olmo2/configuration_olmo2.py b/src/transformers/models/olmo2/configuration_olmo2.py index ce434f541608..222c8e179154 100644 --- a/src/transformers/models/olmo2/configuration_olmo2.py +++ b/src/transformers/models/olmo2/configuration_olmo2.py @@ -98,6 +98,11 @@ class Olmo2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 561b7fdf089e..40c912ef1a63 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -727,6 +727,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 04c7f0f486bd..bc5a9b89d501 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -109,6 +109,11 @@ class Olmo2Config(OlmoConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/phi/configuration_phi.py b/src/transformers/models/phi/configuration_phi.py index 2733d77ff674..06e5cbec2ead 100644 --- a/src/transformers/models/phi/configuration_phi.py +++ b/src/transformers/models/phi/configuration_phi.py @@ -146,6 +146,12 @@ class PhiConfig(PretrainedConfig): "layers.*.mlp.fc1": "colwise", "layers.*.mlp.fc2": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "embed_dropout": (["inputs_embeds"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "final_layernorm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index a62249460eef..8ab41d2a0cbe 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -724,6 +724,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/phi3/configuration_phi3.py b/src/transformers/models/phi3/configuration_phi3.py index 6fe6e1cdfca8..a6b7ec9baf56 100644 --- a/src/transformers/models/phi3/configuration_phi3.py +++ b/src/transformers/models/phi3/configuration_phi3.py @@ -113,6 +113,11 @@ class Phi3Config(PretrainedConfig): "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index ca6992d377bb..2595278048c8 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -821,6 +821,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/qwen2/configuration_qwen2.py b/src/transformers/models/qwen2/configuration_qwen2.py index 16ce924b9f16..16979865e4fe 100644 --- a/src/transformers/models/qwen2/configuration_qwen2.py +++ b/src/transformers/models/qwen2/configuration_qwen2.py @@ -139,6 +139,11 @@ class Qwen2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8310379d83da..91eac84ffcb2 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -735,6 +735,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index a1cf06d94e87..b2bf37ba0c14 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -184,6 +184,11 @@ class Qwen2_5_VLConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py index ac6e8ae17acb..a52b4204a662 100644 --- a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py @@ -160,6 +160,11 @@ class Qwen2MoeConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index bfb4e81d3ec3..8157408f42e2 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1217,6 +1217,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index 49a0836cf96a..710738e39654 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -173,6 +173,11 @@ class Qwen2VLConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/starcoder2/configuration_starcoder2.py b/src/transformers/models/starcoder2/configuration_starcoder2.py index 7f21d1f12d8b..b617a1cad842 100644 --- a/src/transformers/models/starcoder2/configuration_starcoder2.py +++ b/src/transformers/models/starcoder2/configuration_starcoder2.py @@ -143,6 +143,11 @@ class Starcoder2Config(PretrainedConfig): "layers.*.mlp.c_fc": "colwise", "layers.*.mlp.c_proj": "colwise", } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } def __init__( self, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 9314f05b4964..f176d5311dd5 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -747,6 +747,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config)