From 64f1fd67ee55ab35396b26ee6907cb5205e40716 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 7 Feb 2025 13:21:31 +0000 Subject: [PATCH 01/24] Add `base_model_pp_plan` to `PretrainedConfig` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/configuration_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index dfb64fcd0869..1681f133c051 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -74,6 +74,8 @@ class PretrainedConfig(PushToHubMixin): naming of attributes. - **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor parallel plan applied to the sub-module when `model.tensor_parallel` is called. + - **base_model_pp_plan** (`Dict[str, Dict[[str, str]]]`) -- A dict that maps child-modules of a base model to a + pipeline parallel plan that enables users to place the child-module on the appropriate device. Common attributes (present in all subclasses): @@ -198,6 +200,7 @@ class PretrainedConfig(PushToHubMixin): is_composition: bool = False attribute_map: Dict[str, str] = {} base_model_tp_plan: Optional[Dict[str, Any]] = None + base_model_pp_plan: Optional[Dict[Dict[str, str]]] = None _auto_class: Optional[str] = None def __setattr__(self, key, value): @@ -860,6 +863,9 @@ def to_diff_dict(self) -> Dict[str, Any]: # Do not serialize `base_model_tp_plan` for now if "base_model_tp_plan" in serializable_config_dict: del serializable_config_dict["base_model_tp_plan"] + # Do not serialize `base_model_pp_plan` for now + if "base_model_pp_plan" in serializable_config_dict: + del serializable_config_dict["base_model_pp_plan"] return serializable_config_dict @@ -882,6 +888,9 @@ def to_dict(self) -> Dict[str, Any]: # Do not serialize `base_model_tp_plan` for now if "base_model_tp_plan" in output: del output["base_model_tp_plan"] + # Do not serialize `base_model_pp_plan` for now + if "base_model_pp_plan" in output: + del output["base_model_pp_plan"] # Transformers version when serializing the model output["transformers_version"] = __version__ From 040ba146759afc625c77f49d04982bb54873995f Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 7 Feb 2025 13:21:53 +0000 Subject: [PATCH 02/24] Add `_pp_plan` to `PreTrainedModel` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/modeling_utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e292b1061a28..2c88d1c04fbe 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1312,6 +1312,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # `config.base_model_tp_plan` during `post_init`. _tp_plan = None + # A pipeline parallel plan to be applied to the model when PP is enabled. For + # top-level models, this attribute is currently defined in respective model + # code. For base models, this attribute comes from + # `config.base_model_pp_plan` during `post_init`. + _pp_plan = None + # This flag signal that the model can be used as an efficient backend in TGI and vLLM # In practice, it means that they support attention interface functions, fully pass the kwargs # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan @@ -1374,6 +1380,9 @@ def post_init(self): # If current model is a base model, attach `base_model_tp_plan` from config if self.base_model is self: self._tp_plan = self.config.base_model_tp_plan + # If current model is a base model, attach `base_model_pp_plan` from config + if self.base_model is self: + self._pp_plan = self.config.base_model_pp_plan def dequantize(self): """ @@ -5196,6 +5205,15 @@ def tplize(mod: torch.nn.Module) -> None: # function to every submodule. self.apply(tplize) + @property + def supports_pp_plan(self): + if self._pp_plan is not None: + return True + # Check if base model has PP plan + if getattr(self.base_model, "_pp_plan", None) is not None: + return True + return False + @property def loss_function(self): if hasattr(self, "_loss_function"): From ac56826ca98988bba3336264103081f03a2b479c Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 7 Feb 2025 13:22:06 +0000 Subject: [PATCH 03/24] Add both to Llama for testing Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/llama/configuration_llama.py | 5 +++++ src/transformers/models/llama/modeling_llama.py | 1 + 2 files changed, 6 insertions(+) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 646c06bdc4ba..c163af88e533 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -151,6 +151,11 @@ class LlamaConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "input_processing": {"embed_tokens": "inputs_embeds"}, + "decoder_stack": {"layers": "layer_outputs"}, + "output_processing": {"norm": "hidden_states"}, + } def __init__( self, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 566fa57413f9..9021f26c0901 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -750,6 +750,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"output_processing": {"lm_head": "logits"}} def __init__(self, config): super().__init__(config) From fee056f4b7e09a4fd0fa7d545db66453d2deabff Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 7 Feb 2025 15:10:05 +0100 Subject: [PATCH 04/24] Fix type error Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/configuration_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 1681f133c051..c4fbbe13915e 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -200,7 +200,7 @@ class PretrainedConfig(PushToHubMixin): is_composition: bool = False attribute_map: Dict[str, str] = {} base_model_tp_plan: Optional[Dict[str, Any]] = None - base_model_pp_plan: Optional[Dict[Dict[str, str]]] = None + base_model_pp_plan: Optional[Dict[str, Any]] = None _auto_class: Optional[str] = None def __setattr__(self, key, value): From 826cd178698e2be1e849a47606c49eeec589a2f6 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 7 Feb 2025 16:44:53 +0100 Subject: [PATCH 05/24] Update to suggested schema Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/configuration_utils.py | 2 +- src/transformers/models/llama/configuration_llama.py | 6 +++--- src/transformers/models/llama/modeling_llama.py | 3 +-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index c4fbbe13915e..f9186d1cb9ff 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -200,7 +200,7 @@ class PretrainedConfig(PushToHubMixin): is_composition: bool = False attribute_map: Dict[str, str] = {} base_model_tp_plan: Optional[Dict[str, Any]] = None - base_model_pp_plan: Optional[Dict[str, Any]] = None + base_model_pp_plan: Optional[Dict[str, Dict[str, List[str]]]] = None _auto_class: Optional[str] = None def __setattr__(self, key, value): diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index c163af88e533..4ad4dda765c0 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -152,9 +152,9 @@ class LlamaConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "input_processing": {"embed_tokens": "inputs_embeds"}, - "decoder_stack": {"layers": "layer_outputs"}, - "output_processing": {"norm": "hidden_states"}, + "embed_tokens": {"input_keys": ["input_ids"], "output_keys": ["inputs_embeds"]}, + "layers.*": {"input_keys": ["hidden_states", "attention_mask"], "output_keys": ["hidden_states"]}, + "norm": {"input_keys": ["hidden_states"], "output_keys": ["hidden_states"]}, } def __init__( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 9021f26c0901..37023e25adbf 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -750,8 +750,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"output_processing": {"lm_head": "logits"}} - + _pp_plan = {"lm_head": {"input_keys": ["input"], "output_keys": ["logits"]}} def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) From cfee6801a4188c5a5cb315763b3ca729a471a72c Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 7 Feb 2025 18:35:33 +0100 Subject: [PATCH 06/24] `_pp_plan` keys are not patterns Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/llama/configuration_llama.py | 2 +- src/transformers/models/llama/modeling_llama.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 4ad4dda765c0..d95bf6d7adc8 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -153,7 +153,7 @@ class LlamaConfig(PretrainedConfig): } base_model_pp_plan = { "embed_tokens": {"input_keys": ["input_ids"], "output_keys": ["inputs_embeds"]}, - "layers.*": {"input_keys": ["hidden_states", "attention_mask"], "output_keys": ["hidden_states"]}, + "layers": {"input_keys": ["hidden_states", "attention_mask"], "output_keys": ["hidden_states"]}, "norm": {"input_keys": ["hidden_states"], "output_keys": ["hidden_states"]}, } diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 37023e25adbf..240f3b764b56 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -751,6 +751,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": {"input_keys": ["input"], "output_keys": ["logits"]}} + def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) From 7e03c0dbef8d3cfdba6948ff19c1f172117cbf58 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sat, 8 Feb 2025 11:25:21 +0100 Subject: [PATCH 07/24] Simplify schema Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/configuration_utils.py | 2 +- src/transformers/models/llama/configuration_llama.py | 6 +++--- src/transformers/models/llama/modeling_llama.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index f9186d1cb9ff..6d71711e6198 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -200,7 +200,7 @@ class PretrainedConfig(PushToHubMixin): is_composition: bool = False attribute_map: Dict[str, str] = {} base_model_tp_plan: Optional[Dict[str, Any]] = None - base_model_pp_plan: Optional[Dict[str, Dict[str, List[str]]]] = None + base_model_pp_plan: Optional[Dict[List[List[str]]]] = None _auto_class: Optional[str] = None def __setattr__(self, key, value): diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index d95bf6d7adc8..b810c9667f69 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -152,9 +152,9 @@ class LlamaConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": {"input_keys": ["input_ids"], "output_keys": ["inputs_embeds"]}, - "layers": {"input_keys": ["hidden_states", "attention_mask"], "output_keys": ["hidden_states"]}, - "norm": {"input_keys": ["hidden_states"], "output_keys": ["hidden_states"]}, + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], } def __init__( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 240f3b764b56..406f7ac4df3d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -750,7 +750,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": {"input_keys": ["input"], "output_keys": ["logits"]}} + _pp_plan = {"lm_head": [["input"], ["logits"]]} def __init__(self, config): super().__init__(config) From b9ee76ffb905a8dcaadbfed64a5585556b14c573 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 14:37:03 +0100 Subject: [PATCH 08/24] Fix typing error Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/configuration_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 6d71711e6198..1991c6e8af8e 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -200,7 +200,7 @@ class PretrainedConfig(PushToHubMixin): is_composition: bool = False attribute_map: Dict[str, str] = {} base_model_tp_plan: Optional[Dict[str, Any]] = None - base_model_pp_plan: Optional[Dict[List[List[str]]]] = None + base_model_pp_plan: Optional[Dict[str, List[List[str]]]] = None _auto_class: Optional[str] = None def __setattr__(self, key, value): From f5ed151816d7a6f2d32b15959cdc3c1b7b8c2dcb Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:07:44 +0100 Subject: [PATCH 09/24] Update input name for Llama Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 406f7ac4df3d..b783ecd46d7f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -750,7 +750,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["input"], ["logits"]]} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) From cb60de854bf5770928825152f441536bc7a8c1c8 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:08:15 +0100 Subject: [PATCH 10/24] Add pp plan to Aria Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/aria/configuration_aria.py | 5 +++++ src/transformers/models/aria/modeling_aria.py | 1 + 2 files changed, 6 insertions(+) diff --git a/src/transformers/models/aria/configuration_aria.py b/src/transformers/models/aria/configuration_aria.py index ff34d59f5dfe..79e89d0600c9 100644 --- a/src/transformers/models/aria/configuration_aria.py +++ b/src/transformers/models/aria/configuration_aria.py @@ -144,6 +144,11 @@ class AriaTextConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } base_config_key = "text_config" def __init__( diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index ee692c9616f9..b89e80a8c5eb 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1141,6 +1141,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} config_class = AriaTextConfig def __init__(self, config: AriaTextConfig): From 9a07903223d302ee401e43cd647d29b415dbba4a Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:09:35 +0100 Subject: [PATCH 11/24] Add pp plan to Bamba Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/bamba/modeling_bamba.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 41ba1c5b26e0..fdffe245a908 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1446,6 +1446,7 @@ def _update_mamba_mask(self, attention_mask, cache_position): class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) From 57a2aa57f56d8e05d8c469019199bef21675c3a3 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:10:26 +0100 Subject: [PATCH 12/24] Add pp plan to Cohere 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/cohere/configuration_cohere.py | 5 +++++ src/transformers/models/cohere/modeling_cohere.py | 1 + src/transformers/models/cohere2/configuration_cohere2.py | 5 +++++ src/transformers/models/cohere2/modeling_cohere2.py | 1 + src/transformers/models/cohere2/modular_cohere2.py | 5 +++++ 5 files changed, 17 insertions(+) diff --git a/src/transformers/models/cohere/configuration_cohere.py b/src/transformers/models/cohere/configuration_cohere.py index dc9ca2cb4dd0..7964b184aeb0 100644 --- a/src/transformers/models/cohere/configuration_cohere.py +++ b/src/transformers/models/cohere/configuration_cohere.py @@ -148,6 +148,11 @@ class CohereConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 90b4e6dc63c1..86f57a622e4b 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -780,6 +780,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/cohere2/configuration_cohere2.py b/src/transformers/models/cohere2/configuration_cohere2.py index 88d3265eadfe..4605213c0327 100644 --- a/src/transformers/models/cohere2/configuration_cohere2.py +++ b/src/transformers/models/cohere2/configuration_cohere2.py @@ -148,6 +148,11 @@ class Cohere2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index e900413740cc..e384a4c1750a 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -781,6 +781,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config: Cohere2Config): super().__init__(config) diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index f24e1378ecc7..5bf1b624d3d1 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -173,6 +173,11 @@ class Cohere2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, From 2703a9e81a0db76848737d64637eba20be3a1a3a Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:11:25 +0100 Subject: [PATCH 13/24] Add pp plan to diffllama and emu3 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/diffllama/modeling_diffllama.py | 1 + src/transformers/models/emu3/modeling_emu3.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 38d9d3ce001d..64b9b12ff4a4 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -1019,6 +1019,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 38b285be7373..c309ed2a00c4 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1598,6 +1598,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} config_class = Emu3TextConfig def __init__(self, config): From 19f5346140749359823ef26c4183bd8f1aed3db7 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:14:29 +0100 Subject: [PATCH 14/24] Add pp plan to Gemma 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/gemma/configuration_gemma.py | 5 +++++ src/transformers/models/gemma/modeling_gemma.py | 1 + src/transformers/models/gemma/modular_gemma.py | 5 +++++ src/transformers/models/gemma2/configuration_gemma2.py | 5 +++++ src/transformers/models/gemma2/modeling_gemma2.py | 1 + src/transformers/models/gemma2/modular_gemma2.py | 5 +++++ 6 files changed, 22 insertions(+) diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index b8470e92fb32..e2b9dc5d7c0e 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -102,6 +102,11 @@ class GemmaConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 047516a4b162..27922e543cf5 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -752,6 +752,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 540ce2b87c15..e88b3b3da193 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -126,6 +126,11 @@ class GemmaConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 3dfbd6a107e9..da86daa0f0cd 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -106,6 +106,11 @@ class Gemma2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e21d0b656a86..dbaa8b727714 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -790,6 +790,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 351f083f813d..a63f74037531 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -132,6 +132,11 @@ class Gemma2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, From 5de4151d3d525d249687be4e04d2947d9afee9d3 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:18:44 +0100 Subject: [PATCH 15/24] Add pp plan to GLM and GPT NeoX Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/glm/configuration_glm.py | 5 +++++ src/transformers/models/glm/modeling_glm.py | 1 + src/transformers/models/gpt_neox/configuration_gpt_neox.py | 6 ++++++ src/transformers/models/gpt_neox/modeling_gpt_neox.py | 1 + src/transformers/models/gpt_neox/modular_gpt_neox.py | 1 + 5 files changed, 14 insertions(+) diff --git a/src/transformers/models/glm/configuration_glm.py b/src/transformers/models/glm/configuration_glm.py index 553b71cf234d..e3d3b200160d 100644 --- a/src/transformers/models/glm/configuration_glm.py +++ b/src/transformers/models/glm/configuration_glm.py @@ -93,6 +93,11 @@ class GlmConfig(PretrainedConfig): "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 0e73af6b1554..4d2d3dc8a49e 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -761,6 +761,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index e570662c1021..3e064371a44c 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -137,6 +137,12 @@ class GPTNeoXConfig(PretrainedConfig): "layers.*.mlp.dense_h_to_4h": "colwise", "layers.*.mlp.dense_4h_to_h": "rowwise", } + base_model_pp_plan = { + "embed_in": [["input_ids"], ["inputs_embeds"]], + "emb_dropout": [["inputs_embeds"], ["hidden_states"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "final_layer_norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index d83ee58af5ee..913a9c021e40 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -758,6 +758,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] _tp_plan = {"embed_out": "colwise_rep"} + _pp_plan = {"embed_out": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 295882a9eedb..2eb989d4bb3f 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -456,6 +456,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] _tp_plan = {"embed_out": "colwise_rep"} + _pp_plan = {"embed_out": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) From 3a490d86ccb787cd3f268d9ebf820301d2e58e7d Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:21:58 +0100 Subject: [PATCH 16/24] Add pp plan to Granite and Helium Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/granite/configuration_granite.py | 5 +++++ src/transformers/models/granite/modeling_granite.py | 1 + src/transformers/models/helium/configuration_helium.py | 5 +++++ src/transformers/models/helium/modeling_helium.py | 1 + 4 files changed, 12 insertions(+) diff --git a/src/transformers/models/granite/configuration_granite.py b/src/transformers/models/granite/configuration_granite.py index 404d60ca32a3..a45aa94058db 100644 --- a/src/transformers/models/granite/configuration_granite.py +++ b/src/transformers/models/granite/configuration_granite.py @@ -122,6 +122,11 @@ class GraniteConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 52cdc96e6435..08a04ea6ade5 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -764,6 +764,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/helium/configuration_helium.py b/src/transformers/models/helium/configuration_helium.py index 73c2638a88c2..97f8eeb73ed7 100644 --- a/src/transformers/models/helium/configuration_helium.py +++ b/src/transformers/models/helium/configuration_helium.py @@ -95,6 +95,11 @@ class HeliumConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 3c17e18e4c11..e3180466bef8 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -748,6 +748,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config: HeliumConfig): super().__init__(config) From 0c2c8df97580def1e2f87bf3e62f6669f8739d4c Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:24:51 +0100 Subject: [PATCH 17/24] Add pp plan to Mistral and Mixtral Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/mistral/configuration_mistral.py | 5 +++++ src/transformers/models/mistral/modeling_mistral.py | 1 + src/transformers/models/mixtral/configuration_mixtral.py | 5 +++++ src/transformers/models/mixtral/modeling_mixtral.py | 1 + 4 files changed, 12 insertions(+) diff --git a/src/transformers/models/mistral/configuration_mistral.py b/src/transformers/models/mistral/configuration_mistral.py index c4b874f27017..640079b28892 100644 --- a/src/transformers/models/mistral/configuration_mistral.py +++ b/src/transformers/models/mistral/configuration_mistral.py @@ -107,6 +107,11 @@ class MistralConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a6d9a54efc09..62c7493d9dea 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -751,6 +751,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/mixtral/configuration_mixtral.py b/src/transformers/models/mixtral/configuration_mixtral.py index c3f7ec8e4cc1..c5da125101b0 100644 --- a/src/transformers/models/mixtral/configuration_mixtral.py +++ b/src/transformers/models/mixtral/configuration_mixtral.py @@ -119,6 +119,11 @@ class MixtralConfig(PretrainedConfig): "layers.*.block_sparse_moe.experts.*.w2": "rowwise", "layers.*.block_sparse_moe.experts.*.w3": "colwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 251187677fd7..a9391499a056 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -967,6 +967,7 @@ def load_balancing_loss_func( class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) From c8960621a38b80a057f3856e8bcef2c8e4396eb3 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:27:17 +0100 Subject: [PATCH 18/24] Add pp plan to OLMo 1 & 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/olmo/configuration_olmo.py | 5 +++++ src/transformers/models/olmo/modeling_olmo.py | 1 + src/transformers/models/olmo2/configuration_olmo2.py | 5 +++++ src/transformers/models/olmo2/modeling_olmo2.py | 1 + src/transformers/models/olmo2/modular_olmo2.py | 5 +++++ 5 files changed, 17 insertions(+) diff --git a/src/transformers/models/olmo/configuration_olmo.py b/src/transformers/models/olmo/configuration_olmo.py index d80910e8456b..e79a532c6763 100644 --- a/src/transformers/models/olmo/configuration_olmo.py +++ b/src/transformers/models/olmo/configuration_olmo.py @@ -115,6 +115,11 @@ class OlmoConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index ef3e10582f59..09ae76962d6f 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -726,6 +726,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/olmo2/configuration_olmo2.py b/src/transformers/models/olmo2/configuration_olmo2.py index ce434f541608..cfb01de77e57 100644 --- a/src/transformers/models/olmo2/configuration_olmo2.py +++ b/src/transformers/models/olmo2/configuration_olmo2.py @@ -98,6 +98,11 @@ class Olmo2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 561b7fdf089e..8bab410ae21c 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -727,6 +727,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 04c7f0f486bd..0f8c83f4a892 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -109,6 +109,11 @@ class Olmo2Config(OlmoConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, From c01266db62f9a482cab96ae83c902998578613fd Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:35:21 +0100 Subject: [PATCH 19/24] Add pp plan to Phi and Phi 3 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/phi/configuration_phi.py | 6 ++++++ src/transformers/models/phi/modeling_phi.py | 1 + src/transformers/models/phi3/configuration_phi3.py | 5 +++++ src/transformers/models/phi3/modeling_phi3.py | 1 + 4 files changed, 13 insertions(+) diff --git a/src/transformers/models/phi/configuration_phi.py b/src/transformers/models/phi/configuration_phi.py index 2733d77ff674..8ca6cffb99e6 100644 --- a/src/transformers/models/phi/configuration_phi.py +++ b/src/transformers/models/phi/configuration_phi.py @@ -146,6 +146,12 @@ class PhiConfig(PretrainedConfig): "layers.*.mlp.fc1": "colwise", "layers.*.mlp.fc2": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "embed_dropout": [["inputs_embeds"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "final_layernorm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index a62249460eef..32d9df2d2cea 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -724,6 +724,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/phi3/configuration_phi3.py b/src/transformers/models/phi3/configuration_phi3.py index 6fe6e1cdfca8..e09399d6ec9f 100644 --- a/src/transformers/models/phi3/configuration_phi3.py +++ b/src/transformers/models/phi3/configuration_phi3.py @@ -113,6 +113,11 @@ class Phi3Config(PretrainedConfig): "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index ca6992d377bb..78969ba2173a 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -821,6 +821,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) From 81c11859be51e28cc387b1f9af52cdaacbe916c7 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:47:03 +0100 Subject: [PATCH 20/24] Add pp plan for Qwen 2, 2 MoE, 2 VL and 2.5 VL Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/qwen2/configuration_qwen2.py | 5 +++++ src/transformers/models/qwen2/modeling_qwen2.py | 1 + .../models/qwen2_5_vl/configuration_qwen2_5_vl.py | 5 +++++ src/transformers/models/qwen2_moe/configuration_qwen2_moe.py | 5 +++++ src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 1 + src/transformers/models/qwen2_vl/configuration_qwen2_vl.py | 5 +++++ 6 files changed, 22 insertions(+) diff --git a/src/transformers/models/qwen2/configuration_qwen2.py b/src/transformers/models/qwen2/configuration_qwen2.py index 16ce924b9f16..4e4106def5cc 100644 --- a/src/transformers/models/qwen2/configuration_qwen2.py +++ b/src/transformers/models/qwen2/configuration_qwen2.py @@ -139,6 +139,11 @@ class Qwen2Config(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8310379d83da..4fe14ed20c29 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -735,6 +735,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index a1cf06d94e87..e125cd7add3c 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -184,6 +184,11 @@ class Qwen2_5_VLConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py index ac6e8ae17acb..eed5a5e6f276 100644 --- a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py @@ -160,6 +160,11 @@ class Qwen2MoeConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index bfb4e81d3ec3..f24441722ea0 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1217,6 +1217,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index 49a0836cf96a..9e196cd491a6 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -173,6 +173,11 @@ class Qwen2VLConfig(PretrainedConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, From 95cec4b281a679796af127addfbea8f13dc0e553 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:48:09 +0100 Subject: [PATCH 21/24] Add pp plan for Starcoder 2 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .../models/starcoder2/configuration_starcoder2.py | 5 +++++ src/transformers/models/starcoder2/modeling_starcoder2.py | 1 + 2 files changed, 6 insertions(+) diff --git a/src/transformers/models/starcoder2/configuration_starcoder2.py b/src/transformers/models/starcoder2/configuration_starcoder2.py index 7f21d1f12d8b..2b6690861817 100644 --- a/src/transformers/models/starcoder2/configuration_starcoder2.py +++ b/src/transformers/models/starcoder2/configuration_starcoder2.py @@ -143,6 +143,11 @@ class Starcoder2Config(PretrainedConfig): "layers.*.mlp.c_fc": "colwise", "layers.*.mlp.c_proj": "colwise", } + base_model_pp_plan = { + "embed_tokens": [["input_ids"], ["inputs_embeds"]], + "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], + "norm": [["hidden_states"], ["hidden_states"]], + } def __init__( self, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 9314f05b4964..69ecff84a226 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -747,6 +747,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} def __init__(self, config): super().__init__(config) From 387e3a856f6f0e973868ec4d0f579e1d24b60040 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 15:56:08 +0100 Subject: [PATCH 22/24] Add enum for accessing inputs and outputs Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/modeling_utils.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2c88d1c04fbe..13c8719b3603 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -28,6 +28,7 @@ import warnings from contextlib import contextmanager from dataclasses import dataclass +from enum import Enum from functools import partial, wraps from threading import Thread from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union @@ -923,6 +924,11 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: return weights_name +class PipelineParallel(Enum): + inputs: 0 + outputs: 1 + + class ModuleUtilsMixin: """ A few utilities for `torch.nn.Modules`, to be used as a mixin. @@ -1312,10 +1318,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # `config.base_model_tp_plan` during `post_init`. _tp_plan = None - # A pipeline parallel plan to be applied to the model when PP is enabled. For - # top-level models, this attribute is currently defined in respective model - # code. For base models, this attribute comes from - # `config.base_model_pp_plan` during `post_init`. + # A pipeline parallel plan specifying the layers which may not be present + # on all ranks when PP is enabled. For top-level models, this attribute is + # currently defined in respective model code. For base models, this + # attribute comes from `config.base_model_pp_plan` during `post_init`. + # + # The variable names for the inputs and outputs of the specified layers can + # be indexed using the `PipelineParallel` enum as follows: + # - `_pp_plan["layers"][PipelineParallel.inputs]` + # - `_pp_plan["layers"][PipelineParallel.outputs]` _pp_plan = None # This flag signal that the model can be used as an efficient backend in TGI and vLLM From dca584d0a05fd2d7097a4cecbbcad527bd6409bd Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 17:22:27 +0100 Subject: [PATCH 23/24] Update type hints to use tuples Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/configuration_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 1991c6e8af8e..581032ef7d24 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -74,7 +74,7 @@ class PretrainedConfig(PushToHubMixin): naming of attributes. - **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor parallel plan applied to the sub-module when `model.tensor_parallel` is called. - - **base_model_pp_plan** (`Dict[str, Dict[[str, str]]]`) -- A dict that maps child-modules of a base model to a + - **base_model_pp_plan** (`Dict[str, Tuple[List[str]]]`) -- A dict that maps child-modules of a base model to a pipeline parallel plan that enables users to place the child-module on the appropriate device. Common attributes (present in all subclasses): @@ -200,7 +200,7 @@ class PretrainedConfig(PushToHubMixin): is_composition: bool = False attribute_map: Dict[str, str] = {} base_model_tp_plan: Optional[Dict[str, Any]] = None - base_model_pp_plan: Optional[Dict[str, List[List[str]]]] = None + base_model_pp_plan: Optional[Dict[str, Tuple[List[str]]]] = None _auto_class: Optional[str] = None def __setattr__(self, key, value): From 02129121cf18b07f0c958ced5947b7d72f2208b5 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 10 Feb 2025 17:22:27 +0100 Subject: [PATCH 24/24] Change outer list to tuple Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/aria/configuration_aria.py | 6 +++--- src/transformers/models/aria/modeling_aria.py | 2 +- src/transformers/models/bamba/modeling_bamba.py | 2 +- src/transformers/models/cohere/configuration_cohere.py | 6 +++--- src/transformers/models/cohere/modeling_cohere.py | 2 +- src/transformers/models/cohere2/configuration_cohere2.py | 6 +++--- src/transformers/models/cohere2/modeling_cohere2.py | 2 +- src/transformers/models/cohere2/modular_cohere2.py | 6 +++--- src/transformers/models/diffllama/modeling_diffllama.py | 2 +- src/transformers/models/emu3/modeling_emu3.py | 2 +- src/transformers/models/gemma/configuration_gemma.py | 6 +++--- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/gemma/modular_gemma.py | 6 +++--- src/transformers/models/gemma2/configuration_gemma2.py | 6 +++--- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/gemma2/modular_gemma2.py | 6 +++--- src/transformers/models/glm/configuration_glm.py | 6 +++--- src/transformers/models/glm/modeling_glm.py | 2 +- .../models/gpt_neox/configuration_gpt_neox.py | 8 ++++---- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 +- src/transformers/models/gpt_neox/modular_gpt_neox.py | 2 +- src/transformers/models/granite/configuration_granite.py | 6 +++--- src/transformers/models/granite/modeling_granite.py | 2 +- src/transformers/models/helium/configuration_helium.py | 6 +++--- src/transformers/models/helium/modeling_helium.py | 2 +- src/transformers/models/llama/configuration_llama.py | 6 +++--- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/mistral/configuration_mistral.py | 6 +++--- src/transformers/models/mistral/modeling_mistral.py | 2 +- src/transformers/models/mixtral/configuration_mixtral.py | 6 +++--- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/olmo/configuration_olmo.py | 6 +++--- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/olmo2/configuration_olmo2.py | 6 +++--- src/transformers/models/olmo2/modeling_olmo2.py | 2 +- src/transformers/models/olmo2/modular_olmo2.py | 6 +++--- src/transformers/models/phi/configuration_phi.py | 8 ++++---- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/configuration_phi3.py | 6 +++--- src/transformers/models/phi3/modeling_phi3.py | 2 +- src/transformers/models/qwen2/configuration_qwen2.py | 6 +++--- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- .../models/qwen2_5_vl/configuration_qwen2_5_vl.py | 6 +++--- .../models/qwen2_moe/configuration_qwen2_moe.py | 6 +++--- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 +- .../models/qwen2_vl/configuration_qwen2_vl.py | 6 +++--- .../models/starcoder2/configuration_starcoder2.py | 6 +++--- src/transformers/models/starcoder2/modeling_starcoder2.py | 2 +- 48 files changed, 100 insertions(+), 100 deletions(-) diff --git a/src/transformers/models/aria/configuration_aria.py b/src/transformers/models/aria/configuration_aria.py index 79e89d0600c9..fed90c86b4a7 100644 --- a/src/transformers/models/aria/configuration_aria.py +++ b/src/transformers/models/aria/configuration_aria.py @@ -145,9 +145,9 @@ class AriaTextConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } base_config_key = "text_config" diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index b89e80a8c5eb..dacc92b7951f 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1141,7 +1141,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = AriaTextConfig def __init__(self, config: AriaTextConfig): diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index fdffe245a908..6fdce41e5a68 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1446,7 +1446,7 @@ def _update_mamba_mask(self, attention_mask, cache_position): class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/cohere/configuration_cohere.py b/src/transformers/models/cohere/configuration_cohere.py index 7964b184aeb0..eeeb23642802 100644 --- a/src/transformers/models/cohere/configuration_cohere.py +++ b/src/transformers/models/cohere/configuration_cohere.py @@ -149,9 +149,9 @@ class CohereConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 86f57a622e4b..5101a0f9e083 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -780,7 +780,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/cohere2/configuration_cohere2.py b/src/transformers/models/cohere2/configuration_cohere2.py index 4605213c0327..c792ab3f8278 100644 --- a/src/transformers/models/cohere2/configuration_cohere2.py +++ b/src/transformers/models/cohere2/configuration_cohere2.py @@ -149,9 +149,9 @@ class Cohere2Config(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index e384a4c1750a..df0cb24d79c4 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -781,7 +781,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config: Cohere2Config): super().__init__(config) diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 5bf1b624d3d1..979b5abc2600 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -174,9 +174,9 @@ class Cohere2Config(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 64b9b12ff4a4..301668d21a82 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -1019,7 +1019,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index c309ed2a00c4..ef086ab12ea8 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1598,7 +1598,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = Emu3TextConfig def __init__(self, config): diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index e2b9dc5d7c0e..2aeb20058058 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -103,9 +103,9 @@ class GemmaConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 27922e543cf5..59b7dc3dc347 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -752,7 +752,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index e88b3b3da193..dc8ced15f962 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -127,9 +127,9 @@ class GemmaConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index da86daa0f0cd..c9e66f8beace 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -107,9 +107,9 @@ class Gemma2Config(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index dbaa8b727714..d55fafe05677 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -790,7 +790,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index a63f74037531..76123af3ec52 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -133,9 +133,9 @@ class Gemma2Config(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/glm/configuration_glm.py b/src/transformers/models/glm/configuration_glm.py index e3d3b200160d..f9a3ab53a931 100644 --- a/src/transformers/models/glm/configuration_glm.py +++ b/src/transformers/models/glm/configuration_glm.py @@ -94,9 +94,9 @@ class GlmConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 4d2d3dc8a49e..54c138212e86 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -761,7 +761,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index 3e064371a44c..cea854eabb94 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -138,10 +138,10 @@ class GPTNeoXConfig(PretrainedConfig): "layers.*.mlp.dense_4h_to_h": "rowwise", } base_model_pp_plan = { - "embed_in": [["input_ids"], ["inputs_embeds"]], - "emb_dropout": [["inputs_embeds"], ["hidden_states"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "final_layer_norm": [["hidden_states"], ["hidden_states"]], + "embed_in": (["input_ids"], ["inputs_embeds"]), + "emb_dropout": (["inputs_embeds"], ["hidden_states"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "final_layer_norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 913a9c021e40..efb298243177 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -758,7 +758,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] _tp_plan = {"embed_out": "colwise_rep"} - _pp_plan = {"embed_out": [["hidden_states"], ["logits"]]} + _pp_plan = {"embed_out": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 2eb989d4bb3f..3a7cc49542ef 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -456,7 +456,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] _tp_plan = {"embed_out": "colwise_rep"} - _pp_plan = {"embed_out": [["hidden_states"], ["logits"]]} + _pp_plan = {"embed_out": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/granite/configuration_granite.py b/src/transformers/models/granite/configuration_granite.py index a45aa94058db..fc651a94e1bd 100644 --- a/src/transformers/models/granite/configuration_granite.py +++ b/src/transformers/models/granite/configuration_granite.py @@ -123,9 +123,9 @@ class GraniteConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 08a04ea6ade5..85c8e97c77ad 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -764,7 +764,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/helium/configuration_helium.py b/src/transformers/models/helium/configuration_helium.py index 97f8eeb73ed7..7b27c6e54b69 100644 --- a/src/transformers/models/helium/configuration_helium.py +++ b/src/transformers/models/helium/configuration_helium.py @@ -96,9 +96,9 @@ class HeliumConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index e3180466bef8..86635f2d7200 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -748,7 +748,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config: HeliumConfig): super().__init__(config) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index b810c9667f69..066534f109fa 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -152,9 +152,9 @@ class LlamaConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b783ecd46d7f..a06084e82567 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -750,7 +750,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/mistral/configuration_mistral.py b/src/transformers/models/mistral/configuration_mistral.py index 640079b28892..3a237bc7343b 100644 --- a/src/transformers/models/mistral/configuration_mistral.py +++ b/src/transformers/models/mistral/configuration_mistral.py @@ -108,9 +108,9 @@ class MistralConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 62c7493d9dea..92e555b3d768 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -751,7 +751,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/mixtral/configuration_mixtral.py b/src/transformers/models/mixtral/configuration_mixtral.py index c5da125101b0..d9b02e10fc4c 100644 --- a/src/transformers/models/mixtral/configuration_mixtral.py +++ b/src/transformers/models/mixtral/configuration_mixtral.py @@ -120,9 +120,9 @@ class MixtralConfig(PretrainedConfig): "layers.*.block_sparse_moe.experts.*.w3": "colwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index a9391499a056..0835e33722e9 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -967,7 +967,7 @@ def load_balancing_loss_func( class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/olmo/configuration_olmo.py b/src/transformers/models/olmo/configuration_olmo.py index e79a532c6763..ded0bf4f017c 100644 --- a/src/transformers/models/olmo/configuration_olmo.py +++ b/src/transformers/models/olmo/configuration_olmo.py @@ -116,9 +116,9 @@ class OlmoConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 09ae76962d6f..37d15475be89 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -726,7 +726,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/olmo2/configuration_olmo2.py b/src/transformers/models/olmo2/configuration_olmo2.py index cfb01de77e57..222c8e179154 100644 --- a/src/transformers/models/olmo2/configuration_olmo2.py +++ b/src/transformers/models/olmo2/configuration_olmo2.py @@ -99,9 +99,9 @@ class Olmo2Config(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 8bab410ae21c..40c912ef1a63 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -727,7 +727,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 0f8c83f4a892..bc5a9b89d501 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -110,9 +110,9 @@ class Olmo2Config(OlmoConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/phi/configuration_phi.py b/src/transformers/models/phi/configuration_phi.py index 8ca6cffb99e6..06e5cbec2ead 100644 --- a/src/transformers/models/phi/configuration_phi.py +++ b/src/transformers/models/phi/configuration_phi.py @@ -147,10 +147,10 @@ class PhiConfig(PretrainedConfig): "layers.*.mlp.fc2": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "embed_dropout": [["inputs_embeds"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "final_layernorm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "embed_dropout": (["inputs_embeds"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "final_layernorm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 32d9df2d2cea..8ab41d2a0cbe 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -724,7 +724,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/phi3/configuration_phi3.py b/src/transformers/models/phi3/configuration_phi3.py index e09399d6ec9f..a6b7ec9baf56 100644 --- a/src/transformers/models/phi3/configuration_phi3.py +++ b/src/transformers/models/phi3/configuration_phi3.py @@ -114,9 +114,9 @@ class Phi3Config(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 78969ba2173a..2595278048c8 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -821,7 +821,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/qwen2/configuration_qwen2.py b/src/transformers/models/qwen2/configuration_qwen2.py index 4e4106def5cc..16979865e4fe 100644 --- a/src/transformers/models/qwen2/configuration_qwen2.py +++ b/src/transformers/models/qwen2/configuration_qwen2.py @@ -140,9 +140,9 @@ class Qwen2Config(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 4fe14ed20c29..91eac84ffcb2 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -735,7 +735,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index e125cd7add3c..b2bf37ba0c14 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -185,9 +185,9 @@ class Qwen2_5_VLConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py index eed5a5e6f276..a52b4204a662 100644 --- a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py @@ -161,9 +161,9 @@ class Qwen2MoeConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index f24441722ea0..8157408f42e2 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1217,7 +1217,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index 9e196cd491a6..710738e39654 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -174,9 +174,9 @@ class Qwen2VLConfig(PretrainedConfig): "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/starcoder2/configuration_starcoder2.py b/src/transformers/models/starcoder2/configuration_starcoder2.py index 2b6690861817..b617a1cad842 100644 --- a/src/transformers/models/starcoder2/configuration_starcoder2.py +++ b/src/transformers/models/starcoder2/configuration_starcoder2.py @@ -144,9 +144,9 @@ class Starcoder2Config(PretrainedConfig): "layers.*.mlp.c_proj": "colwise", } base_model_pp_plan = { - "embed_tokens": [["input_ids"], ["inputs_embeds"]], - "layers": [["hidden_states", "attention_mask"], ["hidden_states"]], - "norm": [["hidden_states"], ["hidden_states"]], + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 69ecff84a226..f176d5311dd5 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -747,7 +747,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": [["hidden_states"], ["logits"]]} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config)