Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class PretrainedConfig(PushToHubMixin):
naming of attributes.
- **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
parallel plan applied to the sub-module when `model.tensor_parallel` is called.
- **base_model_pp_plan** (`Dict[str, Tuple[List[str]]]`) -- A dict that maps child-modules of a base model to a
pipeline parallel plan that enables users to place the child-module on the appropriate device.

Common attributes (present in all subclasses):

Expand Down Expand Up @@ -198,6 +200,7 @@ class PretrainedConfig(PushToHubMixin):
is_composition: bool = False
attribute_map: Dict[str, str] = {}
base_model_tp_plan: Optional[Dict[str, Any]] = None
base_model_pp_plan: Optional[Dict[str, Tuple[List[str]]]] = None
_auto_class: Optional[str] = None

def __setattr__(self, key, value):
Expand Down Expand Up @@ -860,6 +863,9 @@ def to_diff_dict(self) -> Dict[str, Any]:
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_tp_plan"]
# Do not serialize `base_model_pp_plan` for now
if "base_model_pp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_pp_plan"]

return serializable_config_dict

Expand All @@ -882,6 +888,9 @@ def to_dict(self) -> Dict[str, Any]:
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in output:
del output["base_model_tp_plan"]
# Do not serialize `base_model_pp_plan` for now
if "base_model_pp_plan" in output:
del output["base_model_pp_plan"]

# Transformers version when serializing the model
output["transformers_version"] = __version__
Expand Down
29 changes: 29 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
Expand Down Expand Up @@ -923,6 +924,11 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
return weights_name


class PipelineParallel(Enum):
inputs: 0
outputs: 1


class ModuleUtilsMixin:
"""
A few utilities for `torch.nn.Modules`, to be used as a mixin.
Expand Down Expand Up @@ -1312,6 +1318,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# `config.base_model_tp_plan` during `post_init`.
_tp_plan = None

# A pipeline parallel plan specifying the layers which may not be present
# on all ranks when PP is enabled. For top-level models, this attribute is
# currently defined in respective model code. For base models, this
# attribute comes from `config.base_model_pp_plan` during `post_init`.
#
# The variable names for the inputs and outputs of the specified layers can
# be indexed using the `PipelineParallel` enum as follows:
# - `_pp_plan["layers"][PipelineParallel.inputs]`
# - `_pp_plan["layers"][PipelineParallel.outputs]`
Comment on lines +1327 to +1329
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

_pp_plan = None

# This flag signal that the model can be used as an efficient backend in TGI and vLLM
# In practice, it means that they support attention interface functions, fully pass the kwargs
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
Expand Down Expand Up @@ -1374,6 +1391,9 @@ def post_init(self):
# If current model is a base model, attach `base_model_tp_plan` from config
if self.base_model is self:
self._tp_plan = self.config.base_model_tp_plan
# If current model is a base model, attach `base_model_pp_plan` from config
if self.base_model is self:
self._pp_plan = self.config.base_model_pp_plan

def dequantize(self):
"""
Expand Down Expand Up @@ -5196,6 +5216,15 @@ def tplize(mod: torch.nn.Module) -> None:
# function to every submodule.
self.apply(tplize)

@property
def supports_pp_plan(self):
if self._pp_plan is not None:
return True
# Check if base model has PP plan
if getattr(self.base_model, "_pp_plan", None) is not None:
return True
return False

@property
def loss_function(self):
if hasattr(self, "_loss_function"):
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/aria/configuration_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ class AriaTextConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
base_config_key = "text_config"

def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):

_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config_class = AriaTextConfig

def __init__(self, config: AriaTextConfig):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,7 @@ def _update_mamba_mask(self, attention_mask, cache_position):
class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/cohere/configuration_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class CohereConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/cohere2/configuration_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class Cohere2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config: Cohere2Config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ class Cohere2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/emu3/modeling_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,6 +1598,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config_class = Emu3TextConfig

def __init__(self, config):
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ class GemmaConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ class GemmaConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/gemma2/configuration_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ class Gemma2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ class Gemma2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/glm/configuration_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ class GlmConfig(PretrainedConfig):
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/gpt_neox/configuration_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ class GPTNeoXConfig(PretrainedConfig):
"layers.*.mlp.dense_h_to_4h": "colwise",
"layers.*.mlp.dense_4h_to_h": "rowwise",
}
base_model_pp_plan = {
"embed_in": (["input_ids"], ["inputs_embeds"]),
"emb_dropout": (["inputs_embeds"], ["hidden_states"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"final_layer_norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["embed_out.weight"]
_tp_plan = {"embed_out": "colwise_rep"}
_pp_plan = {"embed_out": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt_neox/modular_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["embed_out.weight"]
_tp_plan = {"embed_out": "colwise_rep"}
_pp_plan = {"embed_out": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/granite/configuration_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ class GraniteConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/helium/configuration_helium.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ class HeliumConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/helium/modeling_helium.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config: HeliumConfig):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ class LlamaConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/mistral/configuration_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ class MistralConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
Loading