From 07536a8deee6ccec22036fdea7e4a3bbe717d21a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:19:02 +0000 Subject: [PATCH 001/342] initial --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/openai.md | 75 ++ .../integrations/flex_attention.py | 17 +- .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 5 + .../models/auto/tokenization_auto.py | 7 + src/transformers/models/openai/__init__.py | 5 +- .../models/openai/configuration_openai.py | 291 +++-- .../openai/convert_openai_weights_to_hf.py | 606 ++++++++++ .../models/openai/modeling_openai.py | 1039 ++++------------- tests/models/openai/test_modeling_openai.py | 975 +++++++++++++--- 11 files changed, 1913 insertions(+), 1111 deletions(-) create mode 100644 docs/source/en/model_doc/openai.md create mode 100644 src/transformers/models/openai/convert_openai_weights_to_hf.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 873df4aa86ad..a0d0319c6b67 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -595,6 +595,8 @@ title: OLMoE - local: model_doc/open-llama title: Open-Llama + - local: model_doc/openai + title: openai - local: model_doc/opt title: OPT - local: model_doc/pegasus diff --git a/docs/source/en/model_doc/openai.md b/docs/source/en/model_doc/openai.md new file mode 100644 index 000000000000..211d954de749 --- /dev/null +++ b/docs/source/en/model_doc/openai.md @@ -0,0 +1,75 @@ + + +
+
+ PyTorch + Flax + FlashAttention + SDPA +
+
+ +# openai + +# openai + +## Overview + +The openai model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + + +## OpenaiConfig + +[[autodoc]] OpenaiConfig + +## OpenaiModel + +[[autodoc]] OpenaiModel + - forward + +## OpenaiForCausalLM + +[[autodoc]] OpenaiForCausalLM + - forward + +## OpenaiForSequenceClassification + +[[autodoc]] OpenaiForSequenceClassification + - forward + +## OpenaiForQuestionAnswering + +[[autodoc]] OpenaiForQuestionAnswering + - forward + +## OpenaiForTokenClassification + +[[autodoc]] OpenaiForTokenClassification + - forward diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index afdaba5199de..98dbef5529ef 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -226,14 +226,15 @@ def flex_attention_forward( if causal_mask is not None: causal_mask = causal_mask[:, :, :, : key.shape[-2]] - def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): - if softcap is not None: - score = softcap * torch.tanh(score / softcap) - if causal_mask is not None: - score = score + causal_mask[batch_idx][0][q_idx][kv_idx] - if head_mask is not None: - score = score + head_mask[batch_idx][head_idx][0][0] - return score + if score_mod is not None: + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + if softcap is not None: + score = softcap * torch.tanh(score / softcap) + if causal_mask is not None: + score = score + causal_mask[batch_idx][0][q_idx][kv_idx] + if head_mask is not None: + score = score + head_mask[batch_idx][head_idx][0][0] + return score enable_gqa = True num_local_query_heads = query.shape[1] diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index bebf72e45f00..3e812e3962fb 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -182,6 +182,7 @@ ("levit", "LevitConfig"), ("lilt", "LiltConfig"), ("llama", "LlamaConfig"), + ("openai", "OpenaiConfig"), ("llama4", "Llama4Config"), ("llama4_text", "Llama4TextConfig"), ("llava", "LlavaConfig"), @@ -550,6 +551,7 @@ ("levit", "LeViT"), ("lilt", "LiLT"), ("llama", "LLaMA"), + ("openai", "openai"), ("llama2", "Llama2"), ("llama3", "Llama3"), ("llama4", "Llama4"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 5fe74444cb20..cc9395c521a6 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -177,6 +177,7 @@ ("levit", "LevitModel"), ("lilt", "LiltModel"), ("llama", "LlamaModel"), + ("openai", "OpenaiModel"), ("llama4", "Llama4ForConditionalGeneration"), ("llama4_text", "Llama4TextModel"), ("llava", "LlavaModel"), @@ -585,6 +586,7 @@ ("jamba", "JambaForCausalLM"), ("jetmoe", "JetMoeForCausalLM"), ("llama", "LlamaForCausalLM"), + ("openai", "OpenaiForCausalLM"), ("llama4", "Llama4ForCausalLM"), ("llama4_text", "Llama4ForCausalLM"), ("mamba", "MambaForCausalLM"), @@ -1099,6 +1101,7 @@ ("led", "LEDForSequenceClassification"), ("lilt", "LiltForSequenceClassification"), ("llama", "LlamaForSequenceClassification"), + ("openai", "OpenaiForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"), ("luke", "LukeForSequenceClassification"), ("markuplm", "MarkupLMForSequenceClassification"), @@ -1189,6 +1192,7 @@ ("led", "LEDForQuestionAnswering"), ("lilt", "LiltForQuestionAnswering"), ("llama", "LlamaForQuestionAnswering"), + ("openai", "OpenaiForQuestionAnswering"), ("longformer", "LongformerForQuestionAnswering"), ("luke", "LukeForQuestionAnswering"), ("lxmert", "LxmertForQuestionAnswering"), @@ -1297,6 +1301,7 @@ ("layoutlmv3", "LayoutLMv3ForTokenClassification"), ("lilt", "LiltForTokenClassification"), ("llama", "LlamaForTokenClassification"), + ("openai", "OpenaiForTokenClassification"), ("longformer", "LongformerForTokenClassification"), ("luke", "LukeForTokenClassification"), ("markuplm", "MarkupLMForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 3c640de462d6..9dcd70c27e85 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -296,6 +296,13 @@ "LlamaTokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "openai", + ( + None, + "PreTrainedTokenizerFast" if is_tokenizers_available() else None, + ), + ), ( "llama4", ( diff --git a/src/transformers/models/openai/__init__.py b/src/transformers/models/openai/__init__.py index a07b0ab669f3..f777d1981929 100644 --- a/src/transformers/models/openai/__init__.py +++ b/src/transformers/models/openai/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,9 +20,6 @@ if TYPE_CHECKING: from .configuration_openai import * from .modeling_openai import * - from .modeling_tf_openai import * - from .tokenization_openai import * - from .tokenization_openai_fast import * else: import sys diff --git a/src/transformers/models/openai/configuration_openai.py b/src/transformers/models/openai/configuration_openai.py index b4f2fae9d304..b94119974251 100644 --- a/src/transformers/models/openai/configuration_openai.py +++ b/src/transformers/models/openai/configuration_openai.py @@ -1,6 +1,10 @@ # coding=utf-8 -# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,144 +17,201 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""OpenAI GPT configuration""" +"""openai model configuration""" from ...configuration_utils import PretrainedConfig -from ...utils import logging - +from ...modeling_rope_utils import rope_config_validation -logger = logging.get_logger(__name__) - -class OpenAIGPTConfig(PretrainedConfig): - """ - This is the configuration class to store the configuration of a [`OpenAIGPTModel`] or a [`TFOpenAIGPTModel`]. It is - used to instantiate a GPT model according to the specified arguments, defining the model architecture. - Instantiating a configuration with the defaults will yield a similar configuration to that of the GPT - [openai-community/openai-gpt](https://huggingface.co/openai-community/openai-gpt) architecture from OpenAI. +class OpenaiConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`OpenaiModel`]. It is used to instantiate an openai + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the openai-7B. + e.g. [meta-openai/Openai-2-7b-hf](https://huggingface.co/meta-openai/Openai-2-7b-hf) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. + Args: - vocab_size (`int`, *optional*, defaults to 40478): - Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`OpenAIGPTModel`] or [`TFOpenAIGPTModel`]. - n_positions (`int`, *optional*, defaults to 512): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). - n_embd (`int`, *optional*, defaults to 768): - Dimensionality of the embeddings and hidden states. - n_layer (`int`, *optional*, defaults to 12): - Number of hidden layers in the Transformer encoder. - n_head (`int`, *optional*, defaults to 12): - Number of attention heads for each attention layer in the Transformer encoder. - afn (`str` or `Callable`, *optional*, defaults to `"gelu"`): - The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, - `"relu"`, `"silu"` and `"gelu_new"` are supported. - resid_pdrop (`float`, *optional*, defaults to 0.1): - The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. - embd_pdrop (`int`, *optional*, defaults to 0.1): - The dropout ratio for the embeddings. - attn_pdrop (`float`, *optional*, defaults to 0.1): - The dropout ratio for the attention. - layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): - The epsilon to use in the layer normalization layers + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the openai model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`OpenaiModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Openai 1 supports up to 2048 tokens, + Openai 2 up to 4096, CodeOpenai up to 16384. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - summary_type (`str`, *optional*, defaults to `"cls_index"`): - Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and - [`OpenAIGPTDoubleHeadsModel`]. - - Has to be one of the following options: - - - `"last"`: Take the last token hidden state (like XLNet). - - `"first"`: Take the first token hidden state (like BERT). - - `"mean"`: Take the mean of all tokens hidden states. - - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). - - `"attn"`: Not implemented now, use multi-head attention. - summary_use_proj (`bool`, *optional*, defaults to `True`): - Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and - [`OpenAIGPTDoubleHeadsModel`]. - - Whether or not to add a projection after the vector extraction. - summary_activation (`str`, *optional*): - Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and - [`OpenAIGPTDoubleHeadsModel`]. - - Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. - summary_proj_to_labels (`bool`, *optional*, defaults to `True`): - Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and - [`OpenAIGPTDoubleHeadsModel`]. - - Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. - summary_first_dropout (`float`, *optional*, defaults to 0.1): - Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and - [`OpenAIGPTDoubleHeadsModel`]. - - The dropout ratio to be used after the projection and activation. - - - Examples: + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'openai3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'openai3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'openai3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'openai3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_attention_heads ```python - >>> from transformers import OpenAIGPTConfig, OpenAIGPTModel + >>> from transformers import OpenaiModel, OpenaiConfig - >>> # Initializing a GPT configuration - >>> configuration = OpenAIGPTConfig() + >>> # Initializing a openai openai-7b style configuration + >>> configuration = OpenaiConfig() - >>> # Initializing a model (with random weights) from the configuration - >>> model = OpenAIGPTModel(configuration) + >>> # Initializing a model from the openai-7b style configuration + >>> model = OpenaiModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" - model_type = "openai-gpt" - attribute_map = { - "max_position_embeddings": "n_positions", - "hidden_size": "n_embd", - "num_attention_heads": "n_head", - "num_hidden_layers": "n_layer", + model_type = "openai" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `OpenaiModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), } def __init__( self, - vocab_size=40478, - n_positions=512, - n_embd=768, - n_layer=12, - n_head=12, - afn="gelu", - resid_pdrop=0.1, - embd_pdrop=0.1, - attn_pdrop=0.1, - layer_norm_epsilon=1e-5, - initializer_range=0.02, - summary_type="cls_index", - summary_use_proj=True, - summary_activation=None, - summary_proj_to_labels=True, - summary_first_dropout=0.1, + num_hidden_layers: int = 36, + num_experts: int = 128, + vocab_size: int = 201088, + hidden_size: int = 2880, + intermediate_size: int = 2880, + head_dim: int = 64, + num_attention_heads: int = 64, + num_key_value_heads: int = 8, + sliding_window: int = 128, + rope_theta: float = 150000.0, + tie_word_embeddings= False, + hidden_act: str = "silu", + initializer_range: float = 0.02, + rms_norm_eps: float = 1e-6, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, **kwargs, ): self.vocab_size = vocab_size - self.n_positions = n_positions - self.n_embd = n_embd - self.n_layer = n_layer - self.n_head = n_head - self.afn = afn - self.resid_pdrop = resid_pdrop - self.embd_pdrop = embd_pdrop - self.attn_pdrop = attn_pdrop - self.layer_norm_epsilon = layer_norm_epsilon + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_experts = num_experts + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act self.initializer_range = initializer_range - self.summary_type = summary_type - self.summary_use_proj = summary_use_proj - self.summary_activation = summary_activation - self.summary_first_dropout = summary_first_dropout - self.summary_proj_to_labels = summary_proj_to_labels - super().__init__(**kwargs) - - -__all__ = ["OpenAIGPTConfig"] + self.rms_norm_eps = rms_norm_eps + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["OpenaiConfig"] diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py new file mode 100644 index 000000000000..baea0e6a6649 --- /dev/null +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -0,0 +1,606 @@ +# Copyright 2025 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import json +import os +import tempfile +import warnings +from typing import List + +import torch +from tokenizers import AddedToken, processors + +from transformers import GenerationConfig, LlamaTokenizer, OpenaiConfig, OpenaiForCausalLM, PreTrainedTokenizerFast +from transformers.convert_slow_tokenizer import TikTokenConverter + + +try: + from transformers import LlamaTokenizerFast +except ImportError as e: + warnings.warn(e) + warnings.warn( + "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" + ) + LlamaTokenizerFast = None + +""" +Sample usage: + +``` +python src/transformers/models/openai/convert_openai_weights_to_hf.py \ + --input_dir /path/to/downloaded/openai/weights --model_size 1B --openai_version 3.2 --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import OpenaiForCausalLM, LlamaTokenizer + +model = OpenaiForCausalLM.from_pretrained("/output/path") +tokenizer = LlamaTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). + +If you want your tokenizer to add a bos automatically you should update the tokenizer._tokenizers.post_processor: + +```py +from tokenizers import processors +bos = "<|begin_of_text|>" +tokenizer._tokenizers.post_processor = processors.Sequence( + [ + processors.ByteLevel(trim_offsets=False), + processors.TemplateProcessing( + single=f"{bos}:0 $A:0", + pair=f"{bos}:0 $A:0 {bos}:1 $B:1", + special_tokens=[ + (bos, tokenizer.encode(bos)), + ], + ), + ] +) +``` +""" + +NUM_SHARDS = { + "1B": 1, + "3B": 1, + "7B": 1, + "8B": 1, + "8Bf": 1, + "7Bf": 1, + "13B": 2, + "13Bf": 2, + "34B": 4, + "30B": 4, + "65B": 8, + "70B": 8, + "70Bf": 8, + "405B": 8, + "405B-MP16": 16, +} + +CONTEXT_LENGTH_FOR_VERSION = {"Guard-3": 131072, "3.2": 131072, "3.1": 131072, "3": 8192, "2": 4096, "1": 2048} + +BOS_ADDED_TOKEN = AddedToken( + "<|begin_of_text|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True +) +EOS_ADDED_TOKEN = AddedToken( + "<|end_of_text|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True +) +EOT_ADDED_TOKEN = AddedToken( + "<|eot_id|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True +) + +DEFAULT_OPENAI_SPECIAL_TOKENS = { + "3": [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)], + "3.1": [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|reserved_special_token_2|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + ] + + [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)], + "3.2": [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|reserved_special_token_2|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + ] + + [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)], + "Guard-3": [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|reserved_special_token_2|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + ] + + [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)], +} + + +def is_openai_3(version): + return version in ["3", "3.1", "3.2", "Guard-3"] + + +def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): + return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model( + model_path, + input_base_path, + model_size=None, + safe_serialization=True, + openai_version="1", + vocab_size=None, + num_shards=None, + instruct=False, + push_to_hub=False, +): + print("Converting the model.") + params = read_json(os.path.join(input_base_path, "params.json")) + num_shards = NUM_SHARDS[model_size] if num_shards is None else num_shards + params = params.get("model", params) + n_layers = params["n_layers"] + n_heads = params["n_heads"] + n_heads_per_shard = n_heads // num_shards + dim = params["dim"] + dims_per_head = dim // n_heads + base = params.get("rope_theta", 10000.0) + inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + if base > 10000.0 and not is_openai_3(openai_version): + max_position_embeddings = 16384 + else: + max_position_embeddings = CONTEXT_LENGTH_FOR_VERSION[openai_version] + + if params.get("n_kv_heads", None) is not None: + num_key_value_heads = params["n_kv_heads"] # for GQA / MQA + num_key_value_heads_per_shard = num_key_value_heads // num_shards + key_value_dim = dims_per_head * num_key_value_heads + else: # compatibility with other checkpoints + num_key_value_heads = n_heads + num_key_value_heads_per_shard = n_heads_per_shard + key_value_dim = dim + + # permute for sliced rotary + def permute(w, n_heads, dim1=dim, dim2=dim): + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + with tempfile.TemporaryDirectory() as tmp_model_path: + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + # Load weights + if num_shards == 1: + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + loaded = torch.load( + os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu", weights_only=True + ) + else: + # Sharded + checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")]) + print("Loading in order:", checkpoint_list) + loaded = [ + torch.load(os.path.join(input_base_path, file), map_location="cpu", weights_only=True) + for file in checkpoint_list + ] + param_count = 0 + index_dict = {"weight_map": {}} + for layer_i in range(n_layers): + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" + if num_shards == 1: + # Unsharded + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads + ), + f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wk.weight"], + n_heads=num_key_value_heads, + dim1=key_value_dim, + ), + f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], + f"model.layers.{layer_i}.input_layernorm.weight": loaded[ + f"layers.{layer_i}.attention_norm.weight" + ], + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[ + f"layers.{layer_i}.ffn_norm.weight" + ], + } + else: + # Sharded + # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share + # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is + # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. + + state_dict = { + f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ + f"layers.{layer_i}.attention_norm.weight" + ].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ + f"layers.{layer_i}.ffn_norm.weight" + ].clone(), + } + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wq.weight"].view( + n_heads_per_shard, dims_per_head, dim + ) + for i in range(len(loaded)) + ], + dim=0, + ).reshape(dim, dim), + n_heads=n_heads, + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( + num_key_value_heads_per_shard, dims_per_head, dim + ) + for i in range(len(loaded)) + ], + dim=0, + ).reshape(key_value_dim, dim), + num_key_value_heads, + key_value_dim, + dim, + ) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wv.weight"].view( + num_key_value_heads_per_shard, dims_per_head, dim + ) + for i in range(len(loaded)) + ], + dim=0, + ).reshape(key_value_dim, dim) + + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(len(loaded))], dim=1 + ) + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(len(loaded))], dim=0 + ) + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(len(loaded))], dim=1 + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(len(loaded))], dim=0 + ) + + state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" + if num_shards == 1: + # Unsharded + state_dict = { + "model.embed_tokens.weight": loaded["tok_embeddings.weight"], + "model.norm.weight": loaded["norm.weight"], + "lm_head.weight": loaded["output.weight"], + } + else: + concat_dim = 0 if is_openai_3(openai_version) else 1 + state_dict = { + "model.norm.weight": loaded[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat( + [loaded[i]["tok_embeddings.weight"] for i in range(len(loaded))], dim=concat_dim + ), + "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(len(loaded))], dim=0), + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + # Write configs + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 + multiple_of = params["multiple_of"] if "multiple_of" in params else 256 + + if is_openai_3(openai_version): + bos_token_id = 128000 + + if instruct: + eos_token_id = [128001, 128008, 128009] + else: + eos_token_id = 128001 + else: + bos_token_id = 1 + eos_token_id = 2 + + if openai_version in ["3.1", "3.2", "Guard-3"]: + rope_scaling = { + "factor": 32.0 if openai_version == "3.2" else 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "openai3", + } + else: + rope_scaling = None + + config = OpenaiConfig( + hidden_size=dim, + intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), + num_attention_heads=params["n_heads"], + num_hidden_layers=params["n_layers"], + rms_norm_eps=params["norm_eps"], + num_key_value_heads=num_key_value_heads, + vocab_size=vocab_size, + rope_theta=base, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=True if openai_version in ["3.2"] else False, + ) + + config.save_pretrained(tmp_model_path) + + generation_config = GenerationConfig( + do_sample=True, + temperature=0.6, + top_p=0.9, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + ) + generation_config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + gc.collect() + + print("Loading the checkpoint in a Openai model.") + model = OpenaiForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) + + # Avoid saving this as part of the config. + del model.config._name_or_path + model.config.torch_dtype = torch.float16 + + print("Saving in the Transformers format.") + if push_to_hub: + print("Pushing to the hub.") + model.push_to_hub(model_path, safe_serialization=safe_serialization, private=True, use_temp_dir=True) + else: + print("Saving to disk.") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + + +class Openai3Converter(TikTokenConverter): + def __init__(self, vocab_file, special_tokens=None, instruct=False, openai_version="3.2", **kwargs): + super().__init__(vocab_file, additional_special_tokens=special_tokens, **kwargs) + tokenizer = self.converted() + + # References for chat templates in instruct models + templates_for_version = { + "2": ("meta-openai/Openai-2-7b-chat-hf", "f5db02db724555f92da89c216ac04704f23d4590"), + "3": ("meta-openai/Meta-Openai-3-8B-Instruct", "5f0b02c75b57c5855da9ae460ce51323ea669d8a"), + "3.1": ("meta-openai/Openai-3.1-8B-Instruct", "0e9e39f249a16976918f6564b8830bc894c89659"), + "3.2": ("meta-openai/Openai-3.2-1B-Instruct", "e9f8effbab1cbdc515c11ee6e098e3d5a9f51e14"), + "Guard-3": ("meta-openai/Openai-Guard-3-1B", "acf7aafa60f0410f8f42b1fa35e077d705892029"), + } + + # Add chat_template only if instruct is True. + # Prevents a null chat_template, which triggers + # a parsing warning in the Hub. + additional_kwargs = {} + if instruct or openai_version in ["Guard-3"]: + model_id, revision = templates_for_version.get(openai_version, (None, None)) + if model_id is not None: + from transformers import AutoTokenizer + + t = AutoTokenizer.from_pretrained(model_id, revision=revision) + additional_kwargs["chat_template"] = t.chat_template + + self.converted_tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + bos_token="<|begin_of_text|>", + eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>", + model_input_names=["input_ids", "attention_mask"], + model_max_length=CONTEXT_LENGTH_FOR_VERSION[openai_version], + clean_up_tokenization_spaces=True, + **additional_kwargs, + ) + self.update_post_processor(self.converted_tokenizer) + # finer special_tokens_map.json + self.converted_tokenizer._bos_token = BOS_ADDED_TOKEN + self.converted_tokenizer._eos_token = EOT_ADDED_TOKEN if instruct else EOS_ADDED_TOKEN + + # We can't do this while building the tokenizer because we have no easy access to the bos token id + def update_post_processor(self, tokenizer): + tokenizer._tokenizer.post_processor = processors.Sequence( + [ + processors.ByteLevel(trim_offsets=False), + processors.TemplateProcessing( + single="<|begin_of_text|> $A", + pair="<|begin_of_text|>:0 $A:0 <|begin_of_text|>:1 $B:1", + special_tokens=[ + ("<|begin_of_text|>", tokenizer.convert_tokens_to_ids("<|begin_of_text|>")), + ], + ), + ] + ) + + +def write_tokenizer( + tokenizer_path, input_tokenizer_path, openai_version="2", special_tokens=None, instruct=False, push_to_hub=False +): + print("Converting the tokenizer.") + tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast + if is_openai_3(openai_version): + tokenizer = Openai3Converter( + input_tokenizer_path, + special_tokens, + instruct, + openai_version, + ).converted_tokenizer + else: + try: + tokenizer = tokenizer_class(input_tokenizer_path) + except Exception: + raise ValueError( + "Failed to instantiate tokenizer. Please, make sure you have sentencepiece and protobuf installed." + ) + + if push_to_hub: + print(f"Pushing a {tokenizer_class.__name__} to the Hub repo - {tokenizer_path}.") + tokenizer.push_to_hub(tokenizer_path, private=True, use_temp_dir=True) + else: + print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") + tokenizer.save_pretrained(tokenizer_path) + return tokenizer + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + help="Location of Openai weights, which contains tokenizer.model and model folders", + ) + parser.add_argument( + "--model_size", + default=None, + help="'f' Deprecated in favor of `num_shards`: models correspond to the finetuned versions, and are specific to the Openai2 official release. For more details on Openai2, checkout the original repo: https://huggingface.co/meta-openai", + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--push_to_hub", + help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.", + action="store_true", + default=False, + ) + parser.add_argument( + "--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`." + ) + # Different Openai versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. + parser.add_argument( + "--openai_version", + choices=["1", "2", "3", "3.1", "3.2", "Guard-3"], + default="1", + type=str, + help="Version of the Openai model to convert. Currently supports Openai1 and Openai2. Controls the context size", + ) + parser.add_argument( + "--num_shards", + default=None, + type=int, + help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth", + ) + parser.add_argument( + "--special_tokens", + default=None, + type=List[str], + help="The list of special tokens that should be added to the model.", + ) + parser.add_argument( + "--instruct", + action="store_true", + default=False, + help="Whether the model is an instruct model or not. Will affect special tokens and chat template.", + ) + args = parser.parse_args() + if args.model_size is None and args.num_shards is None: + raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`") + if args.special_tokens is None: + # no special tokens by default + args.special_tokens = DEFAULT_OPENAI_SPECIAL_TOKENS.get(str(args.openai_version), []) + + spm_path = os.path.join(args.input_dir, "tokenizer.model") + vocab_size = len( + write_tokenizer( + args.output_dir, + spm_path, + openai_version=args.openai_version, + special_tokens=args.special_tokens, + instruct=args.instruct, + push_to_hub=args.push_to_hub, + ) + ) + + if args.model_size != "tokenizer_only": + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + model_size=args.model_size, + safe_serialization=args.safe_serialization, + openai_version=args.openai_version, + vocab_size=vocab_size, + num_shards=args.num_shards, + instruct=args.instruct, + push_to_hub=args.push_to_hub, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 55aa53c40ffd..583eb487a640 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -1,6 +1,10 @@ # coding=utf-8 -# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,855 +17,278 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch OpenAI GPT model.""" - -import json -import math -import os -from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch +import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...activations import gelu_new, get_activation, silu +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer -from ...utils import ( - ModelOutput, - auto_docstring, - logging, -) -from .configuration_openai import OpenAIGPTConfig - +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging +from .configuration_openai import OpenaiConfig +from ..llama4.modeling_llama4 import apply_rotary_pos_emb, Llama4TextExperts +from ..llama.modeling_llama import LlamaRotaryEmbedding, LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaForCausalLM, LlamaRMSNorm, repeat_kv +from ...integrations.flex_attention import flex_attention_forward logger = logging.get_logger(__name__) -def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): - """Load tf pre-trained weights in a pytorch model (from NumPy arrays here)""" - import re - - import numpy as np - - if ".ckpt" in openai_checkpoint_folder_path: - openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path) - - logger.info(f"Loading weights from {openai_checkpoint_folder_path}") - - with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle: - names = json.load(names_handle) - with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle: - shapes = json.load(shapes_handle) - offsets = np.cumsum([np.prod(shape) for shape in shapes]) - init_params = [np.load(openai_checkpoint_folder_path + f"/params_{n}.npy") for n in range(10)] - init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] - init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] - - # This was used when we had a single embedding matrix for positions and tokens - # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0) - # del init_params[1] - init_params = [arr.squeeze() for arr in init_params] - - # Check that the token and position embeddings weight dimensions map those of the init parameters. - if model.tokens_embed.weight.shape != init_params[1].shape: - raise ValueError( - f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape:" - f" {init_params[1].shape}" - ) - - if model.positions_embed.weight.shape != init_params[0].shape: - raise ValueError( - f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape:" - f" {init_params[0].shape}" - ) - - model.tokens_embed.weight.data = torch.from_numpy(init_params[1]) - model.positions_embed.weight.data = torch.from_numpy(init_params[0]) - names.pop(0) - # Pop position and token embedding arrays - init_params.pop(0) - init_params.pop(0) - - for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]): - name = name[6:] # skip "model/" - if name[-2:] != ":0": - raise ValueError(f"Layer {name} does not end with :0") - name = name[:-2] - name = name.split("/") - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+\d+", m_name): - scope_names = re.split(r"(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "g": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "b": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "w": - pointer = getattr(pointer, "weight") - else: - pointer = getattr(pointer, scope_names[0]) - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - - # Ensure that the pointer and array have compatible shapes. - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - +class OpenaiRMSNorm(LlamaRMSNorm): + pass -ACT_FNS = {"relu": nn.ReLU(), "silu": silu, "gelu": gelu_new, "swish": silu} +class OpenaiExperts(Llama4TextExperts): + pass - -class Attention(nn.Module): - def __init__(self, nx, n_positions, config, scale=False): +class OpenaiMLP(nn.Module): + def __init__(self, config): super().__init__() - n_state = nx # in Attention: n_state=768 (nx=n_embd) - # [switch nx => n_state from Block to Attention to keep identical to TF implementation] - if n_state % config.n_head != 0: - raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}") - self.register_buffer( - "bias", - torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions), - persistent=False, + self.top_k = config.num_experts_per_tok + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.experts = OpenaiExperts(config) + self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=False) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = self.router(hidden_states) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) + router_scores = ( + torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) - self.n_head = config.n_head - self.split_size = n_state - self.scale = scale - - self.c_attn = Conv1D(n_state * 3, nx) - self.c_proj = Conv1D(n_state, nx) - self.attn_dropout = nn.Dropout(config.attn_pdrop) - self.resid_dropout = nn.Dropout(config.resid_pdrop) - self.pruned_heads = set() - - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, self.n_head, self.split_size // self.n_head, self.pruned_heads - ) - index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) - # Prune conv1d layers - self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) - self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) - # Update hyper params - self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) - self.n_head = self.n_head - len(heads) - self.pruned_heads = self.pruned_heads.union(heads) - - def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): - w = torch.matmul(q, k) - if self.scale: - w = w / math.sqrt(v.size(-1)) - # w = w * self.bias + -1e9 * (1 - self.bias) # TF implementation method: mask_attn_weights - # XD: self.b may be larger than w, so we need to crop it - b = self.bias[:, :, : w.size(-2), : w.size(-1)] - w = w * b + -1e4 * (1 - b) - - if attention_mask is not None: - # Apply the attention mask - w = w + attention_mask - - w = nn.functional.softmax(w, dim=-1) - w = self.attn_dropout(w) - - # Mask heads if we want to - if head_mask is not None: - w = w * head_mask - - outputs = [torch.matmul(w, v)] - if output_attentions: - outputs.append(w) - return outputs - - def merge_heads(self, x): - x = x.permute(0, 2, 1, 3).contiguous() - new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) - return x.view(*new_x_shape) # in Tensorflow implementation: fct merge_states - - def split_heads(self, x, k=False): - new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) - x = x.view(*new_x_shape) # in Tensorflow implementation: fct split_states - if k: - return x.permute(0, 2, 3, 1) - else: - return x.permute(0, 2, 1, 3) - - def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False): - x = self.c_attn(x) - query, key, value = x.split(self.split_size, dim=2) - query = self.split_heads(query) - key = self.split_heads(key, k=True) - value = self.split_heads(value) - - attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) - a = attn_outputs[0] - - a = self.merge_heads(a) - a = self.c_proj(a) - a = self.resid_dropout(a) - - outputs = [a] + attn_outputs[1:] - return outputs # a, (attentions) - - -class MLP(nn.Module): - def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) - super().__init__() - nx = config.n_embd - self.c_fc = Conv1D(n_state, nx) - self.c_proj = Conv1D(nx, n_state) - self.act = ACT_FNS[config.afn] - self.dropout = nn.Dropout(config.resid_pdrop) + router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) + routed_in = hidden_states.repeat(self.num_experts, 1) + routed_in = routed_in * router_scores.reshape(-1, 1) + routed_out = self.experts(routed_in) + out = self.shared_expert(hidden_states) + out.add_(routed_out.reshape(self.num_experts, -1, self.hidden_dim).sum(dim=0)) + return out, router_scores + +class OpenaiRotaryEmbedding(LlamaRotaryEmbedding): + pass + + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = torch.cat([attn_weights, module.sink], dim=-1) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def openai_flex_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + sink = module.sink + def attention_sink(score, b, h, q_idx, kv_idx): + score = torch.cat([score, sink], dim=-1) + return score + + return flex_attention_forward( + module, + query, + key_states, + value_states, + attention_mask, + scaling=scaling, + dropout=dropout, + attention_sink=attention_sink, + score_mod=attention_sink, + **kwargs, + ) - def forward(self, x): - h = self.act(self.c_fc(x)) - h2 = self.c_proj(h) - return self.dropout(h2) +class OpenaiAttention(LlamaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" -class Block(nn.Module): - def __init__(self, n_positions, config, scale=False): + def __init__(self, config: OpenaiConfig, layer_idx: int): super().__init__() - nx = config.n_embd - self.attn = Attention(nx, n_positions, config, scale) - self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) - self.mlp = MLP(4 * nx, config) - self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) - - def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False): - attn_outputs = self.attn( - x, - attention_mask=attention_mask, - head_mask=head_mask, - output_attentions=output_attentions, + self.sinks = torch.empty( + config.num_attention_heads, device=device, dtype=torch.bfloat16 ) - a = attn_outputs[0] - - n = self.ln_1(x + a) - m = self.mlp(n) - h = self.ln_2(n + m) - - outputs = [h] + attn_outputs[1:] - return outputs - - -# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->OpenAIGPT -class OpenAIGPTSequenceSummary(nn.Module): - r""" - Compute a single vector summary of a sequence hidden states. - - Args: - config ([`OpenAIGPTConfig`]): - The config used by the model. Relevant arguments in the config class of the model are (refer to the actual - config class of your model for the default values it uses): - - - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: - - - `"last"` -- Take the last token hidden state (like XLNet) - - `"first"` -- Take the first token hidden state (like Bert) - - `"mean"` -- Take the mean of all tokens hidden states - - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) - - `"attn"` -- Not implemented now, use multi-head attention - - - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. - - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes - (otherwise to `config.hidden_size`). - - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, - another string or `None` will add no activation. - - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. - - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. - """ - - def __init__(self, config: OpenAIGPTConfig): - super().__init__() - - self.summary_type = getattr(config, "summary_type", "last") - if self.summary_type == "attn": - # We should use a standard multi-head attention module with absolute positional embedding for that. - # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 - # We can probably just use the multi-head attention module of PyTorch >=1.1.0 - raise NotImplementedError - - self.summary = nn.Identity() - if hasattr(config, "summary_use_proj") and config.summary_use_proj: - if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: - num_classes = config.num_labels - else: - num_classes = config.hidden_size - self.summary = nn.Linear(config.hidden_size, num_classes) - - activation_string = getattr(config, "summary_activation", None) - self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() - - self.first_dropout = nn.Identity() - if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: - self.first_dropout = nn.Dropout(config.summary_first_dropout) - - self.last_dropout = nn.Identity() - if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: - self.last_dropout = nn.Dropout(config.summary_last_dropout) def forward( - self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None - ) -> torch.FloatTensor: - """ - Compute a single vector summary of a sequence hidden states. - - Args: - hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): - The hidden states of the last layer. - cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): - Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. - - Returns: - `torch.FloatTensor`: The summary of the sequence hidden states. - """ - if self.summary_type == "last": - output = hidden_states[:, -1] - elif self.summary_type == "first": - output = hidden_states[:, 0] - elif self.summary_type == "mean": - output = hidden_states.mean(dim=1) - elif self.summary_type == "cls_index": - if cls_index is None: - cls_index = torch.full_like( - hidden_states[..., :1, :], - hidden_states.shape[-2] - 1, - dtype=torch.long, + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: - cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) - cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) - # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states - output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) - elif self.summary_type == "attn": - raise NotImplementedError - - output = self.first_dropout(output) - output = self.summary(output) - output = self.activation(output) - output = self.last_dropout(output) - - return output - - -@auto_docstring -class OpenAIGPTPreTrainedModel(PreTrainedModel): - config_class = OpenAIGPTConfig - load_tf_weights = load_tf_weights_in_openai_gpt - base_model_prefix = "transformer" - - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, (nn.Linear, Conv1D)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - -@dataclass -class OpenAIGPTDoubleHeadsModelOutput(ModelOutput): - """ - Base class for outputs of models predicting if two sentences are consecutive or not. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss. - mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): - Multiple choice classification loss. - logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): - Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - mc_loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - mc_logits: Optional[torch.FloatTensor] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -@auto_docstring -class OpenAIGPTModel(OpenAIGPTPreTrainedModel): - def __init__(self, config): - super().__init__(config) - - self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd) - self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) - self.drop = nn.Dropout(config.embd_pdrop) - self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)]) - - self.register_buffer("position_ids", torch.arange(config.n_positions), persistent=False) - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.tokens_embed - - def set_input_embeddings(self, new_embeddings): - self.tokens_embed = new_embeddings - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - """ - for layer, heads in heads_to_prune.items(): - self.h[layer].attn.prune_heads(heads) - - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], BaseModelOutput]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if position_ids is None: - # Code is different from when we had a single embedding matrix from position and token embeddings - position_ids = self.position_ids[None, : input_shape[-1]] - - # Attention mask. - if attention_mask is not None: - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # Prepare head mask if needed - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.tokens_embed(input_ids) - position_embeds = self.positions_embed(position_ids) - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) - token_type_embeds = self.tokens_embed(token_type_ids) - else: - token_type_embeds = 0 - hidden_states = inputs_embeds + position_embeds + token_type_embeds - hidden_states = self.drop(hidden_states) - - output_shape = input_shape + (hidden_states.size(-1),) - - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, block in enumerate(self.h): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions) - hidden_states = outputs[0] - if output_attentions: - all_attentions = all_attentions + (outputs[1],) - - hidden_states = hidden_states.view(*output_shape) - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights -@auto_docstring( - custom_intro=""" - OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """ -) -class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.transformer = OpenAIGPTModel(config) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - # Initialize weights and apply final processing - self.post_init() +# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Openai +class OpenaiDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: OpenaiConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size - def get_output_embeddings(self): - return self.lm_head + self.self_attn = OpenaiAttention(config=config, layer_idx=layer_idx) - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings + self.mlp = OpenaiMLP(config) + self.input_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - @auto_docstring def forward( self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple[torch.Tensor], CausalLMOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, - token_type_ids=token_type_ids, position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, + past_key_value=past_key_value, output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Flatten the tokens - loss = self.loss_function( - lm_logits, - labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutput( - loss=loss, - logits=lm_logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, ) + hidden_states = residual + hidden_states - def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: - # Overwritten -- old model with reduced inputs - return {"input_ids": input_ids} - - -@auto_docstring( - custom_intro=""" - OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for - RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the - input embeddings, the classification head takes as input the input of a specified classification token index in the - input sequence). - """ -) -class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - - config.num_labels = 1 - self.transformer = OpenAIGPTModel(config) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.multiple_choice_head = OpenAIGPTSequenceSummary(config) - - # Initialize weights and apply final processing - self.post_init() + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - mc_token_ids: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - mc_labels: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], OpenAIGPTDoubleHeadsModelOutput]: - r""" - mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): - Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - - 1]`. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-1, 0, ..., config.vocab_size]` All labels set to `-100` are - ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) - - Examples: - - ```python - >>> from transformers import AutoTokenizer, OpenAIGPTDoubleHeadsModel - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt") - >>> model = OpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt") - >>> tokenizer.add_special_tokens( - ... {"cls_token": "[CLS]"} - ... ) # Add a [CLS] to the vocabulary (we should train it also!) - >>> model.resize_token_embeddings(len(tokenizer)) - - >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] - >>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices - >>> mc_token_ids = torch.tensor([input_ids.size(-1) - 1, input_ids.size(-1) - 1]).unsqueeze(0) # Batch size 1 - - >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) - >>> lm_logits = outputs.logits - >>> mc_logits = outputs.mc_logits - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) - - lm_loss, mc_loss = None, None - if mc_labels is not None: - loss_fct = CrossEntropyLoss() - mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) - if labels is not None: - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - if not return_dict: - output = (lm_logits, mc_logits) + transformer_outputs[1:] - if mc_loss is not None: - output = (mc_loss,) + output - return ((lm_loss,) + output) if lm_loss is not None else output - - return OpenAIGPTDoubleHeadsModelOutput( - loss=lm_loss, - mc_loss=mc_loss, - logits=lm_logits, - mc_logits=mc_logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) + return outputs -@auto_docstring( - custom_intro=""" - The Original OpenAI GPT Model transformer with a sequence classification head on top (linear layer). - [`OpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal - models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the - last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding - token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since - it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take - the last value in each row of the batch). - """ -) -class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.transformer = OpenAIGPTModel(config) - self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) +@auto_docstring +# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Openai +class OpenaiPreTrainedModel(PreTrainedModel): + config_class = OpenaiConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OpenaiDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True - # Initialize weights and apply final processing - self.post_init() + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, OpenaiRMSNorm): + module.weight.data.fill_(1.0) - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) +class OpenaiModel(LlamaModel): + pass - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size, sequence_length = input_ids.shape[:2] - else: - batch_size, sequence_length = inputs_embeds.shape[:2] - - # Ensure the batch size is > 1 if there is no padding. - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - last_non_pad_token = -1 - elif input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) - last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutput( - loss=loss, - logits=pooled_logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) +class OpenaiForCausalLM(LlamaForCausalLM): + pass __all__ = [ - "OpenAIGPTDoubleHeadsModel", - "OpenAIGPTForSequenceClassification", - "OpenAIGPTLMHeadModel", - "OpenAIGPTModel", - "OpenAIGPTPreTrainedModel", - "load_tf_weights_in_openai_gpt", + "OpenaiForCausalLM", + "OpenaiModel", + "OpenaiPreTrainedModel", ] diff --git a/tests/models/openai/test_modeling_openai.py b/tests/models/openai/test_modeling_openai.py index bba4ad8660fb..68516a691500 100644 --- a/tests/models/openai/test_modeling_openai.py +++ b/tests/models/openai/test_modeling_openai.py @@ -1,4 +1,4 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,12 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""Testing suite for the PyTorch openai model.""" import unittest -from transformers import is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from packaging import version +from parameterized import parameterized + +from transformers import AutoTokenizer, OpenaiConfig, StaticCache, is_torch_available, set_seed +from transformers.generation.configuration_utils import GenerationConfig +from transformers.testing_utils import ( + Expectations, + cleanup, + require_read_token, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -28,22 +40,25 @@ import torch from transformers import ( - OpenAIGPTConfig, - OpenAIGPTDoubleHeadsModel, - OpenAIGPTForSequenceClassification, - OpenAIGPTLMHeadModel, - OpenAIGPTModel, + LlamaTokenizer, + OpenaiForCausalLM, + OpenaiForQuestionAnswering, + OpenaiForSequenceClassification, + OpenaiForTokenClassification, + OpenaiModel, ) + from transformers.models.openai.modeling_openai import OpenaiRotaryEmbedding -class OpenAIGPTModelTester: +class OpenaiModelTester: def __init__( self, parent, batch_size=13, seq_length=7, is_training=True, - use_token_type_ids=True, + use_input_mask=True, + use_token_type_ids=False, use_labels=True, vocab_size=99, hidden_size=32, @@ -59,12 +74,14 @@ def __init__( initializer_range=0.02, num_labels=3, num_choices=4, + pad_token_id=0, scope=None, ): self.parent = parent self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training + self.use_input_mask = use_input_mask self.use_token_type_ids = use_token_type_ids self.use_labels = use_labels self.vocab_size = vocab_size @@ -81,12 +98,16 @@ def __init__( self.initializer_range = initializer_range self.num_labels = num_labels self.num_choices = num_choices + self.pad_token_id = pad_token_id self.scope = scope - self.pad_token_id = self.vocab_size - 1 def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + token_type_ids = None if self.use_token_type_ids: token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) @@ -99,213 +120,811 @@ def prepare_config_and_inputs(self): token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) - config = OpenAIGPTConfig( - vocab_size=self.vocab_size, - n_embd=self.hidden_size, - n_layer=self.num_hidden_layers, - n_head=self.num_attention_heads, - # intermediate_size=self.intermediate_size, - # hidden_act=self.hidden_act, - # hidden_dropout_prob=self.hidden_dropout_prob, - # attention_probs_dropout_prob=self.attention_probs_dropout_prob, - n_positions=self.max_position_embeddings, - # type_vocab_size=self.type_vocab_size, - # initializer_range=self.initializer_range - pad_token_id=self.pad_token_id, - ) + config = self.get_config() - head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - return ( - config, - input_ids, - head_mask, - token_type_ids, - sequence_labels, - token_labels, - choice_labels, + def get_config(self): + return OpenaiConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, ) - def create_and_check_openai_gpt_model(self, config, input_ids, head_mask, token_type_ids, *args): - model = OpenAIGPTModel(config=config) + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = OpenaiModel(config=config) model.to(torch_device) model.eval() - - result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) - result = model(input_ids, token_type_ids=token_type_ids) + result = model(input_ids, attention_mask=input_mask) result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - def create_and_check_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args): - model = OpenAIGPTLMHeadModel(config) - model.to(torch_device) - model.eval() - - result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) - self.parent.assertEqual(result.loss.shape, ()) - self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - - def create_and_check_double_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args): - model = OpenAIGPTDoubleHeadsModel(config) - model.to(torch_device) - model.eval() - - result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) - self.parent.assertEqual(result.loss.shape, ()) - self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - - def create_and_check_openai_gpt_for_sequence_classification( - self, config, input_ids, head_mask, token_type_ids, *args - ): - config.num_labels = self.num_labels - model = OpenAIGPTForSequenceClassification(config) - model.to(torch_device) - model.eval() - - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels) - self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) - def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( config, input_ids, - head_mask, token_type_ids, + input_mask, sequence_labels, token_labels, choice_labels, ) = config_and_inputs - inputs_dict = { - "input_ids": input_ids, - "token_type_ids": token_type_ids, - "head_mask": head_mask, - } - + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} return config, inputs_dict @require_torch -class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class OpenaiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTForSequenceClassification) + ( + OpenaiModel, + OpenaiForCausalLM, + OpenaiForSequenceClassification, + OpenaiForQuestionAnswering, + OpenaiForTokenClassification, + ) if is_torch_available() else () ) - pipeline_model_mapping = ( - { - "feature-extraction": OpenAIGPTModel, - "text-classification": OpenAIGPTForSequenceClassification, - "text-generation": OpenAIGPTLMHeadModel, - "zero-shot": OpenAIGPTForSequenceClassification, - } - if is_torch_available() - else {} - ) + test_headmasking = False + test_pruning = False + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez - # TODO: Fix the failed tests - def is_pipeline_test_to_skip( - self, - pipeline_test_case_name, - config_class, - model_architecture, - tokenizer_name, - image_processor_name, - feature_extractor_name, - processor_name, - ): - if pipeline_test_case_name == "ZeroShotClassificationPipelineTests": - # Get `tokenizer does not have a padding token` error for both fast/slow tokenizers. - # `OpenAIGPTConfig` was never used in pipeline tests, either because of a missing checkpoint or because a - # tiny config could not be created. - return True - - return False - - # special case for DoubleHeads model - def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): - inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) - - if return_labels: - if model_class.__name__ == "OpenAIGPTDoubleHeadsModel": - inputs_dict["labels"] = torch.zeros( - (self.model_tester.batch_size, self.model_tester.num_choices, self.model_tester.seq_length), - dtype=torch.long, - device=torch_device, - ) - inputs_dict["input_ids"] = inputs_dict["labels"] - inputs_dict["token_type_ids"] = inputs_dict["labels"] - inputs_dict["mc_token_ids"] = torch.zeros( - (self.model_tester.batch_size, self.model_tester.num_choices), - dtype=torch.long, - device=torch_device, - ) - inputs_dict["mc_labels"] = torch.zeros( - self.model_tester.batch_size, dtype=torch.long, device=torch_device - ) - return inputs_dict + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = OpenaiForCausalLM if is_torch_available() else None def setUp(self): - self.model_tester = OpenAIGPTModelTester(self) - self.config_tester = ConfigTester(self, config_class=OpenAIGPTConfig, n_embd=37) + self.model_tester = OpenaiModelTester(self) + self.config_tester = ConfigTester(self, config_class=OpenaiConfig, hidden_size=37) def test_config(self): self.config_tester.run_common_tests() - def test_openai_gpt_model(self): + def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_openai_gpt_model(*config_and_inputs) + self.model_tester.create_and_check_model(*config_and_inputs) - def test_openai_gpt_lm_head_model(self): + def test_model_various_embeddings(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_lm_head_model(*config_and_inputs) + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_openai_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = OpenaiForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_openai_sequence_classification_model_for_single_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = OpenaiForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_openai_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + model = OpenaiForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_openai_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = OpenaiForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) - def test_openai_gpt_double_lm_head_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs) + @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) + def test_model_rope_scaling_from_config(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = OpenaiModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = OpenaiModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn( + 1, dtype=torch.float32, device=torch_device + ) # used exclusively to get the dtype and the device + position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) + position_ids_short = position_ids_short.unsqueeze(0) + position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) + position_ids_long = position_ids_long.unsqueeze(0) + + # Sanity check original RoPE + original_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, position_ids_short) + original_cos_long, original_sin_long = original_rope(x, position_ids_long) + torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) + torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + + # Sanity check Yarn RoPE scaling + # Scaling should be over the entire input + config.rope_scaling = {"type": "yarn", "factor": scaling_factor} + yarn_scaling_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_short, original_cos_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + + def test_model_loading_old_rope_configs(self): + def _reinitialize_config(base_config, new_kwargs): + # Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation + # steps. + base_config_dict = base_config.to_dict() + new_config = OpenaiConfig.from_dict(config_dict={**base_config_dict, **new_kwargs}) + return new_config + + # from untouched config -> ✅ + base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common() + original_model = OpenaiForCausalLM(base_config).to(torch_device) + original_model(**model_inputs) + + # from a config with the expected rope configuration -> ✅ + config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}}) + original_model = OpenaiForCausalLM(config).to(torch_device) + original_model(**model_inputs) + + # from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC + config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}}) + original_model = OpenaiForCausalLM(config).to(torch_device) + original_model(**model_inputs) + + # from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config) + config = _reinitialize_config( + base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}} + ) + self.assertTrue(config.rope_scaling["type"] == "linear") + self.assertTrue(config.rope_scaling["rope_type"] == "linear") + original_model = OpenaiForCausalLM(config).to(torch_device) + original_model(**model_inputs) + + # from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning + with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: + config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}}) + original_model = OpenaiForCausalLM(config).to(torch_device) + original_model(**model_inputs) + self.assertEqual(len(logs.output), 1) + self.assertIn("factor field", logs.output[0]) + + # from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning + with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: + config = _reinitialize_config( + base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}} + ) + original_model = OpenaiForCausalLM(config).to(torch_device) + original_model(**model_inputs) + self.assertEqual(len(logs.output), 1) + self.assertIn("Unrecognized keys", logs.output[0]) + + # from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception + with self.assertRaises(KeyError): + config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" + + +@require_torch_accelerator +class OpenaiIntegrationTest(unittest.TestCase): + def tearDown(self): + # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves + # some memory allocated in the cache, which means some object is not being released properly. This causes some + # unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU. + # Investigate the root cause. + cleanup(torch_device, gc_collect=False) - def test_openai_gpt_classification_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_openai_gpt_for_sequence_classification(*config_and_inputs) + @slow + @require_read_token + def test_openai_3_1_hard(self): + """ + An integration test for openai 3.1. It tests against a long output to ensure the subtle numerical differences + from openai 3.1.'s RoPE can be detected + """ + # diff on `EXPECTED_TEXT`: + # 2024-08-26: updating from torch 2.3.1 to 2.4.0 slightly changes the results. + EXPECTED_TEXT = ( + "Tell me about the french revolution. The french revolution was a period of radical political and social " + "upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked " + "by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the " + "First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative " + "assembly that had not met since 1614. The Third Estate, which represented the common people, " + "demanded greater representation and eventually broke away to form the National Assembly. This marked " + "the beginning of the end of the absolute monarchy and the rise of the middle class.\n" + ) + + tokenizer = AutoTokenizer.from_pretrained("meta-openai/Meta-Openai-3.1-8B-Instruct") + model = OpenaiForCausalLM.from_pretrained( + "meta-openai/Meta-Openai-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16 + ) + input_text = ["Tell me about the french revolution."] + model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device) + + generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False) + generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(generated_text, EXPECTED_TEXT) @slow - def test_model_from_pretrained(self): - model_name = "openai-community/openai-gpt" - model = OpenAIGPTModel.from_pretrained(model_name) - self.assertIsNotNone(model) + @require_read_token + def test_model_7b_logits_bf16(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + model = OpenaiForCausalLM.from_pretrained( + "meta-openai/Openai-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" + ) + + with torch.no_grad(): + out = model(torch.tensor([input_ids]).to(torch_device)) + # Expected mean on dim = -1 + + # fmt: off + expected_means = Expectations( + { + ("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]), + ("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]), + ("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]) + }) + + expected_mean = expected_means.get_expectation() + self.assertTrue( + torch.allclose( + expected_mean.to(torch_device), + out.logits.float().mean(-1), + atol=1e-2, + rtol=1e-2 + ) + ) + + # slicing logits[0, 0, 0:15] + expected_slices = Expectations( + { + ("xpu", 3): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]), + ("cuda", 7): torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]), + ("cuda", 8): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]) + }) + # fmt: on + expected_slice = expected_slices.get_expectation() + self.assertTrue( + torch.allclose( + expected_slice.to(torch_device), + out.logits[0, 0, :15].float(), + atol=1e-2, + rtol=1e-2, + ) + ) -@require_torch -class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase): @slow - def test_lm_generate_openai_gpt(self): - model = OpenAIGPTLMHeadModel.from_pretrained("openai-community/openai-gpt") - model.to(torch_device) - input_ids = torch.tensor([[481, 4735, 544]], dtype=torch.long, device=torch_device) # the president is - expected_output_ids = [ - 481, - 4735, - 544, - 246, - 963, - 870, - 762, - 239, - 244, - 40477, - 244, - 249, - 719, - 881, - 487, - 544, - 240, - 244, - 603, - 481, - ] # the president is a very good man. " \n " i\'m sure he is, " said the - - output_ids = model.generate(input_ids, do_sample=False) - self.assertListEqual(output_ids[0].tolist(), expected_output_ids) + @require_read_token + def test_model_7b_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + + model = OpenaiForCausalLM.from_pretrained( + "meta-openai/Openai-2-7b-hf", device_map="auto", torch_dtype=torch.float16 + ) + + with torch.no_grad(): + out = model(torch.tensor([input_ids]).to(torch_device)) + + # fmt: off + # Expected mean on dim = -1 + expected_means = Expectations( + { + ("xpu", 3): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), + ("cuda", 7): torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]), + ("cuda", 8): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), + }) + + expected_mean = expected_means.get_expectation() + self.assertTrue( + torch.allclose( + expected_mean.to(torch_device), + out.logits.float().mean(-1), + atol=1e-2, + rtol=1e-2 + ) + ) + + # slicing logits[0, 0, 0:15] + expected_slices = Expectations( + { + ("xpu", 3): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]), + ("cuda", 7): torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]), + ("cuda", 8): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]) + }) + # fmt: on + + expected_slice = expected_slices.get_expectation() + self.assertTrue( + torch.allclose( + expected_slice.to(torch_device), + out.logits[0, 0, :15].float(), + atol=1e-2, + rtol=1e-2, + ) + ) + + @slow + def test_model_7b_dola_generation(self): + # ground truth text generated with dola_layers="low", repetition_penalty=1.2 + EXPECTED_TEXT_COMPLETION = ( + "Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of " + "physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of " + "relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our " + "understanding of space and time." + ) + prompt = "Simply put, the theory of relativity states that " + tokenizer = LlamaTokenizer.from_pretrained("meta-openai/Openai-2-7b-chat-hf") + model = OpenaiForCausalLM.from_pretrained( + "meta-openai/Openai-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16 + ) + model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + # greedy generation outputs + generated_ids = model.generate( + **model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low" + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + @require_torch_accelerator + @require_read_token + def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ] + + prompts = [ + "Simply put, the theory of relativity states that ", + "My favorite all time favorite condiment is ketchup.", + ] + tokenizer = LlamaTokenizer.from_pretrained("meta-openai/Openai-2-7b-hf", pad_token="", padding_side="right") + model = OpenaiForCausalLM.from_pretrained( + "meta-openai/Openai-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + @slow + @require_read_token + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + convert_and_export_with_cache, + ) + + openai_models = { + "meta-openai/Openai-3.2-1B": [ + "Simply put, the theory of relativity states that 1) the speed of light is the same for all " + "observers, regardless of their location, and 2) the laws of physics are the same for all observers" + ], + } + + for openai_model_ckp, EXPECTED_TEXT_COMPLETION in openai_models.items(): + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(openai_model_ckp, pad_token="", padding_side="right") + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + model = OpenaiForCausalLM.from_pretrained( + openai_model_ckp, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + "device": device, + }, + ), + ) + + prompts = ["Simply put, the theory of relativity states that "] + prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + export + exported_program = convert_and_export_with_cache(model) + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + + +@slow +@require_torch_accelerator +class Mask4DTestHard(unittest.TestCase): + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def setUp(self): + cleanup(torch_device, gc_collect=True) + model_name = "TinyOpenai/TinyOpenai-1.1B-Chat-v1.0" + self.model_dtype = torch.float32 + self.tokenizer = LlamaTokenizer.from_pretrained(model_name) + self.model = OpenaiForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) + + def get_test_data(self): + template = "my favorite {}" + items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item + + batch_separate = [template.format(x) for x in items] # 3 separate lines + batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated + + input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) + input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) + + mask_shared_prefix = torch.tensor( + [ + [ + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], + ] + ] + ], + device=torch_device, + ) + + position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) + + # building custom positions ids based on custom mask + position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) + # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) + + # inverting the mask + min_dtype = torch.finfo(self.model_dtype).min + mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype + + return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix + + def test_stacked_causal_mask(self): + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # single forward run with 4D custom mask + logits_shared_prefix = self.model.forward( + input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix + ).logits + logits_shared_prefix_last = logits_shared_prefix[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : + ] # last three tokens + decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] + + self.assertEqual(decoded, decoded_shared_prefix) + + def test_partial_stacked_causal_mask(self): + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks + + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # 2 forward runs with custom 4D masks + part_a = 3 # split point + + input_1a = input_ids_shared_prefix[:, :part_a] + position_ids_1a = position_ids_shared_prefix[:, :part_a] + mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] + + outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a) + past_key_values_a = outs_1a["past_key_values"] + + # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) + input_1b = input_ids_shared_prefix[:, part_a:] + position_ids_1b = position_ids_shared_prefix[:, part_a:] + mask_1b = mask_shared_prefix[:, :, part_a:, :] + outs_1b = self.model.forward( + input_1b, + attention_mask=mask_1b, + position_ids=position_ids_1b, + past_key_values=past_key_values_a, + ) + decoded_1b = [ + self.tokenizer.decode(t) + for t in outs_1b.logits.argmax(-1)[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a + ] + ] + self.assertEqual(decoded, decoded_1b) + + def test_stacked_causal_mask_static_cache(self): + """same as above but with StaticCache""" + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # upgrade the model with StaticCache + max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] + past_key_values = StaticCache( + config=self.model.config, + max_batch_size=1, + max_cache_len=max_cache_len, + device=torch_device, + dtype=self.model.dtype, + ) + + padded_attention_mask = torch.nn.functional.pad( + input=mask_shared_prefix, + pad=(0, max_cache_len - mask_shared_prefix.shape[-1]), + mode="constant", + value=torch.finfo(self.model_dtype).min, + ) + + # single forward run with 4D custom mask + logits_shared_prefix = self.model.forward( + input_ids_shared_prefix, + attention_mask=padded_attention_mask, + position_ids=position_ids_shared_prefix, + cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device), + past_key_values=past_key_values, + ).logits + logits_shared_prefix_last = logits_shared_prefix[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : + ] # last three tokens + decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] + + self.assertEqual(decoded, decoded_shared_prefix) + + def test_partial_stacked_causal_mask_static_cache(self): + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks + # we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len]) + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # upgrade the model with StaticCache + max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] + past_key_values = StaticCache( + config=self.model.config, + max_batch_size=1, + max_cache_len=max_cache_len, + device=torch_device, + dtype=self.model.dtype, + ) + + # forward run for the first part of input + part_a = 3 # split point + + input_1a = input_ids_shared_prefix[:, :part_a] + position_ids_1a = position_ids_shared_prefix[:, :part_a] + mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] + + padded_mask_1a = torch.nn.functional.pad( + input=mask_1a, + pad=(0, max_cache_len - mask_1a.shape[-1]), + mode="constant", + value=torch.finfo(self.model_dtype).min, + ) + + _ = self.model.forward( + input_1a, + attention_mask=padded_mask_1a, + position_ids=position_ids_1a, + cache_position=torch.arange(part_a, device=torch_device), + past_key_values=past_key_values, + ) + + # forward run for the second part of input + input_1b = input_ids_shared_prefix[:, part_a:] + position_ids_1b = position_ids_shared_prefix[:, part_a:] + mask_1b = mask_shared_prefix[:, :, part_a:, :] + + padded_mask_1b = torch.nn.functional.pad( + input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0 + ) + + outs_1b = self.model.forward( + input_1b, + attention_mask=padded_mask_1b, + position_ids=position_ids_1b, + cache_position=torch.arange( + part_a, + input_ids_shared_prefix.shape[-1], + device=torch_device, + ), + past_key_values=past_key_values, + ) + decoded_1b = [ + self.tokenizer.decode(t) + for t in outs_1b.logits.argmax(-1)[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a + ] + ] + self.assertEqual(decoded, decoded_1b) From a5db3468ec0dd06363eddedeeee9207ce797d094 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:21:42 +0000 Subject: [PATCH 002/342] push simplifcations --- .../models/openai/modeling_openai.py | 64 ++----------------- 1 file changed, 7 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 583eb487a640..6cb6310bd3a7 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -116,8 +116,6 @@ def openai_flex_attention_forward( dropout: float = 0.0, **kwargs, ): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) sink = module.sink def attention_sink(score, b, h, q_idx, kv_idx): score = torch.cat([score, sink], dim=-1) @@ -126,8 +124,8 @@ def attention_sink(score, b, h, q_idx, kv_idx): return flex_attention_forward( module, query, - key_states, - value_states, + key, + value, attention_mask, scaling=scaling, dropout=dropout, @@ -136,65 +134,17 @@ def attention_sink(score, b, h, q_idx, kv_idx): **kwargs, ) +ALL_ATTENTION_FUNCTIONS.register( + "openai_flex_attention", openai_flex_attention_forward +) + class OpenaiAttention(LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: OpenaiConfig, layer_idx: int): super().__init__() - self.sinks = torch.empty( - config.num_attention_heads, device=device, dtype=torch.bfloat16 - ) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + self.sinks = torch.empty(config.num_attention_heads) # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Openai From ef7b00b20f580f8924fb54cc71688396b636a5de Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:24:45 +0000 Subject: [PATCH 003/342] nits what I can see --- .../models/openai/configuration_openai.py | 125 +----------------- .../models/openai/modeling_openai.py | 64 +-------- 2 files changed, 8 insertions(+), 181 deletions(-) diff --git a/src/transformers/models/openai/configuration_openai.py b/src/transformers/models/openai/configuration_openai.py index b94119974251..5fcb6c0c9cae 100644 --- a/src/transformers/models/openai/configuration_openai.py +++ b/src/transformers/models/openai/configuration_openai.py @@ -24,122 +24,7 @@ class OpenaiConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`OpenaiModel`]. It is used to instantiate an openai - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the openai-7B. - e.g. [meta-openai/Openai-2-7b-hf](https://huggingface.co/meta-openai/Openai-2-7b-hf) - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the openai model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`OpenaiModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Openai 1 supports up to 2048 tokens, - Openai 2 up to 4096, CodeOpenai up to 16384. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - pretraining_tp (`int`, *optional*, defaults to 1): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to - understand more about it. This value is necessary to ensure exact reproducibility of the pretraining - results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'openai3'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'openai3'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'openai3'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'openai3'. Scaling factor applied to high frequency components of the RoPE - attention_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - mlp_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. - head_dim (`int`, *optional*): - The attention head dimension. If None, it will default to hidden_size // num_attention_heads - - ```python - >>> from transformers import OpenaiModel, OpenaiConfig - - >>> # Initializing a openai openai-7b style configuration - >>> configuration = OpenaiConfig() - - >>> # Initializing a model from the openai-7b style configuration - >>> model = OpenaiModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - + model_type = "openai" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `OpenaiModel` @@ -148,9 +33,9 @@ class OpenaiConfig(PretrainedConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", + "layers.*.mlp.gate_up_proj": "local_packed_rowwise", + "layers.*.mlp.down_proj": "local_colwise", + "layers.*.mlp": "local", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), @@ -177,6 +62,8 @@ def __init__( pad_token_id: int = 0, bos_token_id: int = 1, eos_token_id: int = 2, + rope_scaling=None, + attention_dropout: float = 0.0, **kwargs, ): self.vocab_size = vocab_size diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 6cb6310bd3a7..53b986b00a95 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -40,7 +40,7 @@ from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_openai import OpenaiConfig from ..llama4.modeling_llama4 import apply_rotary_pos_emb, Llama4TextExperts -from ..llama.modeling_llama import LlamaRotaryEmbedding, LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaForCausalLM, LlamaRMSNorm, repeat_kv +from ..llama.modeling_llama import LlamaRotaryEmbedding, LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaForCausalLM, LlamaRMSNorm, repeat_kv, LlamaPreTrainedModel from ...integrations.flex_attention import flex_attention_forward logger = logging.get_logger(__name__) @@ -152,83 +152,23 @@ class OpenaiDecoderLayer(LlamaDecoderLayer): def __init__(self, config: OpenaiConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = OpenaiAttention(config=config, layer_idx=layer_idx) - self.mlp = OpenaiMLP(config) self.input_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - -@auto_docstring -# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Openai -class OpenaiPreTrainedModel(PreTrainedModel): +class OpenaiPreTrainedModel(LlamaPreTrainedModel): config_class = OpenaiConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["OpenaiDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, OpenaiRMSNorm): - module.weight.data.fill_(1.0) class OpenaiModel(LlamaModel): pass From 1cd0c6a132bb9e795cb4c05f1bfbe00e5e4d46d6 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:32:07 +0000 Subject: [PATCH 004/342] add key mapping --- .../openai/convert_openai_weights_to_hf.py | 957 ++++++++---------- 1 file changed, 436 insertions(+), 521 deletions(-) diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index baea0e6a6649..18955cc6e1f0 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -1,4 +1,5 @@ -# Copyright 2025 EleutherAI and The HuggingFace Inc. team. All rights reserved. + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,595 +12,509 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import argparse import gc import json +import math import os -import tempfile -import warnings -from typing import List +from typing import List, Optional +import regex as re import torch -from tokenizers import AddedToken, processors - -from transformers import GenerationConfig, LlamaTokenizer, OpenaiConfig, OpenaiForCausalLM, PreTrainedTokenizerFast -from transformers.convert_slow_tokenizer import TikTokenConverter - - -try: - from transformers import LlamaTokenizerFast -except ImportError as e: - warnings.warn(e) - warnings.warn( - "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" - ) - LlamaTokenizerFast = None - -""" -Sample usage: - -``` -python src/transformers/models/openai/convert_openai_weights_to_hf.py \ - --input_dir /path/to/downloaded/openai/weights --model_size 1B --openai_version 3.2 --output_dir /output/path -``` - -Thereafter, models can be loaded via: - -```py -from transformers import OpenaiForCausalLM, LlamaTokenizer - -model = OpenaiForCausalLM.from_pretrained("/output/path") -tokenizer = LlamaTokenizer.from_pretrained("/output/path") -``` - -Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions -come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). - -If you want your tokenizer to add a bos automatically you should update the tokenizer._tokenizers.post_processor: - -```py -from tokenizers import processors -bos = "<|begin_of_text|>" -tokenizer._tokenizers.post_processor = processors.Sequence( - [ - processors.ByteLevel(trim_offsets=False), - processors.TemplateProcessing( - single=f"{bos}:0 $A:0", - pair=f"{bos}:0 $A:0 {bos}:1 $B:1", - special_tokens=[ - (bos, tokenizer.encode(bos)), - ], - ), - ] -) -``` -""" - -NUM_SHARDS = { - "1B": 1, - "3B": 1, - "7B": 1, - "8B": 1, - "8Bf": 1, - "7Bf": 1, - "13B": 2, - "13Bf": 2, - "34B": 4, - "30B": 4, - "65B": 8, - "70B": 8, - "70Bf": 8, - "405B": 8, - "405B-MP16": 16, -} - -CONTEXT_LENGTH_FOR_VERSION = {"Guard-3": 131072, "3.2": 131072, "3.1": 131072, "3": 8192, "2": 4096, "1": 2048} - -BOS_ADDED_TOKEN = AddedToken( - "<|begin_of_text|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True -) -EOS_ADDED_TOKEN = AddedToken( - "<|end_of_text|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True -) -EOT_ADDED_TOKEN = AddedToken( - "<|eot_id|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True +import torch.nn.functional as F + +from transformers import ( + GenerationConfig, + MllamaConfig, + MllamaForConditionalGeneration, + MllamaImageProcessor, + PreTrainedTokenizerFast, ) - -DEFAULT_OPENAI_SPECIAL_TOKENS = { - "3": [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|reserved_special_token_2|>", - "<|reserved_special_token_3|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|reserved_special_token_4|>", - "<|eot_id|>", # end of turn - ] - + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)], - "3.1": [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|finetune_right_pad_id|>", - "<|reserved_special_token_2|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|eom_id|>", # end of message - "<|eot_id|>", # end of turn - "<|python_tag|>", - ] - + [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)], - "3.2": [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|finetune_right_pad_id|>", - "<|reserved_special_token_2|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|eom_id|>", # end of message - "<|eot_id|>", # end of turn - "<|python_tag|>", - ] - + [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)], - "Guard-3": [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|finetune_right_pad_id|>", - "<|reserved_special_token_2|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|eom_id|>", # end of message - "<|eot_id|>", # end of turn - "<|python_tag|>", - ] - + [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)], +from transformers.convert_slow_tokenizer import TikTokenConverter +# fmt: off +# If a weight needs to be split in two or more keys, use `|` to indicate it. ex: +# r"layers.(\d+).attention.wqkv.weight": r"layers.\1.self_attn.q|k|v|_proj.weight" +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + r"norm.weight": r"norm.weight", + r"unembedding.weight": r"lm_head.weight", + r"embedding": r"embed_tokens", + r"rope.freqs": None, # meaning we skip it and don't want it + # special key, wqkv needs to be split afterwards + r"block.(\d+).attn.qkv": r"layers.\1.self_attn.(q|k|v)_proj", + r"block.(\d+).attn.out": r"layers.\1.self_attn.\2_proj", + r"block.(\d+).attn.sinks": r"layers.\1.self_attn.sinks", + r"block.(\d+).attn.norm": r"layers.\1.input_layernorm.weight", + + r"block.(\d+).mlp.mlp1_weight": r"layers.\1.mlp.gate_up_proj.weight", + r"block.(\d+).mlp.mlp1_bias": r"layers.\1.mlp.gate_up_proj.bias", + r"block.(\d+).mlp.mlp2_weight": r"layers.\1.mlp.down_proj.weight", + r"block.(\d+).mlp.mlp2_bias": r"layers.\1.mlp.down_proj.bias", + r"block.(\d+).mlp.norm": r"layers.\1.post_attention_layernorm.weight", + r"block.(\d+).mlp.gate": r"layers.\1.mlp.router.weight", } +# fmt: on - -def is_openai_3(version): - return version in ["3", "3.1", "3.2", "Guard-3"] +CONTEXT_LENGTH = 131072 -def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): - return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) +def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None): + """ + This function should be applied only once, on the concatenated keys to efficiently rename using + the key mappings. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + if replacement is None: + new_text = re.sub(pattern, "", new_text) # an empty line + continue + new_text = re.sub(pattern, replacement, new_text) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict -def read_json(path): - with open(path, "r") as f: - return json.load(f) - - -def write_json(text, path): - with open(path, "w") as f: - json.dump(text, f) +def compute_intermediate_size(hidden_dim, multiple_of=1024, ffn_dim_multiplier=1.3): + hidden_dim = 4 * int(2 * hidden_dim / 3) + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + return hidden_dim def write_model( model_path, input_base_path, - model_size=None, + num_shards, safe_serialization=True, - openai_version="1", - vocab_size=None, - num_shards=None, instruct=False, - push_to_hub=False, ): - print("Converting the model.") - params = read_json(os.path.join(input_base_path, "params.json")) - num_shards = NUM_SHARDS[model_size] if num_shards is None else num_shards + os.makedirs(model_path, exist_ok=True) + + with open(os.path.join(input_base_path, "params.json"), "r") as f: + params = json.load(f) + params = params.get("model", params) - n_layers = params["n_layers"] - n_heads = params["n_heads"] - n_heads_per_shard = n_heads // num_shards - dim = params["dim"] - dims_per_head = dim // n_heads - base = params.get("rope_theta", 10000.0) - inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) - if base > 10000.0 and not is_openai_3(openai_version): - max_position_embeddings = 16384 - else: - max_position_embeddings = CONTEXT_LENGTH_FOR_VERSION[openai_version] + torch_dtype = "bfloat16" + + # ------------------------------------------------------------ + # Text model params and config + # ------------------------------------------------------------ + + # params from config + text_vocab_size = params["vocab_size"] + text_num_layers = params["n_layers"] + text_dim = params["dim"] + text_num_heads = params["n_heads"] + text_rms_norm_eps = params["norm_eps"] + text_rope_theta = params["rope_theta"] + cross_attention_num_layers = params["vision_num_block"] + + # some constants from original code + rope_scaling = { + "rope_type": "llama3", + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + } + max_position_embeddings = CONTEXT_LENGTH + + # compute additional params for weight conversion + text_num_heads_per_shard = text_num_heads // num_shards + text_dim_per_head = text_dim // text_num_heads + text_intermediate_size = compute_intermediate_size(text_dim, multiple_of=params["multiple_of"]) if params.get("n_kv_heads", None) is not None: - num_key_value_heads = params["n_kv_heads"] # for GQA / MQA - num_key_value_heads_per_shard = num_key_value_heads // num_shards - key_value_dim = dims_per_head * num_key_value_heads + text_num_key_value_heads = params["n_kv_heads"] # for GQA / MQA + text_num_key_value_heads_per_shard = text_num_key_value_heads // num_shards + text_key_value_dim = text_dim_per_head * text_num_key_value_heads else: # compatibility with other checkpoints - num_key_value_heads = n_heads - num_key_value_heads_per_shard = n_heads_per_shard - key_value_dim = dim - - # permute for sliced rotary - def permute(w, n_heads, dim1=dim, dim2=dim): - return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) - - with tempfile.TemporaryDirectory() as tmp_model_path: - print(f"Fetching all parameters from the checkpoint at {input_base_path}.") - # Load weights - if num_shards == 1: - # Not sharded - # (The sharded implementation would also work, but this is simpler.) - loaded = torch.load( - os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu", weights_only=True - ) - else: - # Sharded - checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")]) - print("Loading in order:", checkpoint_list) - loaded = [ - torch.load(os.path.join(input_base_path, file), map_location="cpu", weights_only=True) - for file in checkpoint_list - ] - param_count = 0 - index_dict = {"weight_map": {}} - for layer_i in range(n_layers): - filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" - if num_shards == 1: - # Unsharded - state_dict = { - f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( - loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads - ), - f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( - loaded[f"layers.{layer_i}.attention.wk.weight"], - n_heads=num_key_value_heads, - dim1=key_value_dim, - ), - f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], - f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], - f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], - f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], - f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], - f"model.layers.{layer_i}.input_layernorm.weight": loaded[ - f"layers.{layer_i}.attention_norm.weight" - ], - f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[ - f"layers.{layer_i}.ffn_norm.weight" - ], - } - else: - # Sharded - # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share - # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is - # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. - - state_dict = { - f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ - f"layers.{layer_i}.attention_norm.weight" - ].clone(), - f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ - f"layers.{layer_i}.ffn_norm.weight" - ].clone(), - } - state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( - torch.cat( - [ - loaded[i][f"layers.{layer_i}.attention.wq.weight"].view( - n_heads_per_shard, dims_per_head, dim - ) - for i in range(len(loaded)) - ], - dim=0, - ).reshape(dim, dim), - n_heads=n_heads, - ) - state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( - torch.cat( - [ - loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( - num_key_value_heads_per_shard, dims_per_head, dim - ) - for i in range(len(loaded)) - ], - dim=0, - ).reshape(key_value_dim, dim), - num_key_value_heads, - key_value_dim, - dim, - ) - state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( - [ - loaded[i][f"layers.{layer_i}.attention.wv.weight"].view( - num_key_value_heads_per_shard, dims_per_head, dim - ) - for i in range(len(loaded)) - ], - dim=0, - ).reshape(key_value_dim, dim) - - state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(len(loaded))], dim=1 - ) - state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(len(loaded))], dim=0 - ) - state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(len(loaded))], dim=1 - ) - state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(len(loaded))], dim=0 - ) - - state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq - for k, v in state_dict.items(): - index_dict["weight_map"][k] = filename - param_count += v.numel() - torch.save(state_dict, os.path.join(tmp_model_path, filename)) - - filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" - if num_shards == 1: - # Unsharded - state_dict = { - "model.embed_tokens.weight": loaded["tok_embeddings.weight"], - "model.norm.weight": loaded["norm.weight"], - "lm_head.weight": loaded["output.weight"], - } + text_num_key_value_heads = text_num_heads + text_num_key_value_heads_per_shard = text_num_heads_per_shard + text_key_value_dim = text_dim + + # cross-attention layers: 20 for 90B, 8 for 11B + cross_attention_frequency = math.ceil(text_num_layers / cross_attention_num_layers) + text_num_total_layers = text_num_layers + cross_attention_num_layers + block_shift = list( + range(cross_attention_frequency - 1, text_num_total_layers, cross_attention_frequency + 1) + ) + self_attention_layers_shift = [k for k in range(text_num_total_layers) if k not in block_shift] + + bos_token_id = 128000 + eos_token_id = [128001, 128008, 128009] if instruct else 128001 + pad_token_id = 128004 + + text_config = MllamaTextConfig( + num_attention_heads=text_num_heads, + vocab_size=text_vocab_size, + hidden_size=text_dim, + rms_norm_eps=text_rms_norm_eps, + rope_theta=text_rope_theta, + num_hidden_layers=text_num_total_layers, + block=block_shift, + intermediate_size=text_intermediate_size, + max_position_embeddings=max_position_embeddings, + rope_scaling=rope_scaling, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=False, # Constant set to False + torch_dtype=torch_dtype, + ) + + # ------------------------------------------------------------ + # Vision model params and config + # ------------------------------------------------------------ + + # params from config + vision_tile_size = params["vision_chunk_size"] + vision_max_num_tiles = params["vision_max_num_chunks"] + + # some constants from original code + vision_patch_size = 14 + vision_num_channels = 3 + vision_num_layers = 32 + vision_num_layers_global = 8 + vision_dim = 1280 + vision_num_heads = 16 + vision_intermediate_layers_indices = [3, 7, 15, 23, 30] + + # compute additional params for weight conversion + vision_dim_per_head = vision_dim // vision_num_heads + vision_num_heads_per_shard = vision_num_heads // num_shards + vision_intermediate_size = vision_dim * 4 + vision_supported_aspect_ratios = get_all_supported_aspect_ratios(vision_max_num_tiles) + + vision_config = MllamaVisionConfig( + hidden_size=vision_dim, + patch_size=vision_patch_size, + num_channels=vision_num_channels, + intermediate_size=vision_intermediate_size, + num_hidden_layers=vision_num_layers, + num_attention_heads=vision_num_heads, + num_global_layers=vision_num_layers_global, + intermediate_layers_indices=vision_intermediate_layers_indices, + image_size=vision_tile_size, + max_num_tiles=vision_max_num_tiles, + supported_aspect_ratios=vision_supported_aspect_ratios, + torch_dtype=torch_dtype, + ) + + # save config + config = MllamaConfig(vision_config=vision_config, text_config=text_config, torch_dtype=torch_dtype) + config.architectures = ["MllamaForConditionalGeneration"] + config.save_pretrained(model_path) + print("Model config saved successfully...") + + # ------------------------------------------------------------ + # Convert weights + # ------------------------------------------------------------ + + print(f"Fetching all parameters from the checkpoint at {input_base_path}...") + if num_shards == 1: + if os.path.exists(os.path.join(input_base_path, "consolidated.00.pth")): + path = os.path.join(input_base_path, "consolidated.00.pth") else: - concat_dim = 0 if is_openai_3(openai_version) else 1 - state_dict = { - "model.norm.weight": loaded[0]["norm.weight"], - "model.embed_tokens.weight": torch.cat( - [loaded[i]["tok_embeddings.weight"] for i in range(len(loaded))], dim=concat_dim - ), - "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(len(loaded))], dim=0), - } - - for k, v in state_dict.items(): - index_dict["weight_map"][k] = filename - param_count += v.numel() - torch.save(state_dict, os.path.join(tmp_model_path, filename)) - - # Write configs - index_dict["metadata"] = {"total_size": param_count * 2} - write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) - ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 - multiple_of = params["multiple_of"] if "multiple_of" in params else 256 - - if is_openai_3(openai_version): - bos_token_id = 128000 - - if instruct: - eos_token_id = [128001, 128008, 128009] + path = os.path.join(input_base_path, "consolidated.pth") + loaded = [torch.load(path, map_location="cpu", mmap=True, weights_only=True)] + else: + loaded = [ + torch.load( + os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), + map_location="cpu", + mmap=True, + weights_only=True, + ) + for i in range(num_shards) + ] + + print("Converting ..") + all_keys = list(loaded[0].keys()) + new_keys = convert_old_keys_to_new_keys(all_keys) + + state_dict = {} + for key in all_keys: + new_key = new_keys[key] + + # In the original model, self-attention layers and cross-attention layers are different lists of layers. + # In the converted model, they are merged into one list with corresponding index shift to preserve the order. + if ("cross_attention" in key or "layers" in key) and "language_model" in new_key: + shift = block_shift if "cross_attention" in key else self_attention_layers_shift + new_key = re.sub(r"layers.(\d+).", lambda _match: f"layers.{shift[int(_match.groups()[0])]}.", new_key) + + current_parameter = [chunk.pop(key).contiguous().clone() for chunk in loaded] + if not is_param_different_across_shards(new_key): + current_parameter = current_parameter[0] + + concat_dim = get_concat_dim(new_key) + + # Post-process the current_parameter. + if re.search("(k|v|q)_proj.weight", new_key) and "language_model" in new_key: + if "q_proj" in new_key: + param_num_heads = text_num_heads + param_num_head_per_shard = text_num_heads_per_shard + param_dim = text_dim else: - eos_token_id = 128001 - else: - bos_token_id = 1 - eos_token_id = 2 - - if openai_version in ["3.1", "3.2", "Guard-3"]: - rope_scaling = { - "factor": 32.0 if openai_version == "3.2" else 8.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_max_position_embeddings": 8192, - "rope_type": "openai3", - } - else: - rope_scaling = None - - config = OpenaiConfig( - hidden_size=dim, - intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), - num_attention_heads=params["n_heads"], - num_hidden_layers=params["n_layers"], - rms_norm_eps=params["norm_eps"], - num_key_value_heads=num_key_value_heads, - vocab_size=vocab_size, - rope_theta=base, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=True if openai_version in ["3.2"] else False, - ) + param_num_heads = text_num_key_value_heads + param_num_head_per_shard = text_num_key_value_heads_per_shard + param_dim = text_key_value_dim + shards = [param.view(param_num_head_per_shard, text_dim_per_head, text_dim) for param in current_parameter] + current_parameter = torch.cat(shards, dim=concat_dim) + if "self_attn" not in new_key and "v_proj.weight" not in new_key: + current_parameter = permute_for_rope(current_parameter, param_num_heads, param_dim, text_dim) + state_dict[new_key] = current_parameter.reshape(param_num_heads * text_dim_per_head, text_dim) + + elif "vision_model" in new_key and re.search("(k|v|q)_proj", new_key): + shards = [ + param.view(vision_num_heads_per_shard, vision_dim_per_head, vision_dim) for param in current_parameter + ] + param = torch.cat(shards, dim=concat_dim) + state_dict[new_key] = param.reshape(vision_num_heads * vision_dim_per_head, vision_dim) + + elif new_key == "vision_patch_embedding.weight": + current_parameter = torch.cat(current_parameter, dim=concat_dim) + state_dict[new_key] = current_parameter.reshape( + -1, vision_num_channels, vision_patch_size, vision_patch_size + ) - config.save_pretrained(tmp_model_path) + elif new_key.endswith("gate"): + state_dict[new_key] = current_parameter[0].view(1) + elif "vision_gated_positional_embedding.embedding" in new_key: + current_parameter = interpolate_positional_embedding( + current_parameter, vision_tile_size, vision_patch_size + ) + state_dict[new_key] = current_parameter + + elif "vision_gated_positional_embedding.tile_embedding.weight" in new_key: + current_parameter = current_parameter.permute(2, 0, 1, 3).flatten(1) + current_parameter = interpolate_positional_embedding( + current_parameter, vision_tile_size, vision_patch_size + ) + current_parameter = current_parameter.reshape( + -1, vision_max_num_tiles, vision_max_num_tiles, vision_dim + ).permute(1, 2, 0, 3) + state_dict[new_key] = pre_compute_positional_embedding(current_parameter) + + elif "tile_positional_embedding.embedding" in new_key: + state_dict[new_key] = pre_compute_positional_embedding(current_parameter) + + elif new_key != "": + if isinstance(current_parameter, list): + current_parameter = torch.cat(current_parameter, dim=concat_dim) + state_dict[new_key] = current_parameter + + state_dict["embed_tokens.weight"] = torch.cat( + [ + state_dict["embed_tokens.weight"], + state_dict.pop("learnable_embedding.weight"), + ], + dim=0, + ) + del loaded + gc.collect() + + print("Loading the checkpoint in a Mllama ") + with torch.device("meta"): + model = MllamaForConditionalGeneration(config) + load_state_dict(state_dict, strict=True, assign=True) + print("Checkpoint loaded successfully.") + del config._name_or_path + + print("Saving the ") + save_pretrained(model_path, safe_serialization=safe_serialization) + del state_dict, model + + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + MllamaForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") + print("Model reloaded successfully.") + + # generation config + if instruct: + print("Saving generation config...") generation_config = GenerationConfig( do_sample=True, temperature=0.6, top_p=0.9, bos_token_id=bos_token_id, eos_token_id=eos_token_id, + pad_token_id=pad_token_id, ) - generation_config.save_pretrained(tmp_model_path) - - # Make space so we can load the model properly now. - del state_dict - del loaded - gc.collect() - - print("Loading the checkpoint in a Openai model.") - model = OpenaiForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) - - # Avoid saving this as part of the config. - del model.config._name_or_path - model.config.torch_dtype = torch.float16 - - print("Saving in the Transformers format.") - if push_to_hub: - print("Pushing to the hub.") - model.push_to_hub(model_path, safe_serialization=safe_serialization, private=True, use_temp_dir=True) - else: - print("Saving to disk.") - model.save_pretrained(model_path, safe_serialization=safe_serialization) - - -class Openai3Converter(TikTokenConverter): - def __init__(self, vocab_file, special_tokens=None, instruct=False, openai_version="3.2", **kwargs): - super().__init__(vocab_file, additional_special_tokens=special_tokens, **kwargs) + generation_config.save_pretrained(model_path) + + +class MllamaConverter(TikTokenConverter): + def __init__( + self, + vocab_file, + special_tokens: List[str], + pattern: str, + model_max_length: int, + chat_template: Optional[str] = None, + **kwargs, + ): + super().__init__(vocab_file, pattern=pattern) + self.additional_special_tokens = special_tokens tokenizer = self.converted() - - # References for chat templates in instruct models - templates_for_version = { - "2": ("meta-openai/Openai-2-7b-chat-hf", "f5db02db724555f92da89c216ac04704f23d4590"), - "3": ("meta-openai/Meta-Openai-3-8B-Instruct", "5f0b02c75b57c5855da9ae460ce51323ea669d8a"), - "3.1": ("meta-openai/Openai-3.1-8B-Instruct", "0e9e39f249a16976918f6564b8830bc894c89659"), - "3.2": ("meta-openai/Openai-3.2-1B-Instruct", "e9f8effbab1cbdc515c11ee6e098e3d5a9f51e14"), - "Guard-3": ("meta-openai/Openai-Guard-3-1B", "acf7aafa60f0410f8f42b1fa35e077d705892029"), - } - - # Add chat_template only if instruct is True. - # Prevents a null chat_template, which triggers - # a parsing warning in the Hub. - additional_kwargs = {} - if instruct or openai_version in ["Guard-3"]: - model_id, revision = templates_for_version.get(openai_version, (None, None)) - if model_id is not None: - from transformers import AutoTokenizer - - t = AutoTokenizer.from_pretrained(model_id, revision=revision) - additional_kwargs["chat_template"] = t.chat_template - - self.converted_tokenizer = PreTrainedTokenizerFast( + if chat_template is not None: + kwargs["chat_template"] = chat_template + self.tokenizer = PreTrainedTokenizerFast( tokenizer_object=tokenizer, - bos_token="<|begin_of_text|>", - eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>", model_input_names=["input_ids", "attention_mask"], - model_max_length=CONTEXT_LENGTH_FOR_VERSION[openai_version], - clean_up_tokenization_spaces=True, - **additional_kwargs, - ) - self.update_post_processor(self.converted_tokenizer) - # finer special_tokens_map.json - self.converted_tokenizer._bos_token = BOS_ADDED_TOKEN - self.converted_tokenizer._eos_token = EOT_ADDED_TOKEN if instruct else EOS_ADDED_TOKEN - - # We can't do this while building the tokenizer because we have no easy access to the bos token id - def update_post_processor(self, tokenizer): - tokenizer._tokenizer.post_processor = processors.Sequence( - [ - processors.ByteLevel(trim_offsets=False), - processors.TemplateProcessing( - single="<|begin_of_text|> $A", - pair="<|begin_of_text|>:0 $A:0 <|begin_of_text|>:1 $B:1", - special_tokens=[ - ("<|begin_of_text|>", tokenizer.convert_tokens_to_ids("<|begin_of_text|>")), - ], - ), - ] + model_max_length=model_max_length, + **kwargs, ) -def write_tokenizer( - tokenizer_path, input_tokenizer_path, openai_version="2", special_tokens=None, instruct=False, push_to_hub=False -): - print("Converting the tokenizer.") - tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast - if is_openai_3(openai_version): - tokenizer = Openai3Converter( - input_tokenizer_path, - special_tokens, - instruct, - openai_version, - ).converted_tokenizer - else: - try: - tokenizer = tokenizer_class(input_tokenizer_path) - except Exception: - raise ValueError( - "Failed to instantiate tokenizer. Please, make sure you have sentencepiece and protobuf installed." - ) +def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): + model_max_length = CONTEXT_LENGTH + pattern = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: W605 - if push_to_hub: - print(f"Pushing a {tokenizer_class.__name__} to the Hub repo - {tokenizer_path}.") - tokenizer.push_to_hub(tokenizer_path, private=True, use_temp_dir=True) - else: - print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") - tokenizer.save_pretrained(tokenizer_path) - return tokenizer + # Special tokens + num_reserved_special_tokens = 256 + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|step_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + ] + special_tokens += [ + f"<|reserved_special_token_{i + 2}|>" for i in range(num_reserved_special_tokens - len(special_tokens)) + ] + # original tokenizer has <|image|> with 128011 token_id, + # however, later in the code it is replaced with 128256 token_id + special_tokens.append("<|image|>") + + # Chat template + chat_template = ( + "{% for message in messages %}" + "{% if loop.index0 == 0 %}" + "{{ bos_token }}" + "{% endif %}" + "{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}" + "{% if message['content'] is string %}" + "{{ message['content'] }}" + "{% else %}" + "{% for content in message['content'] %}" + "{% if content['type'] == 'image' %}" + "{{ '<|image|>' }}" + "{% elif content['type'] == 'text' %}" + "{{ content['text'] }}" + "{% endif %}" + "{% endfor %}" + "{% endif %}" + "{{ '<|eot_id|>' }}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" + "{% endif %}" + ) + + converter = MllamaConverter( + vocab_file=tokenizer_path, + pattern=pattern, + special_tokens=special_tokens, + model_max_length=model_max_length, + chat_template=chat_template if instruct else None, + bos_token="<|begin_of_text|>", + eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>", + pad_token="<|finetune_right_pad_id|>", + ) + tokenizer = converter.tokenizer + tokenizer.save_pretrained(save_dir) + + if instruct: + print("Saving chat template...") + chat_template_path = os.path.join(save_dir, "chat_template.json") + with open(chat_template_path, "w") as f: + json.dump({"chat_template": chat_template}, f, indent=2) + + +def write_image_processor(config_path: str, save_dir: str): + with open(config_path, "r") as f: + params = json.load(f) + + tile_size = params["vision_chunk_size"] + max_image_tiles = params["vision_max_num_chunks"] + + image_processor = MllamaImageProcessor( + do_resize=True, + size={"height": tile_size, "width": tile_size}, + do_rescale=True, + rescale_factor=1 / 255, + do_normalize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + do_pad=True, + max_image_tiles=max_image_tiles, + ) + + image_processor.save_pretrained(save_dir) def main(): parser = argparse.ArgumentParser() parser.add_argument( "--input_dir", - help="Location of Openai weights, which contains tokenizer.model and model folders", - ) - parser.add_argument( - "--model_size", - default=None, - help="'f' Deprecated in favor of `num_shards`: models correspond to the finetuned versions, and are specific to the Openai2 official release. For more details on Openai2, checkout the original repo: https://huggingface.co/meta-openai", + default="Llama-3.2-11B-Vision/original", + help="Location of LLaMA weights, which contains tokenizer.model and model folders", ) parser.add_argument( "--output_dir", + default="Llama-3.2-11B-Vision", help="Location to write HF model and tokenizer", ) parser.add_argument( - "--push_to_hub", - help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.", - action="store_true", - default=False, - ) - parser.add_argument( - "--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`." + "--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`." ) - # Different Openai versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. parser.add_argument( - "--openai_version", - choices=["1", "2", "3", "3.1", "3.2", "Guard-3"], - default="1", - type=str, - help="Version of the Openai model to convert. Currently supports Openai1 and Openai2. Controls the context size", + "--special_tokens", + default=None, + type=List[str], + help="The list of special tokens that should be added to the ", ) parser.add_argument( "--num_shards", - default=None, + default=1, type=int, - help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth", - ) - parser.add_argument( - "--special_tokens", - default=None, - type=List[str], - help="The list of special tokens that should be added to the model.", + help="The number of individual shards used for the Does not have to be the same as the number of consolidated_xx.pth", ) parser.add_argument( "--instruct", action="store_true", - default=False, - help="Whether the model is an instruct model or not. Will affect special tokens and chat template.", + help="Whether the model is an instruct model", ) args = parser.parse_args() - if args.model_size is None and args.num_shards is None: - raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`") - if args.special_tokens is None: - # no special tokens by default - args.special_tokens = DEFAULT_OPENAI_SPECIAL_TOKENS.get(str(args.openai_version), []) - - spm_path = os.path.join(args.input_dir, "tokenizer.model") - vocab_size = len( - write_tokenizer( - args.output_dir, - spm_path, - openai_version=args.openai_version, - special_tokens=args.special_tokens, - instruct=args.instruct, - push_to_hub=args.push_to_hub, - ) + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + safe_serialization=args.safe_serialization, + num_shards=args.num_shards, + instruct=args.instruct, ) - if args.model_size != "tokenizer_only": - write_model( - model_path=args.output_dir, - input_base_path=args.input_dir, - model_size=args.model_size, - safe_serialization=args.safe_serialization, - openai_version=args.openai_version, - vocab_size=vocab_size, - num_shards=args.num_shards, - instruct=args.instruct, - push_to_hub=args.push_to_hub, - ) + write_tokenizer( + tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"), + save_dir=args.output_dir, + instruct=args.instruct, + ) + + write_image_processor( + config_path=os.path.join(args.input_dir, "params.json"), + save_dir=args.output_dir, + ) if __name__ == "__main__": From 16ed546b402328358968af4dc2c0020dc7e8801c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:42:03 +0000 Subject: [PATCH 005/342] clean conversion script --- .../openai/convert_openai_weights_to_hf.py | 241 ++---------------- .../models/openai/modeling_openai.py | 5 - 2 files changed, 25 insertions(+), 221 deletions(-) diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index 18955cc6e1f0..3405bf3e96dc 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -16,19 +16,19 @@ import argparse import gc import json -import math + import os from typing import List, Optional import regex as re import torch import torch.nn.functional as F - +import tqdm +from safetensors.torch import load_file as safe_load from transformers import ( GenerationConfig, - MllamaConfig, - MllamaForConditionalGeneration, - MllamaImageProcessor, + OpenaiConfig, + OpenaiForCausalLM, PreTrainedTokenizerFast, ) from transformers.convert_slow_tokenizer import TikTokenConverter @@ -41,7 +41,7 @@ r"embedding": r"embed_tokens", r"rope.freqs": None, # meaning we skip it and don't want it # special key, wqkv needs to be split afterwards - r"block.(\d+).attn.qkv": r"layers.\1.self_attn.(q|k|v)_proj", + r"block.(\d+).attn.qkv": r"layers.\1.self_attn.(k|v|q)_proj", r"block.(\d+).attn.out": r"layers.\1.self_attn.\2_proj", r"block.(\d+).attn.sinks": r"layers.\1.self_attn.sinks", r"block.(\d+).attn.norm": r"layers.\1.input_layernorm.weight", @@ -91,145 +91,18 @@ def write_model( instruct=False, ): os.makedirs(model_path, exist_ok=True) - - with open(os.path.join(input_base_path, "params.json"), "r") as f: - params = json.load(f) - - params = params.get("model", params) torch_dtype = "bfloat16" - # ------------------------------------------------------------ - # Text model params and config - # ------------------------------------------------------------ - - # params from config - text_vocab_size = params["vocab_size"] - text_num_layers = params["n_layers"] - text_dim = params["dim"] - text_num_heads = params["n_heads"] - text_rms_norm_eps = params["norm_eps"] - text_rope_theta = params["rope_theta"] - cross_attention_num_layers = params["vision_num_block"] - - # some constants from original code - rope_scaling = { - "rope_type": "llama3", - "factor": 8.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_max_position_embeddings": 8192, - } - max_position_embeddings = CONTEXT_LENGTH - - # compute additional params for weight conversion - text_num_heads_per_shard = text_num_heads // num_shards - text_dim_per_head = text_dim // text_num_heads - text_intermediate_size = compute_intermediate_size(text_dim, multiple_of=params["multiple_of"]) - - if params.get("n_kv_heads", None) is not None: - text_num_key_value_heads = params["n_kv_heads"] # for GQA / MQA - text_num_key_value_heads_per_shard = text_num_key_value_heads // num_shards - text_key_value_dim = text_dim_per_head * text_num_key_value_heads - else: # compatibility with other checkpoints - text_num_key_value_heads = text_num_heads - text_num_key_value_heads_per_shard = text_num_heads_per_shard - text_key_value_dim = text_dim - - # cross-attention layers: 20 for 90B, 8 for 11B - cross_attention_frequency = math.ceil(text_num_layers / cross_attention_num_layers) - text_num_total_layers = text_num_layers + cross_attention_num_layers - block_shift = list( - range(cross_attention_frequency - 1, text_num_total_layers, cross_attention_frequency + 1) - ) - self_attention_layers_shift = [k for k in range(text_num_total_layers) if k not in block_shift] - bos_token_id = 128000 eos_token_id = [128001, 128008, 128009] if instruct else 128001 pad_token_id = 128004 - text_config = MllamaTextConfig( - num_attention_heads=text_num_heads, - vocab_size=text_vocab_size, - hidden_size=text_dim, - rms_norm_eps=text_rms_norm_eps, - rope_theta=text_rope_theta, - num_hidden_layers=text_num_total_layers, - block=block_shift, - intermediate_size=text_intermediate_size, - max_position_embeddings=max_position_embeddings, - rope_scaling=rope_scaling, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - tie_word_embeddings=False, # Constant set to False - torch_dtype=torch_dtype, - ) - - # ------------------------------------------------------------ - # Vision model params and config - # ------------------------------------------------------------ - - # params from config - vision_tile_size = params["vision_chunk_size"] - vision_max_num_tiles = params["vision_max_num_chunks"] - - # some constants from original code - vision_patch_size = 14 - vision_num_channels = 3 - vision_num_layers = 32 - vision_num_layers_global = 8 - vision_dim = 1280 - vision_num_heads = 16 - vision_intermediate_layers_indices = [3, 7, 15, 23, 30] - - # compute additional params for weight conversion - vision_dim_per_head = vision_dim // vision_num_heads - vision_num_heads_per_shard = vision_num_heads // num_shards - vision_intermediate_size = vision_dim * 4 - vision_supported_aspect_ratios = get_all_supported_aspect_ratios(vision_max_num_tiles) - - vision_config = MllamaVisionConfig( - hidden_size=vision_dim, - patch_size=vision_patch_size, - num_channels=vision_num_channels, - intermediate_size=vision_intermediate_size, - num_hidden_layers=vision_num_layers, - num_attention_heads=vision_num_heads, - num_global_layers=vision_num_layers_global, - intermediate_layers_indices=vision_intermediate_layers_indices, - image_size=vision_tile_size, - max_num_tiles=vision_max_num_tiles, - supported_aspect_ratios=vision_supported_aspect_ratios, - torch_dtype=torch_dtype, - ) - - # save config - config = MllamaConfig(vision_config=vision_config, text_config=text_config, torch_dtype=torch_dtype) - config.architectures = ["MllamaForConditionalGeneration"] - config.save_pretrained(model_path) - print("Model config saved successfully...") - - # ------------------------------------------------------------ - # Convert weights - # ------------------------------------------------------------ + config = OpenaiConfig.from_pretrained(input_base_path) + print(f"Fetching all parameters from the checkpoint at {input_base_path}...") - if num_shards == 1: - if os.path.exists(os.path.join(input_base_path, "consolidated.00.pth")): - path = os.path.join(input_base_path, "consolidated.00.pth") - else: - path = os.path.join(input_base_path, "consolidated.pth") - loaded = [torch.load(path, map_location="cpu", mmap=True, weights_only=True)] - else: - loaded = [ - torch.load( - os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), - map_location="cpu", - mmap=True, - weights_only=True, - ) - for i in range(num_shards) - ] + + loaded = [safe_load(file for file in tqdm(os.listdir(model_path), desc="Loading shards", unit="shard") if file.endswith(".safetensors"))] print("Converting ..") all_keys = list(loaded[0].keys()) @@ -237,101 +110,37 @@ def write_model( state_dict = {} for key in all_keys: - new_key = new_keys[key] - - # In the original model, self-attention layers and cross-attention layers are different lists of layers. - # In the converted model, they are merged into one list with corresponding index shift to preserve the order. - if ("cross_attention" in key or "layers" in key) and "language_model" in new_key: - shift = block_shift if "cross_attention" in key else self_attention_layers_shift - new_key = re.sub(r"layers.(\d+).", lambda _match: f"layers.{shift[int(_match.groups()[0])]}.", new_key) - - current_parameter = [chunk.pop(key).contiguous().clone() for chunk in loaded] - if not is_param_different_across_shards(new_key): - current_parameter = current_parameter[0] - - concat_dim = get_concat_dim(new_key) - # Post-process the current_parameter. + new_key = new_keys.get(key, key) if re.search("(k|v|q)_proj.weight", new_key) and "language_model" in new_key: - if "q_proj" in new_key: - param_num_heads = text_num_heads - param_num_head_per_shard = text_num_heads_per_shard - param_dim = text_dim - else: - param_num_heads = text_num_key_value_heads - param_num_head_per_shard = text_num_key_value_heads_per_shard - param_dim = text_key_value_dim - shards = [param.view(param_num_head_per_shard, text_dim_per_head, text_dim) for param in current_parameter] - current_parameter = torch.cat(shards, dim=concat_dim) - if "self_attn" not in new_key and "v_proj.weight" not in new_key: - current_parameter = permute_for_rope(current_parameter, param_num_heads, param_dim, text_dim) - state_dict[new_key] = current_parameter.reshape(param_num_heads * text_dim_per_head, text_dim) - - elif "vision_model" in new_key and re.search("(k|v|q)_proj", new_key): - shards = [ - param.view(vision_num_heads_per_shard, vision_dim_per_head, vision_dim) for param in current_parameter - ] - param = torch.cat(shards, dim=concat_dim) - state_dict[new_key] = param.reshape(vision_num_heads * vision_dim_per_head, vision_dim) - - elif new_key == "vision_patch_embedding.weight": - current_parameter = torch.cat(current_parameter, dim=concat_dim) - state_dict[new_key] = current_parameter.reshape( - -1, vision_num_channels, vision_patch_size, vision_patch_size - ) - - elif new_key.endswith("gate"): - state_dict[new_key] = current_parameter[0].view(1) - - elif "vision_gated_positional_embedding.embedding" in new_key: - current_parameter = interpolate_positional_embedding( - current_parameter, vision_tile_size, vision_patch_size - ) - state_dict[new_key] = current_parameter - - elif "vision_gated_positional_embedding.tile_embedding.weight" in new_key: - current_parameter = current_parameter.permute(2, 0, 1, 3).flatten(1) - current_parameter = interpolate_positional_embedding( - current_parameter, vision_tile_size, vision_patch_size - ) - current_parameter = current_parameter.reshape( - -1, vision_max_num_tiles, vision_max_num_tiles, vision_dim - ).permute(1, 2, 0, 3) - state_dict[new_key] = pre_compute_positional_embedding(current_parameter) - - elif "tile_positional_embedding.embedding" in new_key: - state_dict[new_key] = pre_compute_positional_embedding(current_parameter) - - elif new_key != "": - if isinstance(current_parameter, list): - current_parameter = torch.cat(current_parameter, dim=concat_dim) - state_dict[new_key] = current_parameter - - state_dict["embed_tokens.weight"] = torch.cat( - [ - state_dict["embed_tokens.weight"], - state_dict.pop("learnable_embedding.weight"), - ], - dim=0, - ) + q, k , v = loaded[0][key].chunk(3, dim=-1) + q_key = re.sub(r"(k|v|q)_proj.weight", "q_proj.weight", new_key) + state_dict[q_key] = q + k_key = re.sub(r"(k|v|q)_proj.weight", "k_proj.weight", new_key) + v_key = re.sub(r"(k|v|q)_proj.weight", "v_proj.weight", new_key) + state_dict[k_key] = k + state_dict[v_key] = v + else: + state_dict[new_key] = loaded[0][key] + del loaded gc.collect() print("Loading the checkpoint in a Mllama ") with torch.device("meta"): - model = MllamaForConditionalGeneration(config) - load_state_dict(state_dict, strict=True, assign=True) + model = OpenaiForCausalLM(config) + model.load_state_dict(state_dict, strict=True, assign=True) print("Checkpoint loaded successfully.") del config._name_or_path print("Saving the ") - save_pretrained(model_path, safe_serialization=safe_serialization) + model.save_pretrained(model_path, safe_serialization=safe_serialization) del state_dict, model # Safety check: reload the converted model gc.collect() print("Reloading the model to check if it's saved correctly.") - MllamaForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") + OpenaiForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") print("Model reloaded successfully.") # generation config diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 53b986b00a95..d64fa4ae583f 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -163,11 +163,6 @@ class OpenaiPreTrainedModel(LlamaPreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["OpenaiDecoderLayer"] - _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_attention_backend = True class OpenaiModel(LlamaModel): From 77a5529d0c43dc18590117ab9c52abdb29d8ca23 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:44:24 +0000 Subject: [PATCH 006/342] sytling and updates --- .../integrations/flex_attention.py | 1 + .../models/auto/configuration_auto.py | 4 +- src/transformers/models/auto/modeling_auto.py | 10 +-- .../models/auto/tokenization_auto.py | 14 +-- .../models/openai/configuration_openai.py | 3 +- .../openai/convert_openai_weights_to_hf.py | 85 ++++--------------- .../models/openai/modeling_openai.py | 47 +++++----- tests/models/openai/test_modeling_openai.py | 4 +- 8 files changed, 61 insertions(+), 107 deletions(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 98dbef5529ef..fe20791a9056 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -214,6 +214,7 @@ def flex_attention_forward( scaling: Optional[float] = None, softcap: Optional[float] = None, head_mask: Optional[torch.Tensor] = None, + score_mod: Optional[callable] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: block_mask = None diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 3e812e3962fb..aadf21a90a83 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -182,7 +182,6 @@ ("levit", "LevitConfig"), ("lilt", "LiltConfig"), ("llama", "LlamaConfig"), - ("openai", "OpenaiConfig"), ("llama4", "Llama4Config"), ("llama4_text", "Llama4TextConfig"), ("llava", "LlavaConfig"), @@ -239,6 +238,7 @@ ("omdet-turbo", "OmDetTurboConfig"), ("oneformer", "OneFormerConfig"), ("open-llama", "OpenLlamaConfig"), + ("openai", "OpenaiConfig"), ("openai-gpt", "OpenAIGPTConfig"), ("opt", "OPTConfig"), ("owlv2", "Owlv2Config"), @@ -551,7 +551,6 @@ ("levit", "LeViT"), ("lilt", "LiLT"), ("llama", "LLaMA"), - ("openai", "openai"), ("llama2", "Llama2"), ("llama3", "Llama3"), ("llama4", "Llama4"), @@ -618,6 +617,7 @@ ("omdet-turbo", "OmDet-Turbo"), ("oneformer", "OneFormer"), ("open-llama", "OpenLlama"), + ("openai", "openai"), ("openai-gpt", "OpenAI GPT"), ("opt", "OPT"), ("owlv2", "OWLv2"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index cc9395c521a6..3881dcffa88b 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -177,7 +177,6 @@ ("levit", "LevitModel"), ("lilt", "LiltModel"), ("llama", "LlamaModel"), - ("openai", "OpenaiModel"), ("llama4", "Llama4ForConditionalGeneration"), ("llama4_text", "Llama4TextModel"), ("llava", "LlavaModel"), @@ -233,6 +232,7 @@ ("omdet-turbo", "OmDetTurboForObjectDetection"), ("oneformer", "OneFormerModel"), ("open-llama", "OpenLlamaModel"), + ("openai", "OpenaiModel"), ("openai-gpt", "OpenAIGPTModel"), ("opt", "OPTModel"), ("owlv2", "Owlv2Model"), @@ -586,7 +586,6 @@ ("jamba", "JambaForCausalLM"), ("jetmoe", "JetMoeForCausalLM"), ("llama", "LlamaForCausalLM"), - ("openai", "OpenaiForCausalLM"), ("llama4", "Llama4ForCausalLM"), ("llama4_text", "Llama4ForCausalLM"), ("mamba", "MambaForCausalLM"), @@ -608,6 +607,7 @@ ("olmo2", "Olmo2ForCausalLM"), ("olmoe", "OlmoeForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"), + ("openai", "OpenaiForCausalLM"), ("openai-gpt", "OpenAIGPTLMHeadModel"), ("opt", "OPTForCausalLM"), ("pegasus", "PegasusForCausalLM"), @@ -1101,7 +1101,6 @@ ("led", "LEDForSequenceClassification"), ("lilt", "LiltForSequenceClassification"), ("llama", "LlamaForSequenceClassification"), - ("openai", "OpenaiForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"), ("luke", "LukeForSequenceClassification"), ("markuplm", "MarkupLMForSequenceClassification"), @@ -1121,6 +1120,7 @@ ("nezha", "NezhaForSequenceClassification"), ("nystromformer", "NystromformerForSequenceClassification"), ("open-llama", "OpenLlamaForSequenceClassification"), + ("openai", "OpenaiForSequenceClassification"), ("openai-gpt", "OpenAIGPTForSequenceClassification"), ("opt", "OPTForSequenceClassification"), ("perceiver", "PerceiverForSequenceClassification"), @@ -1192,7 +1192,6 @@ ("led", "LEDForQuestionAnswering"), ("lilt", "LiltForQuestionAnswering"), ("llama", "LlamaForQuestionAnswering"), - ("openai", "OpenaiForQuestionAnswering"), ("longformer", "LongformerForQuestionAnswering"), ("luke", "LukeForQuestionAnswering"), ("lxmert", "LxmertForQuestionAnswering"), @@ -1212,6 +1211,7 @@ ("nemotron", "NemotronForQuestionAnswering"), ("nezha", "NezhaForQuestionAnswering"), ("nystromformer", "NystromformerForQuestionAnswering"), + ("openai", "OpenaiForQuestionAnswering"), ("opt", "OPTForQuestionAnswering"), ("qdqbert", "QDQBertForQuestionAnswering"), ("qwen2", "Qwen2ForQuestionAnswering"), @@ -1301,7 +1301,6 @@ ("layoutlmv3", "LayoutLMv3ForTokenClassification"), ("lilt", "LiltForTokenClassification"), ("llama", "LlamaForTokenClassification"), - ("openai", "OpenaiForTokenClassification"), ("longformer", "LongformerForTokenClassification"), ("luke", "LukeForTokenClassification"), ("markuplm", "MarkupLMForTokenClassification"), @@ -1318,6 +1317,7 @@ ("nemotron", "NemotronForTokenClassification"), ("nezha", "NezhaForTokenClassification"), ("nystromformer", "NystromformerForTokenClassification"), + ("openai", "OpenaiForTokenClassification"), ("persimmon", "PersimmonForTokenClassification"), ("phi", "PhiForTokenClassification"), ("phi3", "Phi3ForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 9dcd70c27e85..c5fe3358df4d 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -296,13 +296,6 @@ "LlamaTokenizerFast" if is_tokenizers_available() else None, ), ), - ( - "openai", - ( - None, - "PreTrainedTokenizerFast" if is_tokenizers_available() else None, - ), - ), ( "llama4", ( @@ -417,6 +410,13 @@ ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None), ), ("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), + ( + "openai", + ( + None, + "PreTrainedTokenizerFast" if is_tokenizers_available() else None, + ), + ), ( "openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None), diff --git a/src/transformers/models/openai/configuration_openai.py b/src/transformers/models/openai/configuration_openai.py index 5fcb6c0c9cae..b8d2c589d001 100644 --- a/src/transformers/models/openai/configuration_openai.py +++ b/src/transformers/models/openai/configuration_openai.py @@ -24,7 +24,6 @@ class OpenaiConfig(PretrainedConfig): - model_type = "openai" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `OpenaiModel` @@ -55,7 +54,7 @@ def __init__( num_key_value_heads: int = 8, sliding_window: int = 128, rope_theta: float = 150000.0, - tie_word_embeddings= False, + tie_word_embeddings=False, hidden_act: str = "silu", initializer_range: float = 0.02, rms_norm_eps: float = 1e-6, diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index 3405bf3e96dc..118d5f291fbf 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -1,4 +1,3 @@ - # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,15 +15,14 @@ import argparse import gc import json - import os from typing import List, Optional import regex as re import torch -import torch.nn.functional as F import tqdm from safetensors.torch import load_file as safe_load + from transformers import ( GenerationConfig, OpenaiConfig, @@ -32,6 +30,8 @@ PreTrainedTokenizerFast, ) from transformers.convert_slow_tokenizer import TikTokenConverter + + # fmt: off # If a weight needs to be split in two or more keys, use `|` to indicate it. ex: # r"layers.(\d+).attention.wqkv.weight": r"layers.\1.self_attn.q|k|v|_proj.weight" @@ -91,18 +91,21 @@ def write_model( instruct=False, ): os.makedirs(model_path, exist_ok=True) - torch_dtype = "bfloat16" - bos_token_id = 128000 eos_token_id = [128001, 128008, 128009] if instruct else 128001 pad_token_id = 128004 config = OpenaiConfig.from_pretrained(input_base_path) - print(f"Fetching all parameters from the checkpoint at {input_base_path}...") - loaded = [safe_load(file for file in tqdm(os.listdir(model_path), desc="Loading shards", unit="shard") if file.endswith(".safetensors"))] + loaded = [ + safe_load( + file + for file in tqdm(os.listdir(model_path), desc="Loading shards", unit="shard") + if file.endswith(".safetensors") + ) + ] print("Converting ..") all_keys = list(loaded[0].keys()) @@ -113,7 +116,7 @@ def write_model( # Post-process the current_parameter. new_key = new_keys.get(key, key) if re.search("(k|v|q)_proj.weight", new_key) and "language_model" in new_key: - q, k , v = loaded[0][key].chunk(3, dim=-1) + q, k, v = loaded[0][key].chunk(3, dim=-1) q_key = re.sub(r"(k|v|q)_proj.weight", "q_proj.weight", new_key) state_dict[q_key] = q k_key = re.sub(r"(k|v|q)_proj.weight", "k_proj.weight", new_key) @@ -157,7 +160,7 @@ def write_model( generation_config.save_pretrained(model_path) -class MllamaConverter(TikTokenConverter): +class OpenaiConverter(TikTokenConverter): def __init__( self, vocab_file, @@ -168,6 +171,7 @@ def __init__( **kwargs, ): super().__init__(vocab_file, pattern=pattern) + # TODO 1st donwload the vocabfile!!! self.additional_special_tokens = special_tokens tokenizer = self.converted() if chat_template is not None: @@ -183,28 +187,7 @@ def __init__( def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): model_max_length = CONTEXT_LENGTH pattern = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: W605 - - # Special tokens - num_reserved_special_tokens = 256 - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|finetune_right_pad_id|>", - "<|step_id|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|eom_id|>", # end of message - "<|eot_id|>", # end of turn - "<|python_tag|>", - ] - special_tokens += [ - f"<|reserved_special_token_{i + 2}|>" for i in range(num_reserved_special_tokens - len(special_tokens)) - ] - # original tokenizer has <|image|> with 128011 token_id, - # however, later in the code it is replaced with 128256 token_id - special_tokens.append("<|image|>") + special_tokens = [] # Chat template chat_template = ( @@ -231,7 +214,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): "{% endif %}" ) - converter = MllamaConverter( + converter = OpenaiConverter( vocab_file=tokenizer_path, pattern=pattern, special_tokens=special_tokens, @@ -251,38 +234,16 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): json.dump({"chat_template": chat_template}, f, indent=2) -def write_image_processor(config_path: str, save_dir: str): - with open(config_path, "r") as f: - params = json.load(f) - - tile_size = params["vision_chunk_size"] - max_image_tiles = params["vision_max_num_chunks"] - - image_processor = MllamaImageProcessor( - do_resize=True, - size={"height": tile_size, "width": tile_size}, - do_rescale=True, - rescale_factor=1 / 255, - do_normalize=True, - image_mean=[0.48145466, 0.4578275, 0.40821073], - image_std=[0.26862954, 0.26130258, 0.27577711], - do_pad=True, - max_image_tiles=max_image_tiles, - ) - - image_processor.save_pretrained(save_dir) - - def main(): parser = argparse.ArgumentParser() parser.add_argument( "--input_dir", - default="Llama-3.2-11B-Vision/original", + default="/fsx/arthur/oai", help="Location of LLaMA weights, which contains tokenizer.model and model folders", ) parser.add_argument( "--output_dir", - default="Llama-3.2-11B-Vision", + default="/fsx/arthur/oai_hf", help="Location to write HF model and tokenizer", ) parser.add_argument( @@ -294,12 +255,7 @@ def main(): type=List[str], help="The list of special tokens that should be added to the ", ) - parser.add_argument( - "--num_shards", - default=1, - type=int, - help="The number of individual shards used for the Does not have to be the same as the number of consolidated_xx.pth", - ) + parser.add_argument( "--instruct", action="store_true", @@ -320,11 +276,6 @@ def main(): instruct=args.instruct, ) - write_image_processor( - config_path=os.path.join(args.input_dir, "params.json"), - save_dir=args.output_dir, - ) - if __name__ == "__main__": main() diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index d64fa4ae583f..3aa6a1b1165e 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -17,40 +17,40 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Tuple, Union +from typing import Optional import torch import torch.utils.checkpoint from torch import nn -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache -from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub -from ...masking_utils import create_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs - -from ...modeling_outputs import ( - BaseModelOutputWithPast, +from ...integrations.flex_attention import flex_attention_forward +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...utils import logging +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + repeat_kv, ) -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import Unpack -from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging +from ..llama4.modeling_llama4 import Llama4TextExperts from .configuration_openai import OpenaiConfig -from ..llama4.modeling_llama4 import apply_rotary_pos_emb, Llama4TextExperts -from ..llama.modeling_llama import LlamaRotaryEmbedding, LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaForCausalLM, LlamaRMSNorm, repeat_kv, LlamaPreTrainedModel -from ...integrations.flex_attention import flex_attention_forward + + logger = logging.get_logger(__name__) class OpenaiRMSNorm(LlamaRMSNorm): pass + class OpenaiExperts(Llama4TextExperts): pass + class OpenaiMLP(nn.Module): def __init__(self, config): super().__init__() @@ -75,11 +75,11 @@ def forward(self, hidden_states): out.add_(routed_out.reshape(self.num_experts, -1, self.hidden_dim).sum(dim=0)) return out, router_scores + class OpenaiRotaryEmbedding(LlamaRotaryEmbedding): pass - def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -117,6 +117,7 @@ def openai_flex_attention_forward( **kwargs, ): sink = module.sink + def attention_sink(score, b, h, q_idx, kv_idx): score = torch.cat([score, sink], dim=-1) return score @@ -134,9 +135,8 @@ def attention_sink(score, b, h, q_idx, kv_idx): **kwargs, ) -ALL_ATTENTION_FUNCTIONS.register( - "openai_flex_attention", openai_flex_attention_forward -) + +ALL_ATTENTION_FUNCTIONS.register("openai_flex_attention", openai_flex_attention_forward) class OpenaiAttention(LlamaAttention): @@ -166,7 +166,8 @@ class OpenaiPreTrainedModel(LlamaPreTrainedModel): class OpenaiModel(LlamaModel): - pass + pass + class OpenaiForCausalLM(LlamaForCausalLM): pass diff --git a/tests/models/openai/test_modeling_openai.py b/tests/models/openai/test_modeling_openai.py index 68516a691500..4239e0816bf7 100644 --- a/tests/models/openai/test_modeling_openai.py +++ b/tests/models/openai/test_modeling_openai.py @@ -600,7 +600,9 @@ def test_compile_static_cache(self): "Simply put, the theory of relativity states that ", "My favorite all time favorite condiment is ketchup.", ] - tokenizer = LlamaTokenizer.from_pretrained("meta-openai/Openai-2-7b-hf", pad_token="", padding_side="right") + tokenizer = LlamaTokenizer.from_pretrained( + "meta-openai/Openai-2-7b-hf", pad_token="", padding_side="right" + ) model = OpenaiForCausalLM.from_pretrained( "meta-openai/Openai-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16 ) From 5fa5670dba972c2786d8f89d9ec4a681b9872849 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:45:29 +0000 Subject: [PATCH 007/342] yes --- .../openai/convert_openai_weights_to_hf.py | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index 118d5f291fbf..7c6c91b939c0 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -36,22 +36,22 @@ # If a weight needs to be split in two or more keys, use `|` to indicate it. ex: # r"layers.(\d+).attention.wqkv.weight": r"layers.\1.self_attn.q|k|v|_proj.weight" ORIGINAL_TO_CONVERTED_KEY_MAPPING = { - r"norm.weight": r"norm.weight", - r"unembedding.weight": r"lm_head.weight", - r"embedding": r"embed_tokens", - r"rope.freqs": None, # meaning we skip it and don't want it + r"norm.weight": r"norm.weight", + r"unembedding.weight": r"lm_head.weight", + r"embedding": r"embed_tokens", + r"rope.freqs": None, # meaning we skip it and don't want it # special key, wqkv needs to be split afterwards - r"block.(\d+).attn.qkv": r"layers.\1.self_attn.(k|v|q)_proj", - r"block.(\d+).attn.out": r"layers.\1.self_attn.\2_proj", - r"block.(\d+).attn.sinks": r"layers.\1.self_attn.sinks", - r"block.(\d+).attn.norm": r"layers.\1.input_layernorm.weight", - - r"block.(\d+).mlp.mlp1_weight": r"layers.\1.mlp.gate_up_proj.weight", - r"block.(\d+).mlp.mlp1_bias": r"layers.\1.mlp.gate_up_proj.bias", - r"block.(\d+).mlp.mlp2_weight": r"layers.\1.mlp.down_proj.weight", - r"block.(\d+).mlp.mlp2_bias": r"layers.\1.mlp.down_proj.bias", - r"block.(\d+).mlp.norm": r"layers.\1.post_attention_layernorm.weight", - r"block.(\d+).mlp.gate": r"layers.\1.mlp.router.weight", + r"block.(\d+).attn.qkv": r"layers.\1.self_attn.(k|v|q)_proj", + r"block.(\d+).attn.out": r"layers.\1.self_attn.\2_proj", + r"block.(\d+).attn.sinks": r"layers.\1.self_attn.sinks", + r"block.(\d+).attn.norm": r"layers.\1.input_layernorm.weight", + + r"block.(\d+).mlp.mlp1_weight": r"layers.\1.mlp.gate_up_proj.weight", + r"block.(\d+).mlp.mlp1_bias": r"layers.\1.mlp.gate_up_proj.bias", + r"block.(\d+).mlp.mlp2_weight": r"layers.\1.mlp.down_proj.weight", + r"block.(\d+).mlp.mlp2_bias": r"layers.\1.mlp.down_proj.bias", + r"block.(\d+).mlp.norm": r"layers.\1.post_attention_layernorm.weight", + r"block.(\d+).mlp.gate": r"layers.\1.mlp.router.weight", } # fmt: on @@ -76,12 +76,6 @@ def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None): return output_dict -def compute_intermediate_size(hidden_dim, multiple_of=1024, ffn_dim_multiplier=1.3): - hidden_dim = 4 * int(2 * hidden_dim / 3) - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - return hidden_dim - def write_model( model_path, From 10257725c44b016b5930e0536a29ad84fddb3f49 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:45:46 +0000 Subject: [PATCH 008/342] nits --- src/transformers/models/openai/convert_openai_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index 7c6c91b939c0..0953933d427e 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -42,7 +42,7 @@ r"rope.freqs": None, # meaning we skip it and don't want it # special key, wqkv needs to be split afterwards r"block.(\d+).attn.qkv": r"layers.\1.self_attn.(k|v|q)_proj", - r"block.(\d+).attn.out": r"layers.\1.self_attn.\2_proj", + r"block.(\d+).attn.out": r"layers.\1.self_attn.o_proj", r"block.(\d+).attn.sinks": r"layers.\1.self_attn.sinks", r"block.(\d+).attn.norm": r"layers.\1.input_layernorm.weight", From 2d8ed7faef3774803f93cc0f474a798da90dd7df Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:46:21 +0000 Subject: [PATCH 009/342] up --- src/transformers/models/openai/convert_openai_weights_to_hf.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index 0953933d427e..6aa499ae1d04 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -39,7 +39,6 @@ r"norm.weight": r"norm.weight", r"unembedding.weight": r"lm_head.weight", r"embedding": r"embed_tokens", - r"rope.freqs": None, # meaning we skip it and don't want it # special key, wqkv needs to be split afterwards r"block.(\d+).attn.qkv": r"layers.\1.self_attn.(k|v|q)_proj", r"block.(\d+).attn.out": r"layers.\1.self_attn.o_proj", @@ -55,7 +54,6 @@ } # fmt: on -CONTEXT_LENGTH = 131072 def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None): From 21bdc8d7a6f93b6267dc39487ef381275e855a01 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:48:26 +0000 Subject: [PATCH 010/342] remove shard arg --- src/transformers/models/openai/convert_openai_weights_to_hf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index 6aa499ae1d04..1d9ea8bfbe1e 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -40,7 +40,7 @@ r"unembedding.weight": r"lm_head.weight", r"embedding": r"embed_tokens", # special key, wqkv needs to be split afterwards - r"block.(\d+).attn.qkv": r"layers.\1.self_attn.(k|v|q)_proj", + r"block.(\d+).attn.qkv": r"layers.\1.self_attn.k|v|q_proj", r"block.(\d+).attn.out": r"layers.\1.self_attn.o_proj", r"block.(\d+).attn.sinks": r"layers.\1.self_attn.sinks", r"block.(\d+).attn.norm": r"layers.\1.input_layernorm.weight", @@ -258,7 +258,6 @@ def main(): model_path=args.output_dir, input_base_path=args.input_dir, safe_serialization=args.safe_serialization, - num_shards=args.num_shards, instruct=args.instruct, ) From d3b519b4d9ad4ae1e5857e36d59e50bc7a3a105f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:51:39 +0000 Subject: [PATCH 011/342] remove annoyance --- .../models/openai/convert_openai_weights_to_hf.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index 1d9ea8bfbe1e..feace025e2c5 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -20,7 +20,7 @@ import regex as re import torch -import tqdm +from tqdm import tqdm from safetensors.torch import load_file as safe_load from transformers import ( @@ -78,7 +78,6 @@ def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None): def write_model( model_path, input_base_path, - num_shards, safe_serialization=True, instruct=False, ): @@ -90,11 +89,11 @@ def write_model( config = OpenaiConfig.from_pretrained(input_base_path) print(f"Fetching all parameters from the checkpoint at {input_base_path}...") - + print(list(os.listdir(input_base_path))) loaded = [ safe_load( file - for file in tqdm(os.listdir(model_path), desc="Loading shards", unit="shard") + for file in list(os.listdir(input_base_path)) if file.endswith(".safetensors") ) ] From d9b30f697a33f45533229ab4080b04e37dbef254 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:52:21 +0000 Subject: [PATCH 012/342] ckpt files path join --- .../models/openai/convert_openai_weights_to_hf.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index feace025e2c5..7308c51d1fe4 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -89,13 +89,10 @@ def write_model( config = OpenaiConfig.from_pretrained(input_base_path) print(f"Fetching all parameters from the checkpoint at {input_base_path}...") - print(list(os.listdir(input_base_path))) loaded = [ - safe_load( - file - for file in list(os.listdir(input_base_path)) + safe_load(os.path.join(input_base_path,file)) for file in list(os.listdir(input_base_path)) if file.endswith(".safetensors") - ) + ] print("Converting ..") From 178f1270301d73896ea3ef02dab4363ad4e63350 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:55:03 +0000 Subject: [PATCH 013/342] yup forgot to use the llama one --- src/transformers/models/openai/configuration_openai.py | 4 ++++ src/transformers/models/openai/modeling_openai.py | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/openai/configuration_openai.py b/src/transformers/models/openai/configuration_openai.py index b8d2c589d001..b9a12c1d5b58 100644 --- a/src/transformers/models/openai/configuration_openai.py +++ b/src/transformers/models/openai/configuration_openai.py @@ -91,6 +91,10 @@ def __init__( self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) + self.attention_bias = False + self.mlp_bias = False + self.max_position_embeddings = 8192 + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 3aa6a1b1165e..5b293a648e44 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -166,7 +166,11 @@ class OpenaiPreTrainedModel(LlamaPreTrainedModel): class OpenaiModel(LlamaModel): - pass + def __init__(self, config: OpenaiConfig): + super().__init__(config) + self.rope = OpenaiRotaryEmbedding(config) + self.layers = nn.ModuleList([OpenaiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) class OpenaiForCausalLM(LlamaForCausalLM): From ff452aeea0439e795c70b7c58dde008cf13eb351 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 18:59:48 +0000 Subject: [PATCH 014/342] small nits --- .../models/openai/configuration_openai.py | 3 ++- .../openai/convert_openai_weights_to_hf.py | 4 ++-- .../models/openai/modeling_openai.py | 18 ++++++++---------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/openai/configuration_openai.py b/src/transformers/models/openai/configuration_openai.py index b9a12c1d5b58..82e171512c37 100644 --- a/src/transformers/models/openai/configuration_openai.py +++ b/src/transformers/models/openai/configuration_openai.py @@ -63,6 +63,7 @@ def __init__( eos_token_id: int = 2, rope_scaling=None, attention_dropout: float = 0.0, + num_experts_per_tok=4, **kwargs, ): self.vocab_size = vocab_size @@ -72,7 +73,7 @@ def __init__( self.num_attention_heads = num_attention_heads self.num_experts = num_experts self.sliding_window = sliding_window - + self.num_experts_per_tok = num_experts_per_tok # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index 7308c51d1fe4..d1886590a2f2 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -102,7 +102,7 @@ def write_model( state_dict = {} for key in all_keys: # Post-process the current_parameter. - new_key = new_keys.get(key, key) + new_key = "model." + new_keys.get(key, key) if re.search("(k|v|q)_proj.weight", new_key) and "language_model" in new_key: q, k, v = loaded[0][key].chunk(3, dim=-1) q_key = re.sub(r"(k|v|q)_proj.weight", "q_proj.weight", new_key) @@ -117,7 +117,7 @@ def write_model( del loaded gc.collect() - print("Loading the checkpoint in a Mllama ") + print("Loading the checkpoint in a OpenAI ") with torch.device("meta"): model = OpenaiForCausalLM(config) model.load_state_dict(state_dict, strict=True, assign=True) diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 5b293a648e44..fa971d4b221e 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -56,7 +56,7 @@ def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.hidden_dim = config.hidden_size - self.num_experts = config.num_local_experts + self.num_local_experts = config.num_local_experts self.experts = OpenaiExperts(config) self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=False) @@ -68,11 +68,11 @@ def forward(self, hidden_states): torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) - routed_in = hidden_states.repeat(self.num_experts, 1) + routed_in = hidden_states.repeat(self.num_local_experts, 1) routed_in = routed_in * router_scores.reshape(-1, 1) routed_out = self.experts(routed_in) out = self.shared_expert(hidden_states) - out.add_(routed_out.reshape(self.num_experts, -1, self.hidden_dim).sum(dim=0)) + out.add_(routed_out.reshape(self.num_local_experts, -1, self.hidden_dim).sum(dim=0)) return out, router_scores @@ -140,17 +140,13 @@ def attention_sink(score, b, h, q_idx, kv_idx): class OpenaiAttention(LlamaAttention): - """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: OpenaiConfig, layer_idx: int): - super().__init__() + super().__init__(config, layer_idx) self.sinks = torch.empty(config.num_attention_heads) - -# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Openai class OpenaiDecoderLayer(LlamaDecoderLayer): def __init__(self, config: OpenaiConfig, layer_idx: int): - super().__init__() + super().__init__(config, layer_idx) self.hidden_size = config.hidden_size self.self_attn = OpenaiAttention(config=config, layer_idx=layer_idx) self.mlp = OpenaiMLP(config) @@ -174,7 +170,9 @@ def __init__(self, config: OpenaiConfig): class OpenaiForCausalLM(LlamaForCausalLM): - pass + def __init__(self, config: OpenaiConfig): + super().__init__(config) + self.model = OpenaiModel(config) __all__ = [ From 08afd6bf136ce395cb74da95ab8689018d290775 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 19:01:41 +0000 Subject: [PATCH 015/342] fix --- src/transformers/models/openai/configuration_openai.py | 4 ++-- .../models/openai/convert_openai_weights_to_hf.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/openai/configuration_openai.py b/src/transformers/models/openai/configuration_openai.py index 82e171512c37..bef12f373048 100644 --- a/src/transformers/models/openai/configuration_openai.py +++ b/src/transformers/models/openai/configuration_openai.py @@ -45,7 +45,7 @@ class OpenaiConfig(PretrainedConfig): def __init__( self, num_hidden_layers: int = 36, - num_experts: int = 128, + num_local_experts: int = 128, vocab_size: int = 201088, hidden_size: int = 2880, intermediate_size: int = 2880, @@ -71,7 +71,7 @@ def __init__( self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads - self.num_experts = num_experts + self.num_local_experts = num_local_experts self.sliding_window = sliding_window self.num_experts_per_tok = num_experts_per_tok # for backward compatibility diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index d1886590a2f2..54b5415ae2fb 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -40,7 +40,7 @@ r"unembedding.weight": r"lm_head.weight", r"embedding": r"embed_tokens", # special key, wqkv needs to be split afterwards - r"block.(\d+).attn.qkv": r"layers.\1.self_attn.k|v|q_proj", + r"block.(\d+).attn.qkv": r"layers.\1.self_attn.qkv_proj", r"block.(\d+).attn.out": r"layers.\1.self_attn.o_proj", r"block.(\d+).attn.sinks": r"layers.\1.self_attn.sinks", r"block.(\d+).attn.norm": r"layers.\1.input_layernorm.weight", @@ -103,12 +103,12 @@ def write_model( for key in all_keys: # Post-process the current_parameter. new_key = "model." + new_keys.get(key, key) - if re.search("(k|v|q)_proj.weight", new_key) and "language_model" in new_key: + if re.search("qkv_proj.weight", new_key): q, k, v = loaded[0][key].chunk(3, dim=-1) - q_key = re.sub(r"(k|v|q)_proj.weight", "q_proj.weight", new_key) + q_key = re.sub(r"(qkv)_proj.weight", "q", new_key) state_dict[q_key] = q - k_key = re.sub(r"(k|v|q)_proj.weight", "k_proj.weight", new_key) - v_key = re.sub(r"(k|v|q)_proj.weight", "v_proj.weight", new_key) + k_key = re.sub(r"(qkv)_proj.weight", "k", new_key) + v_key = re.sub(r"(qkv)_proj.weight", "v", new_key) state_dict[k_key] = k state_dict[v_key] = v else: From f0004e040061411568a5830bbd45c3e999f56935 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 19:04:24 +0000 Subject: [PATCH 016/342] something was too legacy --- .../models/openai/convert_openai_weights_to_hf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index 54b5415ae2fb..d9938482391a 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -96,15 +96,17 @@ def write_model( ] print("Converting ..") - all_keys = list(loaded[0].keys()) + weights = torch.cat(loaded, dim=0) + all_keys = weights.keys() new_keys = convert_old_keys_to_new_keys(all_keys) state_dict = {} for key in all_keys: # Post-process the current_parameter. new_key = "model." + new_keys.get(key, key) + print(f"Processing key: {key} -> {new_key}") if re.search("qkv_proj.weight", new_key): - q, k, v = loaded[0][key].chunk(3, dim=-1) + q, k, v = weights[key].chunk(3, dim=-1) q_key = re.sub(r"(qkv)_proj.weight", "q", new_key) state_dict[q_key] = q k_key = re.sub(r"(qkv)_proj.weight", "k", new_key) From f522b89fc7efee5bee31a38a97345f58bb717c1e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 19:06:07 +0000 Subject: [PATCH 017/342] del as well --- .../openai/convert_openai_weights_to_hf.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index d9938482391a..a2d7b45bd7f1 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -89,15 +89,13 @@ def write_model( config = OpenaiConfig.from_pretrained(input_base_path) print(f"Fetching all parameters from the checkpoint at {input_base_path}...") - loaded = [ - safe_load(os.path.join(input_base_path,file)) for file in list(os.listdir(input_base_path)) - if file.endswith(".safetensors") - - ] + final_ = {} + for file in list(os.listdir(input_base_path)): + if file.endswith(".safetensors"): + final_.update(safe_load(os.path.join(input_base_path,file)) ) print("Converting ..") - weights = torch.cat(loaded, dim=0) - all_keys = weights.keys() + all_keys = final_.keys() new_keys = convert_old_keys_to_new_keys(all_keys) state_dict = {} @@ -106,7 +104,7 @@ def write_model( new_key = "model." + new_keys.get(key, key) print(f"Processing key: {key} -> {new_key}") if re.search("qkv_proj.weight", new_key): - q, k, v = weights[key].chunk(3, dim=-1) + q, k, v = final_[key].chunk(3, dim=-1) q_key = re.sub(r"(qkv)_proj.weight", "q", new_key) state_dict[q_key] = q k_key = re.sub(r"(qkv)_proj.weight", "k", new_key) @@ -114,9 +112,9 @@ def write_model( state_dict[k_key] = k state_dict[v_key] = v else: - state_dict[new_key] = loaded[0][key] + state_dict[new_key] = final_[key] - del loaded + del final_ gc.collect() print("Loading the checkpoint in a OpenAI ") From 58eced8ac16fc9a7fdd598fd48aa584077f89f53 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 19:18:45 +0000 Subject: [PATCH 018/342] all keys match, onto the shapes! --- .../models/openai/configuration_openai.py | 2 +- .../openai/convert_openai_weights_to_hf.py | 32 ++++++++------- .../models/openai/modeling_openai.py | 41 ++++++++++++++++--- 3 files changed, 53 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/openai/configuration_openai.py b/src/transformers/models/openai/configuration_openai.py index bef12f373048..d3637633ffbc 100644 --- a/src/transformers/models/openai/configuration_openai.py +++ b/src/transformers/models/openai/configuration_openai.py @@ -92,7 +92,7 @@ def __init__( self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) - self.attention_bias = False + self.attention_bias = True self.mlp_bias = False self.max_position_embeddings = 8192 diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index a2d7b45bd7f1..673ba966eae1 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -36,21 +36,21 @@ # If a weight needs to be split in two or more keys, use `|` to indicate it. ex: # r"layers.(\d+).attention.wqkv.weight": r"layers.\1.self_attn.q|k|v|_proj.weight" ORIGINAL_TO_CONVERTED_KEY_MAPPING = { - r"norm.weight": r"norm.weight", + r"norm.scale": r"norm.weight", r"unembedding.weight": r"lm_head.weight", r"embedding": r"embed_tokens", # special key, wqkv needs to be split afterwards r"block.(\d+).attn.qkv": r"layers.\1.self_attn.qkv_proj", r"block.(\d+).attn.out": r"layers.\1.self_attn.o_proj", - r"block.(\d+).attn.sinks": r"layers.\1.self_attn.sinks", - r"block.(\d+).attn.norm": r"layers.\1.input_layernorm.weight", - - r"block.(\d+).mlp.mlp1_weight": r"layers.\1.mlp.gate_up_proj.weight", - r"block.(\d+).mlp.mlp1_bias": r"layers.\1.mlp.gate_up_proj.bias", - r"block.(\d+).mlp.mlp2_weight": r"layers.\1.mlp.down_proj.weight", - r"block.(\d+).mlp.mlp2_bias": r"layers.\1.mlp.down_proj.bias", - r"block.(\d+).mlp.norm": r"layers.\1.post_attention_layernorm.weight", - r"block.(\d+).mlp.gate": r"layers.\1.mlp.router.weight", + r"block.(\d+).attn.sdpa.sinks": r"layers.\1.self_attn.sinks", + r"block.(\d+).attn.norm.scale": r"layers.\1.input_layernorm.weight", + + r"block.(\d+).mlp.mlp1_weight": r"layers.\1.mlp.experts.gate_up_proj", + r"block.(\d+).mlp.mlp1_bias": r"layers.\1.mlp.experts.gate_up_proj_bias", + r"block.(\d+).mlp.mlp2_weight": r"layers.\1.mlp.experts.down_proj", + r"block.(\d+).mlp.mlp2_bias": r"layers.\1.mlp.experts.down_proj_bias", + r"block.(\d+).mlp.norm.scale": r"layers.\1.post_attention_layernorm.weight", + r"block.(\d+).mlp.gate": r"layers.\1.mlp.router", } # fmt: on @@ -101,14 +101,16 @@ def write_model( state_dict = {} for key in all_keys: # Post-process the current_parameter. - new_key = "model." + new_keys.get(key, key) + new_key = new_keys.get(key, key) + if "lm_head" not in new_key: + new_key = "model." + new_key print(f"Processing key: {key} -> {new_key}") - if re.search("qkv_proj.weight", new_key): + if re.search("qkv_proj", new_key): q, k, v = final_[key].chunk(3, dim=-1) - q_key = re.sub(r"(qkv)_proj.weight", "q", new_key) + q_key = re.sub(r"qkv_proj", "q_proj", new_key) state_dict[q_key] = q - k_key = re.sub(r"(qkv)_proj.weight", "k", new_key) - v_key = re.sub(r"(qkv)_proj.weight", "v", new_key) + k_key = re.sub(r"qkv_proj", "k_proj", new_key) + v_key = re.sub(r"qkv_proj", "v_proj", new_key) state_dict[k_key] = k state_dict[v_key] = v else: diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index fa971d4b221e..2218eca41b05 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -22,7 +22,7 @@ import torch import torch.utils.checkpoint from torch import nn - +from ...activations import ACT2FN from ...integrations.flex_attention import flex_attention_forward from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import logging @@ -36,7 +36,6 @@ LlamaRotaryEmbedding, repeat_kv, ) -from ..llama4.modeling_llama4 import Llama4TextExperts from .configuration_openai import OpenaiConfig @@ -47,8 +46,38 @@ class OpenaiRMSNorm(LlamaRMSNorm): pass -class OpenaiExperts(Llama4TextExperts): - pass +class OpenaiExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + This should really not be run on a single machine, as we are reaching compute bound: + - the inputs are expected to be "sorted" per expert already. + - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape + + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, top_k) + Returns: + torch.Tensor + """ + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias.unsqueeze(1) + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj) + self.down_proj_bias.unsqueeze(1) + next_states = next_states.view(-1, self.hidden_size) + return next_states class OpenaiMLP(nn.Module): @@ -58,7 +87,7 @@ def __init__(self, config): self.hidden_dim = config.hidden_size self.num_local_experts = config.num_local_experts self.experts = OpenaiExperts(config) - self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=False) + self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=True) def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) @@ -142,7 +171,7 @@ def attention_sink(score, b, h, q_idx, kv_idx): class OpenaiAttention(LlamaAttention): def __init__(self, config: OpenaiConfig, layer_idx: int): super().__init__(config, layer_idx) - self.sinks = torch.empty(config.num_attention_heads) + self.register_buffer("sinks", torch.empty(config.num_attention_heads), persistent=True) class OpenaiDecoderLayer(LlamaDecoderLayer): def __init__(self, config: OpenaiConfig, layer_idx: int): From 285d141f5468a16990692074e4bbae52ab6caa6f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 19:33:14 +0000 Subject: [PATCH 019/342] fix key shapes --- .../models/openai/convert_openai_weights_to_hf.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index 673ba966eae1..bb6726ea034a 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -36,7 +36,8 @@ # If a weight needs to be split in two or more keys, use `|` to indicate it. ex: # r"layers.(\d+).attention.wqkv.weight": r"layers.\1.self_attn.q|k|v|_proj.weight" ORIGINAL_TO_CONVERTED_KEY_MAPPING = { - r"norm.scale": r"norm.weight", + r"norm.weight": r"norm.weight", + r"^norm.scale": r"norm.weight", r"unembedding.weight": r"lm_head.weight", r"embedding": r"embed_tokens", # special key, wqkv needs to be split afterwards @@ -86,7 +87,7 @@ def write_model( eos_token_id = [128001, 128008, 128009] if instruct else 128001 pad_token_id = 128004 - config = OpenaiConfig.from_pretrained(input_base_path) + config = OpenaiConfig() print(f"Fetching all parameters from the checkpoint at {input_base_path}...") final_ = {} @@ -106,13 +107,17 @@ def write_model( new_key = "model." + new_key print(f"Processing key: {key} -> {new_key}") if re.search("qkv_proj", new_key): - q, k, v = final_[key].chunk(3, dim=-1) + q_len = config.head_dim * config.num_attention_heads + k_len = config.head_dim * config.num_key_value_heads + q, k, v = final_[key][:q_len, ...], final_[key][q_len:k_len+q_len, ...], final_[key][k_len+q_len:, ...] q_key = re.sub(r"qkv_proj", "q_proj", new_key) - state_dict[q_key] = q k_key = re.sub(r"qkv_proj", "k_proj", new_key) v_key = re.sub(r"qkv_proj", "v_proj", new_key) + state_dict[q_key] = q state_dict[k_key] = k state_dict[v_key] = v + elif re.search("gate_up_proj", new_key) and "bias" not in new_key: + state_dict[new_key] = final_[key].permute(0,2,1) else: state_dict[new_key] = final_[key] From 3f0a54ac3c6e5529de873078cf76d299086524b9 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 19:34:16 +0000 Subject: [PATCH 020/342] convert embed to the correct dtype for now --- src/transformers/models/openai/convert_openai_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index bb6726ea034a..559e4b2c2e70 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -119,7 +119,7 @@ def write_model( elif re.search("gate_up_proj", new_key) and "bias" not in new_key: state_dict[new_key] = final_[key].permute(0,2,1) else: - state_dict[new_key] = final_[key] + state_dict[new_key] = final_[key].to(torch.bfloat16) # TODO slow, let's rmeove del final_ gc.collect() From 62af880220e73a7fbdabd7e1960130589ea1d72a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 28 May 2025 19:35:01 +0000 Subject: [PATCH 021/342] saved the checkpiont --- src/transformers/models/openai/convert_openai_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index 559e4b2c2e70..c6819007e043 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -37,7 +37,7 @@ # r"layers.(\d+).attention.wqkv.weight": r"layers.\1.self_attn.q|k|v|_proj.weight" ORIGINAL_TO_CONVERTED_KEY_MAPPING = { r"norm.weight": r"norm.weight", - r"^norm.scale": r"norm.weight", + r"\nnorm.scale": r"\nnorm.weight", r"unembedding.weight": r"lm_head.weight", r"embedding": r"embed_tokens", # special key, wqkv needs to be split afterwards From adcadd9fe45d0e8ebe186754aff862a0f871db92 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 04:56:15 +0000 Subject: [PATCH 022/342] fix model type and wholes --- .../models/auto/configuration_auto.py | 4 +- src/transformers/models/auto/modeling_auto.py | 7 +- .../models/auto/tokenization_auto.py | 2 +- .../models/openai/configuration_openai.py | 2 +- .../openai/convert_openai_weights_to_hf.py | 78 ++++++++++++++++--- .../models/openai/modeling_openai.py | 25 +++--- 6 files changed, 89 insertions(+), 29 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index aadf21a90a83..43d8cb18c745 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -238,7 +238,7 @@ ("omdet-turbo", "OmDetTurboConfig"), ("oneformer", "OneFormerConfig"), ("open-llama", "OpenLlamaConfig"), - ("openai", "OpenaiConfig"), + ("openai-moe", "OpenaiConfig"), ("openai-gpt", "OpenAIGPTConfig"), ("opt", "OPTConfig"), ("owlv2", "Owlv2Config"), @@ -617,7 +617,7 @@ ("omdet-turbo", "OmDet-Turbo"), ("oneformer", "OneFormer"), ("open-llama", "OpenLlama"), - ("openai", "openai"), + ("openai-moe", "openai"), ("openai-gpt", "OpenAI GPT"), ("opt", "OPT"), ("owlv2", "OWLv2"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 3881dcffa88b..5b4c0062be08 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -232,7 +232,7 @@ ("omdet-turbo", "OmDetTurboForObjectDetection"), ("oneformer", "OneFormerModel"), ("open-llama", "OpenLlamaModel"), - ("openai", "OpenaiModel"), + ("openai-moe", "OpenaiModel"), ("openai-gpt", "OpenAIGPTModel"), ("opt", "OPTModel"), ("owlv2", "Owlv2Model"), @@ -607,7 +607,7 @@ ("olmo2", "Olmo2ForCausalLM"), ("olmoe", "OlmoeForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"), - ("openai", "OpenaiForCausalLM"), + ("openai-moe", "OpenaiForCausalLM"), ("openai-gpt", "OpenAIGPTLMHeadModel"), ("opt", "OPTForCausalLM"), ("pegasus", "PegasusForCausalLM"), @@ -1120,7 +1120,6 @@ ("nezha", "NezhaForSequenceClassification"), ("nystromformer", "NystromformerForSequenceClassification"), ("open-llama", "OpenLlamaForSequenceClassification"), - ("openai", "OpenaiForSequenceClassification"), ("openai-gpt", "OpenAIGPTForSequenceClassification"), ("opt", "OPTForSequenceClassification"), ("perceiver", "PerceiverForSequenceClassification"), @@ -1211,7 +1210,6 @@ ("nemotron", "NemotronForQuestionAnswering"), ("nezha", "NezhaForQuestionAnswering"), ("nystromformer", "NystromformerForQuestionAnswering"), - ("openai", "OpenaiForQuestionAnswering"), ("opt", "OPTForQuestionAnswering"), ("qdqbert", "QDQBertForQuestionAnswering"), ("qwen2", "Qwen2ForQuestionAnswering"), @@ -1317,7 +1315,6 @@ ("nemotron", "NemotronForTokenClassification"), ("nezha", "NezhaForTokenClassification"), ("nystromformer", "NystromformerForTokenClassification"), - ("openai", "OpenaiForTokenClassification"), ("persimmon", "PersimmonForTokenClassification"), ("phi", "PhiForTokenClassification"), ("phi3", "Phi3ForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index c5fe3358df4d..d1569c77b63c 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -411,7 +411,7 @@ ), ("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ( - "openai", + "openai-moe", ( None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None, diff --git a/src/transformers/models/openai/configuration_openai.py b/src/transformers/models/openai/configuration_openai.py index d3637633ffbc..147047e60a27 100644 --- a/src/transformers/models/openai/configuration_openai.py +++ b/src/transformers/models/openai/configuration_openai.py @@ -24,7 +24,7 @@ class OpenaiConfig(PretrainedConfig): - model_type = "openai" + model_type = "openai-moe" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `OpenaiModel` base_model_tp_plan = { diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai/convert_openai_weights_to_hf.py index c6819007e043..9fcec5f6c085 100644 --- a/src/transformers/models/openai/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai/convert_openai_weights_to_hf.py @@ -155,7 +155,58 @@ def write_model( generation_config.save_pretrained(model_path) +# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + +import tiktoken class OpenaiConverter(TikTokenConverter): + def extract_vocab_merges_from_model(self, tiktoken_url: str): + tokenizer = tiktoken.get_encoding(tiktoken_url) + + bpe_ranks = tokenizer._mergeable_ranks + byte_encoder = bytes_to_unicode() + + def token_bytes_to_string(b): + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) + + merges = [] + vocab = {} + for token, rank in bpe_ranks.items(): + vocab[token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + local = [] + for index in range(1, len(token)): + piece_l, piece_r = token[:index], token[index:] + if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks: + local.append((piece_l, piece_r, rank)) + local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False) + merges.extend(local) + merges = sorted(merges, key=lambda val: val[2], reverse=False) + merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges] + return vocab, merges + def __init__( self, vocab_file, @@ -166,8 +217,16 @@ def __init__( **kwargs, ): super().__init__(vocab_file, pattern=pattern) + # TODO 1st donwload the vocabfile!!! - self.additional_special_tokens = special_tokens + tokenizer = tiktoken.get_encoding(vocab_file) + special_tokens = tokenizer._special_tokens.keys() + + self.additional_special_tokens = {'<|endoftext|>': 199999, '<|endofprompt|>': 200018} + for k in range(200000, 200018): + self.additional_special_tokens[f"<|reserved_{k}|>"] = k + sorted_list = sorted(self.additional_special_tokens.items(), key=lambda x: x[1]) + self.additional_special_tokens = [k[0] for k in sorted_list] tokenizer = self.converted() if chat_template is not None: kwargs["chat_template"] = chat_template @@ -180,7 +239,6 @@ def __init__( def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): - model_max_length = CONTEXT_LENGTH pattern = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: W605 special_tokens = [] @@ -213,7 +271,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): vocab_file=tokenizer_path, pattern=pattern, special_tokens=special_tokens, - model_max_length=model_max_length, + model_max_length=None, chat_template=chat_template if instruct else None, bos_token="<|begin_of_text|>", eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>", @@ -257,15 +315,15 @@ def main(): help="Whether the model is an instruct model", ) args = parser.parse_args() - write_model( - model_path=args.output_dir, - input_base_path=args.input_dir, - safe_serialization=args.safe_serialization, - instruct=args.instruct, - ) + # write_model( + # model_path=args.output_dir, + # input_base_path=args.input_dir, + # safe_serialization=args.safe_serialization, + # instruct=args.instruct, + # ) write_tokenizer( - tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"), + tokenizer_path="o200k_base", save_dir=args.output_dir, instruct=args.instruct, ) diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 2218eca41b05..5ad61855e362 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -46,6 +46,12 @@ class OpenaiRMSNorm(LlamaRMSNorm): pass +def swiglu(x, alpha: float = 1.702): + # Note we add an extra bias of 1 to the linear layer + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return out_glu * (x_linear + 1) + class OpenaiExperts(nn.Module): def __init__(self, config): super().__init__() @@ -73,10 +79,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: torch.Tensor """ hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias.unsqueeze(1) - gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors - next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj) + self.down_proj_bias.unsqueeze(1) - next_states = next_states.view(-1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias + swiglu = swiglu(gate_up, alpha=1.702) + next_states = torch.bmm(swiglu, self.down_proj) + self.down_proj_bias return next_states @@ -90,19 +95,19 @@ def __init__(self, config): self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=True) def forward(self, hidden_states): + # we don't slice weight as its not compile compatible hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = self.router(hidden_states) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1, sorted=True) router_scores = ( torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) - router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) + routed_in = hidden_states.repeat(self.num_local_experts, 1) - routed_in = routed_in * router_scores.reshape(-1, 1) routed_out = self.experts(routed_in) - out = self.shared_expert(hidden_states) - out.add_(routed_out.reshape(self.num_local_experts, -1, self.hidden_dim).sum(dim=0)) - return out, router_scores + routed_out = routed_out.view(-1, self.hidden_size) * router_scores.reshape(self.num_local_experts, -1, 1) + hidden_states = routed_out.sum(dim=0) + return hidden_states, router_scores class OpenaiRotaryEmbedding(LlamaRotaryEmbedding): From a68c00be96a6cf7b64b50eec0ef65ad4f1395952 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 05:01:30 +0000 Subject: [PATCH 023/342] oups --- ...penai_original_tf_checkpoint_to_pytorch.py | 74 -- .../models/openai/modeling_tf_openai.py | 937 ------------------ .../models/openai/tokenization_openai.py | 396 -------- .../models/openai/tokenization_openai_fast.py | 66 -- .../models/{openai => openai_moe}/__init__.py | 0 .../configuration_openai.py | 0 .../convert_openai_weights_to_hf.py | 0 .../{openai => openai_moe}/modeling_openai.py | 0 8 files changed, 1473 deletions(-) delete mode 100755 src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py delete mode 100644 src/transformers/models/openai/modeling_tf_openai.py delete mode 100644 src/transformers/models/openai/tokenization_openai.py delete mode 100644 src/transformers/models/openai/tokenization_openai_fast.py rename src/transformers/models/{openai => openai_moe}/__init__.py (100%) rename src/transformers/models/{openai => openai_moe}/configuration_openai.py (100%) rename src/transformers/models/{openai => openai_moe}/convert_openai_weights_to_hf.py (100%) rename src/transformers/models/{openai => openai_moe}/modeling_openai.py (100%) diff --git a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py deleted file mode 100755 index 3d5218c20426..000000000000 --- a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py +++ /dev/null @@ -1,74 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert OpenAI GPT checkpoint.""" - -import argparse - -import torch - -from transformers import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt -from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging - - -logging.set_verbosity_info() - - -def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): - # Construct model - if openai_config_file == "": - config = OpenAIGPTConfig() - else: - config = OpenAIGPTConfig.from_json_file(openai_config_file) - model = OpenAIGPTModel(config) - - # Load weights from numpy - load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) - - # Save pytorch-model - pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME - pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME - print(f"Save PyTorch model to {pytorch_weights_dump_path}") - torch.save(model.state_dict(), pytorch_weights_dump_path) - print(f"Save configuration file to {pytorch_config_dump_path}") - with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: - f.write(config.to_json_string()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--openai_checkpoint_folder_path", - default=None, - type=str, - required=True, - help="Path to the TensorFlow checkpoint path.", - ) - parser.add_argument( - "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." - ) - parser.add_argument( - "--openai_config_file", - default="", - type=str, - help=( - "An optional config json file corresponding to the pre-trained OpenAI model. \n" - "This specifies the model architecture." - ), - ) - args = parser.parse_args() - convert_openai_checkpoint_to_pytorch( - args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path - ) diff --git a/src/transformers/models/openai/modeling_tf_openai.py b/src/transformers/models/openai/modeling_tf_openai.py deleted file mode 100644 index 3856711d1062..000000000000 --- a/src/transformers/models/openai/modeling_tf_openai.py +++ /dev/null @@ -1,937 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TF 2.0 OpenAI GPT model.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import numpy as np -import tensorflow as tf - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFConv1D, - TFModelInputType, - TFPreTrainedModel, - TFSequenceClassificationLoss, - TFSequenceSummary, - TFSharedEmbeddings, - get_initializer, - keras, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_openai import OpenAIGPTConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "openai-community/openai-gpt" -_CONFIG_FOR_DOC = "OpenAIGPTConfig" - - -class TFAttention(keras.layers.Layer): - def __init__(self, nx, config, scale=False, **kwargs): - super().__init__(**kwargs) - - n_state = nx # in Attention: n_state=768 (nx=n_embd) - # [switch nx => n_state from Block to Attention to keep identical to TF implementation] - assert n_state % config.n_head == 0, ( - f"Hidden dimension {n_state} not dividable by number of heads {config.n_head}" - ) - self.n_head = config.n_head - self.split_size = n_state - self.scale = scale - self.output_attentions = config.output_attentions - - self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn") - self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj") - self.attn_dropout = keras.layers.Dropout(config.attn_pdrop) - self.resid_dropout = keras.layers.Dropout(config.resid_pdrop) - self.n_state = n_state - self.pruned_heads = set() - - def prune_heads(self, heads): - pass - - @staticmethod - def causal_attention_mask(nd, ns): - """ - 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), - -1, ns-nd), but doesn't produce garbage on TPUs. - """ - i = tf.range(nd)[:, None] - j = tf.range(ns) - m = i >= j - ns + nd - return m - - def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False): - # q, k, v have shape [batch, heads, sequence, features] - w = tf.matmul(q, k, transpose_b=True) - if self.scale: - dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores - w = w / tf.math.sqrt(dk) - - # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. - _, _, nd, ns = shape_list(w) - b = tf.cast(self.causal_attention_mask(nd, ns), dtype=w.dtype) - b = tf.reshape(b, [1, 1, nd, ns]) - w = w * b - 1e4 * (1 - b) - - if attention_mask is not None: - # Apply the attention mask - attention_mask = tf.cast(attention_mask, dtype=w.dtype) - w = w + attention_mask - - w = stable_softmax(w, axis=-1) - w = self.attn_dropout(w, training=training) - - # Mask heads if we want to - if head_mask is not None: - w = w * head_mask - - outputs = [tf.matmul(w, v)] - if output_attentions: - outputs.append(w) - return outputs - - def merge_heads(self, x): - x = tf.transpose(x, [0, 2, 1, 3]) - x_shape = shape_list(x) - new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]] - return tf.reshape(x, new_x_shape) - - def split_heads(self, x): - x_shape = shape_list(x) - new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head] - x = tf.reshape(x, new_x_shape) - return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) - - def call(self, x, attention_mask, head_mask, output_attentions, training=False): - x = self.c_attn(x) - query, key, value = tf.split(x, 3, axis=2) - query = self.split_heads(query) - key = self.split_heads(key) - value = self.split_heads(value) - - attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training) - a = attn_outputs[0] - - a = self.merge_heads(a) - a = self.c_proj(a) - a = self.resid_dropout(a, training=training) - - outputs = [a] + attn_outputs[1:] - return outputs # a, (attentions) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "c_attn", None) is not None: - with tf.name_scope(self.c_attn.name): - self.c_attn.build([None, None, self.n_state * 3]) - if getattr(self, "c_proj", None) is not None: - with tf.name_scope(self.c_proj.name): - self.c_proj.build([None, None, self.n_state]) - - -class TFMLP(keras.layers.Layer): - def __init__(self, n_state, config, **kwargs): - super().__init__(**kwargs) - nx = config.n_embd - self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc") - self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj") - self.act = get_tf_activation("gelu") - self.dropout = keras.layers.Dropout(config.resid_pdrop) - self.nx = nx - self.n_state = n_state - - def call(self, x, training=False): - h = self.act(self.c_fc(x)) - h2 = self.c_proj(h) - h2 = self.dropout(h2, training=training) - return h2 - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "c_fc", None) is not None: - with tf.name_scope(self.c_fc.name): - self.c_fc.build([None, None, self.n_state]) - if getattr(self, "c_proj", None) is not None: - with tf.name_scope(self.c_proj.name): - self.c_proj.build([None, None, self.nx]) - - -class TFBlock(keras.layers.Layer): - def __init__(self, config, scale=False, **kwargs): - super().__init__(**kwargs) - nx = config.n_embd - self.attn = TFAttention(nx, config, scale, name="attn") - self.ln_1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1") - self.mlp = TFMLP(4 * nx, config, name="mlp") - self.ln_2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2") - self.nx = nx - - def call(self, x, attention_mask, head_mask, output_attentions, training=False): - output_attn = self.attn(x, attention_mask, head_mask, output_attentions, training=training) - a = output_attn[0] # output_attn: a, (attentions) - - n = self.ln_1(x + a) - m = self.mlp(n, training=training) - h = self.ln_2(n + m) - - outputs = [h] + output_attn[1:] - return outputs # x, (attentions) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "attn", None) is not None: - with tf.name_scope(self.attn.name): - self.attn.build(None) - if getattr(self, "ln_1", None) is not None: - with tf.name_scope(self.ln_1.name): - self.ln_1.build([None, None, self.nx]) - if getattr(self, "mlp", None) is not None: - with tf.name_scope(self.mlp.name): - self.mlp.build(None) - if getattr(self, "ln_2", None) is not None: - with tf.name_scope(self.ln_2.name): - self.ln_2.build([None, None, self.nx]) - - -@keras_serializable -class TFOpenAIGPTMainLayer(keras.layers.Layer): - config_class = OpenAIGPTConfig - - def __init__(self, config, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - self.config = config - self.output_hidden_states = config.output_hidden_states - self.output_attentions = config.output_attentions - self.return_dict = config.use_return_dict - self.num_hidden_layers = config.n_layer - self.n_embd = config.n_embd - self.n_positions = config.n_positions - self.initializer_range = config.initializer_range - - self.tokens_embed = TFSharedEmbeddings( - config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed" - ) - self.drop = keras.layers.Dropout(config.embd_pdrop) - self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)] - - def build(self, input_shape=None): - with tf.name_scope("positions_embed"): - self.positions_embed = self.add_weight( - name="embeddings", - shape=[self.n_positions, self.n_embd], - initializer=get_initializer(self.initializer_range), - ) - - if self.built: - return - self.built = True - if getattr(self, "tokens_embed", None) is not None: - with tf.name_scope(self.tokens_embed.name): - self.tokens_embed.build(None) - if getattr(self, "h", None) is not None: - for layer in self.h: - with tf.name_scope(layer.name): - layer.build(None) - - def get_input_embeddings(self): - return self.tokens_embed - - def set_input_embeddings(self, value): - self.tokens_embed.weight = value - self.tokens_embed.vocab_size = shape_list(value)[0] - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - """ - raise NotImplementedError - - @unpack_inputs - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: Optional[bool] = False, - ) -> Union[Tuple, TFBaseModelOutput]: - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = shape_list(input_ids) - input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if position_ids is None: - position_ids = tf.expand_dims(tf.range(input_shape[-1]), axis=0) - - if attention_mask is not None: - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - - one_cst = tf.constant(1.0) - attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype) - attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0)) - else: - attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if head_mask is not None: - raise NotImplementedError - else: - head_mask = [None] * self.num_hidden_layers - # head_mask = tf.constant([0] * self.num_hidden_layers) - - position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) - - if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = self.tokens_embed(input_ids, mode="embedding") - position_embeds = tf.gather(self.positions_embed, position_ids) - if token_type_ids is not None: - token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) - check_embeddings_within_bounds(token_type_ids, self.config.vocab_size, "token_type_ids") - token_type_embeds = self.tokens_embed(token_type_ids, mode="embedding") - else: - token_type_embeds = 0 - hidden_states = inputs_embeds + position_embeds + token_type_embeds - hidden_states = self.drop(hidden_states, training=training) - - output_shape = input_shape + [shape_list(hidden_states)[-1]] - - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, block in enumerate(self.h): - if output_hidden_states: - all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) - - outputs = block( - hidden_states, - attention_mask, - head_mask[i], - output_attentions, - training=training, - ) - hidden_states = outputs[0] - if output_attentions: - all_attentions = all_attentions + (outputs[1],) - - hidden_states = tf.reshape(hidden_states, output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if output_attentions: - # let the number of heads free (-1) so we can extract attention even after head pruning - attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] - all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - - return TFBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - ) - - -class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = OpenAIGPTConfig - base_model_prefix = "transformer" - - -@dataclass -class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput): - """ - Base class for outputs of models predicting if two sentences are consecutive or not. - - Args: - logits (`tf.Tensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - mc_logits (`tf.Tensor` of shape `(batch_size, num_choices)`): - Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). - hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape - `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - logits: Optional[tf.Tensor] = None - mc_logits: Optional[tf.Tensor] = None - hidden_states: Tuple[tf.Tensor] | None = None - attentions: Tuple[tf.Tensor] | None = None - - -OPENAI_GPT_START_DOCSTRING = r""" - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`OpenAIGPTConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -OPENAI_GPT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`tf.Tensor` or `Numpy array` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - - -@add_start_docstrings( - "The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.", - OPENAI_GPT_START_DOCSTRING, -) -class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") - - @unpack_inputs - @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFBaseModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: Optional[bool] = False, - ) -> Union[Tuple, TFBaseModelOutput]: - outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - return outputs - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -@add_start_docstrings( - """ - OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - OPENAI_GPT_START_DOCSTRING, -) -class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelingLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") - # OpenAIGPT does not have past caching features - self.supports_xla_generation = False - - def get_output_embeddings(self): - return self.get_input_embeddings() - - def set_output_embeddings(self, value): - self.set_input_embeddings(value) - - @unpack_inputs - @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFCausalLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - labels: np.ndarray | tf.Tensor | None = None, - training: Optional[bool] = False, - ) -> Union[Tuple, TFCausalLMOutput]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = transformer_outputs[0] - - logits = self.transformer.tokens_embed(hidden_states, mode="linear") - - loss = None - if labels is not None: - # shift labels to the left and cut last logit token - shifted_logits = logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels, shifted_logits) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFCausalLMOutput( - loss=loss, - logits=logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def prepare_inputs_for_generation(self, inputs, **kwargs): - return {"input_ids": inputs} - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -@add_start_docstrings( - """ - OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for - RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the - input embeddings, the classification head takes as input the input of a specified classification token index in the - input sequence). - """, - OPENAI_GPT_START_DOCSTRING, -) -class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - config.num_labels = 1 - self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") - self.multiple_choice_head = TFSequenceSummary( - config, initializer_range=config.initializer_range, name="multiple_choice_head" - ) - - @unpack_inputs - @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - mc_token_ids: np.ndarray | tf.Tensor | None = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: Optional[bool] = False, - ) -> Union[Tuple, TFOpenAIGPTDoubleHeadsModelOutput]: - r""" - mc_token_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): - Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - - 1]`. - - Return: - - Examples: - - ```python - >>> import tensorflow as tf - >>> from transformers import AutoTokenizer, TFOpenAIGPTDoubleHeadsModel - - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt") - >>> model = TFOpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt") - - >>> # Add a [CLS] to the vocabulary (we should train it also!) - >>> tokenizer.add_special_tokens({"cls_token": "[CLS]"}) - >>> model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size - >>> print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary - - >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] - >>> encoding = tokenizer(choices, return_tensors="tf") - >>> inputs = {k: tf.expand_dims(v, 0) for k, v in encoding.items()} - >>> inputs["mc_token_ids"] = tf.constant( - ... [inputs["input_ids"].shape[-1] - 1, inputs["input_ids"].shape[-1] - 1] - ... )[ - ... None, : - ... ] # Batch size 1 - >>> outputs = model(inputs) - >>> lm_prediction_scores, mc_prediction_scores = outputs[:2] - ```""" - - if input_ids is not None: - input_shapes = shape_list(input_ids) - else: - input_shapes = shape_list(inputs_embeds)[:-1] - - seq_length = input_shapes[-1] - flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None - flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None - flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None - flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None - transformer_outputs = self.transformer( - flat_input_ids, - flat_attention_mask, - flat_token_type_ids, - flat_position_ids, - head_mask, - inputs_embeds, - output_attentions, - output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = transformer_outputs[0] - hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) - if return_dict and output_hidden_states: - # We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the - # input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged) - all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,) - else: - all_hidden_states = None - lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear") - mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) - mc_logits = tf.squeeze(mc_logits, axis=-1) - - if not return_dict: - return (lm_logits, mc_logits) + transformer_outputs[1:] - - return TFOpenAIGPTDoubleHeadsModelOutput( - logits=lm_logits, - mc_logits=mc_logits, - hidden_states=all_hidden_states, - attentions=transformer_outputs.attentions, - ) - - @property - def input_signature(self): - return { - "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"), - "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"), - "mc_token_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"), - } - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - if getattr(self, "multiple_choice_head", None) is not None: - with tf.name_scope(self.multiple_choice_head.name): - self.multiple_choice_head.build(None) - - -@add_start_docstrings( - """ - The OpenAI GPT Model transformer with a sequence classification head on top (linear layer). - - [`TFOpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal - models (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - OPENAI_GPT_START_DOCSTRING, -) -class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenceClassificationLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.num_labels = config.num_labels - self.score = keras.layers.Dense( - config.num_labels, - kernel_initializer=get_initializer(config.initializer_range), - name="score", - use_bias=False, - ) - self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") - self.config = config - - @unpack_inputs - @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def call( - self, - input_ids: TFModelInputType | None = None, - attention_mask: np.ndarray | tf.Tensor | None = None, - token_type_ids: np.ndarray | tf.Tensor | None = None, - position_ids: np.ndarray | tf.Tensor | None = None, - head_mask: np.ndarray | tf.Tensor | None = None, - inputs_embeds: np.ndarray | tf.Tensor | None = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - labels: np.ndarray | tf.Tensor | None = None, - training: Optional[bool] = False, - ) -> Union[Tuple, TFSequenceClassifierOutput]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - """ - transformer_outputs = self.transformer( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - logits_shape = shape_list(logits) - batch_size = logits_shape[0] - - if self.config.pad_token_id is None: - last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) - else: - if input_ids is not None: - token_indices = tf.range(shape_list(input_ids)[-1]) - non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype) - last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1) - else: - last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - loss = None - - pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1) - - if labels is not None: - if self.config.pad_token_id is None and logits_shape[0] != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - - loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels])) - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return TFSequenceClassifierOutput( - loss=loss, - logits=pooled_logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - def build(self, input_shape=None): - if self.built: - return - self.built = True - if getattr(self, "score", None) is not None: - with tf.name_scope(self.score.name): - self.score.build([None, None, self.config.n_embd]) - if getattr(self, "transformer", None) is not None: - with tf.name_scope(self.transformer.name): - self.transformer.build(None) - - -__all__ = [ - "TFOpenAIGPTDoubleHeadsModel", - "TFOpenAIGPTForSequenceClassification", - "TFOpenAIGPTLMHeadModel", - "TFOpenAIGPTMainLayer", - "TFOpenAIGPTModel", - "TFOpenAIGPTPreTrainedModel", -] diff --git a/src/transformers/models/openai/tokenization_openai.py b/src/transformers/models/openai/tokenization_openai.py deleted file mode 100644 index cbfb41fc888f..000000000000 --- a/src/transformers/models/openai/tokenization_openai.py +++ /dev/null @@ -1,396 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tokenization classes for OpenAI GPT.""" - -import json -import os -import re -import unicodedata -from typing import Optional, Tuple - -from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace -from ...utils import logging - - -logger = logging.get_logger(__name__) - -VOCAB_FILES_NAMES = { - "vocab_file": "vocab.json", - "merges_file": "merges.txt", -} - - -# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize -def whitespace_tokenize(text): - """Runs basic whitespace cleaning and splitting on a piece of text.""" - text = text.strip() - if not text: - return [] - tokens = text.split() - return tokens - - -# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer -class BasicTokenizer: - """ - Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). - - Args: - do_lower_case (`bool`, *optional*, defaults to `True`): - Whether or not to lowercase the input when tokenizing. - never_split (`Iterable`, *optional*): - Collection of tokens which will never be split during tokenization. Only has an effect when - `do_basic_tokenize=True` - tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): - Whether or not to tokenize Chinese characters. - - This should likely be deactivated for Japanese (see this - [issue](https://github.com/huggingface/transformers/issues/328)). - strip_accents (`bool`, *optional*): - Whether or not to strip all accents. If this option is not specified, then it will be determined by the - value for `lowercase` (as in the original BERT). - do_split_on_punc (`bool`, *optional*, defaults to `True`): - In some instances we want to skip the basic punctuation splitting so that later tokenization can capture - the full context of the words, such as contractions. - """ - - def __init__( - self, - do_lower_case=True, - never_split=None, - tokenize_chinese_chars=True, - strip_accents=None, - do_split_on_punc=True, - ): - if never_split is None: - never_split = [] - self.do_lower_case = do_lower_case - self.never_split = set(never_split) - self.tokenize_chinese_chars = tokenize_chinese_chars - self.strip_accents = strip_accents - self.do_split_on_punc = do_split_on_punc - - def tokenize(self, text, never_split=None): - """ - Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. - - Args: - never_split (`List[str]`, *optional*) - Kept for backward compatibility purposes. Now implemented directly at the base class level (see - [`PreTrainedTokenizer.tokenize`]) List of token not to split. - """ - # union() returns a new set by concatenating the two sets. - never_split = self.never_split.union(set(never_split)) if never_split else self.never_split - text = self._clean_text(text) - - # This was added on November 1st, 2018 for the multilingual and Chinese - # models. This is also applied to the English models now, but it doesn't - # matter since the English models were not trained on any Chinese data - # and generally don't have any Chinese data in them (there are Chinese - # characters in the vocabulary because Wikipedia does have some Chinese - # words in the English Wikipedia.). - if self.tokenize_chinese_chars: - text = self._tokenize_chinese_chars(text) - # prevents treating the same character with different unicode codepoints as different characters - unicode_normalized_text = unicodedata.normalize("NFC", text) - orig_tokens = whitespace_tokenize(unicode_normalized_text) - split_tokens = [] - for token in orig_tokens: - if token not in never_split: - if self.do_lower_case: - token = token.lower() - if self.strip_accents is not False: - token = self._run_strip_accents(token) - elif self.strip_accents: - token = self._run_strip_accents(token) - split_tokens.extend(self._run_split_on_punc(token, never_split)) - - output_tokens = whitespace_tokenize(" ".join(split_tokens)) - return output_tokens - - def _run_strip_accents(self, text): - """Strips accents from a piece of text.""" - text = unicodedata.normalize("NFD", text) - output = [] - for char in text: - cat = unicodedata.category(char) - if cat == "Mn": - continue - output.append(char) - return "".join(output) - - def _run_split_on_punc(self, text, never_split=None): - """Splits punctuation on a piece of text.""" - if not self.do_split_on_punc or (never_split is not None and text in never_split): - return [text] - chars = list(text) - i = 0 - start_new_word = True - output = [] - while i < len(chars): - char = chars[i] - if _is_punctuation(char): - output.append([char]) - start_new_word = True - else: - if start_new_word: - output.append([]) - start_new_word = False - output[-1].append(char) - i += 1 - - return ["".join(x) for x in output] - - def _tokenize_chinese_chars(self, text): - """Adds whitespace around any CJK character.""" - output = [] - for char in text: - cp = ord(char) - if self._is_chinese_char(cp): - output.append(" ") - output.append(char) - output.append(" ") - else: - output.append(char) - return "".join(output) - - def _is_chinese_char(self, cp): - """Checks whether CP is the codepoint of a CJK character.""" - # This defines a "chinese character" as anything in the CJK Unicode block: - # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) - # - # Note that the CJK Unicode block is NOT all Japanese and Korean characters, - # despite its name. The modern Korean Hangul alphabet is a different block, - # as is Japanese Hiragana and Katakana. Those alphabets are used to write - # space-separated words, so they are not treated specially and handled - # like the all of the other languages. - if ( - (cp >= 0x4E00 and cp <= 0x9FFF) - or (cp >= 0x3400 and cp <= 0x4DBF) # - or (cp >= 0x20000 and cp <= 0x2A6DF) # - or (cp >= 0x2A700 and cp <= 0x2B73F) # - or (cp >= 0x2B740 and cp <= 0x2B81F) # - or (cp >= 0x2B820 and cp <= 0x2CEAF) # - or (cp >= 0xF900 and cp <= 0xFAFF) - or (cp >= 0x2F800 and cp <= 0x2FA1F) # - ): # - return True - - return False - - def _clean_text(self, text): - """Performs invalid character removal and whitespace cleanup on text.""" - output = [] - for char in text: - cp = ord(char) - if cp == 0 or cp == 0xFFFD or _is_control(char): - continue - if _is_whitespace(char): - output.append(" ") - else: - output.append(char) - return "".join(output) - - -def get_pairs(word): - """ - Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length - strings) - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -def text_standardize(text): - """ - fixes some issues the spacy tokenizer had on books corpus also does some whitespace standardization - """ - text = text.replace("—", "-") - text = text.replace("–", "-") - text = text.replace("―", "-") - text = text.replace("…", "...") - text = text.replace("´", "'") - text = re.sub(r"""(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)""", r" \1 ", text) - text = re.sub(r"\s*\n\s*", " \n ", text) - text = re.sub(r"[^\S\n]+", " ", text) - return text.strip() - - -class OpenAIGPTTokenizer(PreTrainedTokenizer): - """ - Construct a GPT Tokenizer. Based on Byte-Pair-Encoding with the following peculiarities: - - - lowercases all inputs, - - uses `SpaCy` tokenizer and `ftfy` for pre-BPE tokenization if they are installed, fallback to BERT's - `BasicTokenizer` if not. - - This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to - this superclass for more information regarding those methods. - - Args: - vocab_file (`str`): - Path to the vocabulary file. - merges_file (`str`): - Path to the merges file. - unk_token (`str`, *optional*, defaults to `""`): - The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this - token instead. - """ - - vocab_files_names = VOCAB_FILES_NAMES - model_input_names = ["input_ids", "attention_mask"] - - def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): - try: - import ftfy - from spacy.lang.en import English - - _nlp = English() - self.nlp = _nlp.tokenizer - self.fix_text = ftfy.fix_text - except ImportError: - logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") - self.nlp = BasicTokenizer(do_lower_case=True) - self.fix_text = None - - with open(vocab_file, encoding="utf-8") as vocab_handle: - self.encoder = json.load(vocab_handle) - self.decoder = {v: k for k, v in self.encoder.items()} - with open(merges_file, encoding="utf-8") as merges_handle: - merges = merges_handle.read().split("\n")[1:-1] - merges = [tuple(merge.split()) for merge in merges] - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {} - - super().__init__(unk_token=unk_token, **kwargs) - - @property - def do_lower_case(self): - return True - - @property - def vocab_size(self): - return len(self.encoder) - - def get_vocab(self): - return dict(self.encoder, **self.added_tokens_encoder) - - def bpe(self, token): - word = tuple(token[:-1]) + (token[-1] + "",) - if token in self.cache: - return self.cache[token] - pairs = get_pairs(word) - - if not pairs: - return token + "" - - while True: - bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - except ValueError: - new_word.extend(word[i:]) - break - else: - new_word.extend(word[i:j]) - i = j - - if word[i] == first and i < len(word) - 1 and word[i + 1] == second: - new_word.append(first + second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = " ".join(word) - if word == "\n ": - word = "\n" - self.cache[token] = word - return word - - def _tokenize(self, text): - """Tokenize a string.""" - split_tokens = [] - if self.fix_text is None: - # Using BERT's BasicTokenizer - text = self.nlp.tokenize(text) - for token in text: - split_tokens.extend(list(self.bpe(token).split(" "))) - else: - # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) - text = self.nlp(text_standardize(self.fix_text(text))) - for token in text: - split_tokens.extend(list(self.bpe(token.text.lower()).split(" "))) - return split_tokens - - def _convert_token_to_id(self, token): - """Converts a token (str) in an id using the vocab.""" - return self.encoder.get(token, self.encoder.get(self.unk_token)) - - def _convert_id_to_token(self, index): - """Converts an id in a token (BPE) using the vocab.""" - return self.decoder.get(index, self.unk_token) - - def convert_tokens_to_string(self, tokens): - """Converts a sequence of tokens (string) in a single string.""" - out_string = "".join(tokens).replace("", " ").strip() - return out_string - - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: - if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) should be a directory") - return - vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] - ) - merge_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] - ) - - with open(vocab_file, "w", encoding="utf-8") as f: - f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") - - index = 0 - with open(merge_file, "w", encoding="utf-8") as writer: - writer.write("#version: 0.2\n") - for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): - if index != token_index: - logger.warning( - f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." - " Please check that the tokenizer is not corrupted!" - ) - index = token_index - writer.write(" ".join(bpe_tokens) + "\n") - index += 1 - - return vocab_file, merge_file - - -__all__ = ["OpenAIGPTTokenizer"] diff --git a/src/transformers/models/openai/tokenization_openai_fast.py b/src/transformers/models/openai/tokenization_openai_fast.py deleted file mode 100644 index c17d7d29b7dd..000000000000 --- a/src/transformers/models/openai/tokenization_openai_fast.py +++ /dev/null @@ -1,66 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Fast Tokenization classes for OpenAI GPT.""" - -from typing import Optional, Tuple - -from ...tokenization_utils_fast import PreTrainedTokenizerFast -from ...utils import logging -from .tokenization_openai import OpenAIGPTTokenizer - - -logger = logging.get_logger(__name__) - -VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} - - -class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast): - """ - Construct a "fast" GPT Tokenizer (backed by HuggingFace's *tokenizers* library). Based on Byte-Pair-Encoding with - the following peculiarities: - - - lower case all inputs - - uses BERT's BasicTokenizer for pre-BPE tokenization - - This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should - refer to this superclass for more information regarding those methods. - - Args: - vocab_file (`str`): - Path to the vocabulary file. - merges_file (`str`): - Path to the merges file. - unk_token (`str`, *optional*, defaults to `""`): - The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this - token instead. - """ - - vocab_files_names = VOCAB_FILES_NAMES - model_input_names = ["input_ids", "attention_mask"] - slow_tokenizer_class = OpenAIGPTTokenizer - - def __init__(self, vocab_file=None, merges_file=None, tokenizer_file=None, unk_token="", **kwargs): - super().__init__(vocab_file, merges_file, tokenizer_file=tokenizer_file, unk_token=unk_token, **kwargs) - - @property - def do_lower_case(self): - return True - - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: - files = self._tokenizer.model.save(save_directory, name=filename_prefix) - return tuple(files) - - -__all__ = ["OpenAIGPTTokenizerFast"] diff --git a/src/transformers/models/openai/__init__.py b/src/transformers/models/openai_moe/__init__.py similarity index 100% rename from src/transformers/models/openai/__init__.py rename to src/transformers/models/openai_moe/__init__.py diff --git a/src/transformers/models/openai/configuration_openai.py b/src/transformers/models/openai_moe/configuration_openai.py similarity index 100% rename from src/transformers/models/openai/configuration_openai.py rename to src/transformers/models/openai_moe/configuration_openai.py diff --git a/src/transformers/models/openai/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py similarity index 100% rename from src/transformers/models/openai/convert_openai_weights_to_hf.py rename to src/transformers/models/openai_moe/convert_openai_weights_to_hf.py diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai_moe/modeling_openai.py similarity index 100% rename from src/transformers/models/openai/modeling_openai.py rename to src/transformers/models/openai_moe/modeling_openai.py From 6dfb395cc2ab2cba2cbe996afa4b0f97a12bea42 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 05:03:59 +0000 Subject: [PATCH 024/342] rename --- .../{configuration_openai.py => configuration_openai_moe.py} | 0 .../openai_moe/{modeling_openai.py => modeling_openai_moe.py} | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename src/transformers/models/openai_moe/{configuration_openai.py => configuration_openai_moe.py} (100%) rename src/transformers/models/openai_moe/{modeling_openai.py => modeling_openai_moe.py} (99%) diff --git a/src/transformers/models/openai_moe/configuration_openai.py b/src/transformers/models/openai_moe/configuration_openai_moe.py similarity index 100% rename from src/transformers/models/openai_moe/configuration_openai.py rename to src/transformers/models/openai_moe/configuration_openai_moe.py diff --git a/src/transformers/models/openai_moe/modeling_openai.py b/src/transformers/models/openai_moe/modeling_openai_moe.py similarity index 99% rename from src/transformers/models/openai_moe/modeling_openai.py rename to src/transformers/models/openai_moe/modeling_openai_moe.py index 5ad61855e362..3989d693c6f8 100644 --- a/src/transformers/models/openai_moe/modeling_openai.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -36,7 +36,7 @@ LlamaRotaryEmbedding, repeat_kv, ) -from .configuration_openai import OpenaiConfig +from .configuration_openai_moe import OpenaiConfig logger = logging.get_logger(__name__) From ef4da0a6a05878ab5aee08954695cdb5c24f8bae Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 05:06:42 +0000 Subject: [PATCH 025/342] more naming issues haha --- src/transformers/models/openai_moe/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/__init__.py b/src/transformers/models/openai_moe/__init__.py index f777d1981929..b35ca7cd1dc6 100644 --- a/src/transformers/models/openai_moe/__init__.py +++ b/src/transformers/models/openai_moe/__init__.py @@ -18,8 +18,8 @@ if TYPE_CHECKING: - from .configuration_openai import * - from .modeling_openai import * + from .configuration_openai_moe import * + from .modeling_openai_moe import * else: import sys From 99b4403b6c065f231f64b28f166d8fb359e80a90 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 05:08:18 +0000 Subject: [PATCH 026/342] revert removal of the foldree --- .../models/openai/configuration_openai.py | 156 ++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 src/transformers/models/openai/configuration_openai.py diff --git a/src/transformers/models/openai/configuration_openai.py b/src/transformers/models/openai/configuration_openai.py new file mode 100644 index 000000000000..7dc1525fe0ef --- /dev/null +++ b/src/transformers/models/openai/configuration_openai.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenAI GPT configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class OpenAIGPTConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`OpenAIGPTModel`] or a [`TFOpenAIGPTModel`]. It is + used to instantiate a GPT model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the GPT + [openai-community/openai-gpt](https://huggingface.co/openai-community/openai-gpt) architecture from OpenAI. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 40478): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`OpenAIGPTModel`] or [`TFOpenAIGPTModel`]. + n_positions (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + afn (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`int`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + summary_type (`str`, *optional*, defaults to `"cls_index"`): + Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and + [`OpenAIGPTDoubleHeadsModel`]. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and + [`OpenAIGPTDoubleHeadsModel`]. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and + [`OpenAIGPTDoubleHeadsModel`]. + + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and + [`OpenAIGPTDoubleHeadsModel`]. + + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and + [`OpenAIGPTDoubleHeadsModel`]. + + The dropout ratio to be used after the projection and activation. + + + Examples: + + ```python + >>> from transformers import OpenAIGPTConfig, OpenAIGPTModel + + >>> # Initializing a GPT configuration + >>> configuration = OpenAIGPTConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = OpenAIGPTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "openai-gpt" + attribute_map = { + "max_position_embeddings": "n_positions", + "hidden_size": "n_embd", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=40478, + n_positions=512, + n_embd=768, + n_layer=12, + n_head=12, + afn="gelu", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.afn = afn + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + super().__init__(**kwargs) + + +__all__ = ["OpenAIGPTConfig"] \ No newline at end of file From 5e3ee5c54368e6c73c3b5ac31d1cffe1b6b7b7bd Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 05:19:19 +0000 Subject: [PATCH 027/342] super small nit bias needs expansion as we repeated the inputs --- src/transformers/models/openai_moe/modeling_openai_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 3989d693c6f8..cdf694a2c4cc 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -79,9 +79,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: torch.Tensor """ hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[:, None, :] swiglu = swiglu(gate_up, alpha=1.702) - next_states = torch.bmm(swiglu, self.down_proj) + self.down_proj_bias + next_states = torch.bmm(swiglu, self.down_proj) + self.down_proj_bias[:,None,:] return next_states From 26e7aa8612b507b0ed10efddc351e5b38349e047 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 05:21:53 +0000 Subject: [PATCH 028/342] nit, swiglu fn was getting erase --- src/transformers/models/openai_moe/modeling_openai_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index cdf694a2c4cc..04e3ec99dd27 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -80,8 +80,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[:, None, :] - swiglu = swiglu(gate_up, alpha=1.702) - next_states = torch.bmm(swiglu, self.down_proj) + self.down_proj_bias[:,None,:] + swiglu_ = swiglu(gate_up) + next_states = torch.bmm(swiglu_, self.down_proj) + self.down_proj_bias[:,None,:] return next_states From b9cf11ea0796cf060528cbaf75cf21c5bc093713 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 05:40:59 +0000 Subject: [PATCH 029/342] nits --- .../openai_moe/configuration_openai_moe.py | 5 +- .../models/openai_moe/modeling_openai_moe.py | 60 +++- .../models/openai_moe/modular_openai_moe.py | 257 ++++++++++++++++++ 3 files changed, 313 insertions(+), 9 deletions(-) create mode 100644 src/transformers/models/openai_moe/modular_openai_moe.py diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 147047e60a27..deda53a27111 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -64,6 +64,8 @@ def __init__( rope_scaling=None, attention_dropout: float = 0.0, num_experts_per_tok=4, + router_aux_loss_coef: float = 0.9, + output_router_logits=False, **kwargs, ): self.vocab_size = vocab_size @@ -95,7 +97,8 @@ def __init__( self.attention_bias = True self.mlp_bias = False self.max_position_embeddings = 8192 - + self.router_aux_loss_coef = router_aux_loss_coef + self.output_router_logits = output_router_logits super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 04e3ec99dd27..436b62d5d3c4 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -17,8 +17,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional - +from typing import Optional, Tuple +from ...cache_utils import Cache import torch import torch.utils.checkpoint from torch import nn @@ -36,6 +36,7 @@ LlamaRotaryEmbedding, repeat_kv, ) +from ..mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel from .configuration_openai_moe import OpenaiConfig @@ -100,13 +101,13 @@ def forward(self, hidden_states): router_logits = self.router(hidden_states) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1, sorted=True) router_scores = ( - torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1) + torch.full_like(router_logits, float(0)).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) routed_in = hidden_states.repeat(self.num_local_experts, 1) routed_out = self.experts(routed_in) - routed_out = routed_out.view(-1, self.hidden_size) * router_scores.reshape(self.num_local_experts, -1, 1) - hidden_states = routed_out.sum(dim=0) + routed_out = routed_out * router_scores.reshape(self.num_local_experts, -1, 1) + hidden_states = routed_out.sum(dim=0)[None, ...] return hidden_states, router_scores @@ -187,7 +188,48 @@ def __init__(self, config: OpenaiConfig, layer_idx: int): self.input_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if kwargs.get("output_router_logits", False): + outputs += (router_logits,) + return outputs + class OpenaiPreTrainedModel(LlamaPreTrainedModel): config_class = OpenaiConfig base_model_prefix = "model" @@ -195,7 +237,9 @@ class OpenaiPreTrainedModel(LlamaPreTrainedModel): _no_split_modules = ["OpenaiDecoderLayer"] -class OpenaiModel(LlamaModel): +class OpenaiModel(MixtralModel, OpenaiPreTrainedModel): + _no_split_modules = ["OpenaiDecoderLayer"] + def __init__(self, config: OpenaiConfig): super().__init__(config) self.rope = OpenaiRotaryEmbedding(config) @@ -203,7 +247,7 @@ def __init__(self, config: OpenaiConfig): self.norm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) -class OpenaiForCausalLM(LlamaForCausalLM): +class OpenaiForCausalLM(MixtralForCausalLM, OpenaiPreTrainedModel): def __init__(self, config: OpenaiConfig): super().__init__(config) self.model = OpenaiModel(config) diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py new file mode 100644 index 000000000000..77f40546be27 --- /dev/null +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -0,0 +1,257 @@ +# coding=utf-8 +# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple +from ...cache_utils import Cache +import torch +from torch import nn +from ...activations import ACT2FN +from ...integrations.flex_attention import flex_attention_forward +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...utils import logging +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + repeat_kv, +) +from ..mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel +from .configuration_openai_moe import OpenaiConfig + + +logger = logging.get_logger(__name__) + + +class OpenaiRMSNorm(LlamaRMSNorm): + pass + + +def swiglu(x, alpha: float = 1.702): + # Note we add an extra bias of 1 to the linear layer + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return out_glu * (x_linear + 1) + +class OpenaiExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + This should really not be run on a single machine, as we are reaching compute bound: + - the inputs are expected to be "sorted" per expert already. + - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape + + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, top_k) + Returns: + torch.Tensor + """ + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[:, None, :] + swiglu_ = swiglu(gate_up) + next_states = torch.bmm(swiglu_, self.down_proj) + self.down_proj_bias[:,None,:] + return next_states + + +class OpenaiMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.hidden_dim = config.hidden_size + self.num_local_experts = config.num_local_experts + self.experts = OpenaiExperts(config) + self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=True) + + def forward(self, hidden_states): + # we don't slice weight as its not compile compatible + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = self.router(hidden_states) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1, sorted=True) + router_scores = ( + torch.full_like(router_logits, float(0)).scatter_(1, router_indices, router_top_value).transpose(0, 1) + ) + + routed_in = hidden_states.repeat(self.num_local_experts, 1) + routed_out = self.experts(routed_in) + routed_out = routed_out * router_scores.reshape(self.num_local_experts, -1, 1) + hidden_states = routed_out.sum(dim=0)[None, ...] + return hidden_states, router_scores + + +class OpenaiRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = torch.cat([attn_weights, module.sink], dim=-1) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def openai_flex_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + sink = module.sink + + def attention_sink(score, b, h, q_idx, kv_idx): + score = torch.cat([score, sink], dim=-1) + return score + + return flex_attention_forward( + module, + query, + key, + value, + attention_mask, + scaling=scaling, + dropout=dropout, + attention_sink=attention_sink, + score_mod=attention_sink, + **kwargs, + ) + + +ALL_ATTENTION_FUNCTIONS.register("openai_flex_attention", openai_flex_attention_forward) + + +class OpenaiAttention(LlamaAttention): + def __init__(self, config: OpenaiConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.register_buffer("sinks", torch.empty(config.num_attention_heads), persistent=True) + +class OpenaiDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: OpenaiConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.hidden_size = config.hidden_size + self.self_attn = OpenaiAttention(config=config, layer_idx=layer_idx) + self.mlp = OpenaiMLP(config) + self.input_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if kwargs.get("output_router_logits", False): + outputs += (router_logits,) + return outputs + +class OpenaiPreTrainedModel(LlamaPreTrainedModel): + config_class = OpenaiConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OpenaiDecoderLayer"] + + +class OpenaiModel(MixtralModel, OpenaiPreTrainedModel): + _no_split_modules = ["OpenaiDecoderLayer"] + + def __init__(self, config: OpenaiConfig): + super().__init__(config) + self.rope = OpenaiRotaryEmbedding(config) + self.layers = nn.ModuleList([OpenaiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class OpenaiForCausalLM(MixtralForCausalLM, OpenaiPreTrainedModel): + def __init__(self, config: OpenaiConfig): + super().__init__(config) + self.model = OpenaiModel(config) + + +__all__ = [ + "OpenaiForCausalLM", + "OpenaiModel", + "OpenaiPreTrainedModel", +] From cf79feda0de0f3ce85adb2278104848d702f600d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 05:47:32 +0000 Subject: [PATCH 030/342] single model single file --- .../models/openai_moe/modeling_openai_moe.py | 650 ++++++++++++++++-- .../models/openai_moe/modular_openai_moe.py | 4 +- 2 files changed, 582 insertions(+), 72 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 436b62d5d3c4..29f9373fe558 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/openai_moe/modular_openai_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_openai_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. # @@ -17,34 +23,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple -from ...cache_utils import Cache +from functools import partial +from typing import Callable, List, Optional, Tuple, Union + import torch -import torch.utils.checkpoint from torch import nn + from ...activations import ACT2FN -from ...integrations.flex_attention import flex_attention_forward -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...utils import logging -from ..llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaForCausalLM, - LlamaModel, - LlamaPreTrainedModel, - LlamaRMSNorm, - LlamaRotaryEmbedding, - repeat_kv, -) -from ..mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_openai_moe import OpenaiConfig logger = logging.get_logger(__name__) -class OpenaiRMSNorm(LlamaRMSNorm): - pass +@use_kernel_forward_from_hub("RMSNorm") +class OpenaiRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + OpenaiRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" def swiglu(x, alpha: float = 1.702): @@ -53,6 +74,7 @@ def swiglu(x, alpha: float = 1.702): out_glu = x_glu * torch.sigmoid(alpha * x_glu) return out_glu * (x_linear + 1) + class OpenaiExperts(nn.Module): def __init__(self, config): super().__init__() @@ -82,7 +104,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[:, None, :] swiglu_ = swiglu(gate_up) - next_states = torch.bmm(swiglu_, self.down_proj) + self.down_proj_bias[:,None,:] + next_states = torch.bmm(swiglu_, self.down_proj) + self.down_proj_bias[:, None, :] return next_states @@ -111,8 +133,84 @@ def forward(self, hidden_states): return hidden_states, router_scores -class OpenaiRotaryEmbedding(LlamaRotaryEmbedding): - pass +class OpenaiRotaryEmbedding(nn.Module): + def __init__(self, config: OpenaiConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( @@ -141,47 +239,87 @@ def eager_attention_forward( return attn_output, attn_weights -def openai_flex_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - sink = module.sink - - def attention_sink(score, b, h, q_idx, kv_idx): - score = torch.cat([score, sink], dim=-1) - return score - - return flex_attention_forward( - module, - query, - key, - value, - attention_mask, - scaling=scaling, - dropout=dropout, - attention_sink=attention_sink, - score_mod=attention_sink, - **kwargs, - ) +class OpenaiAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + def __init__(self, config: OpenaiConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.register_buffer("sinks", torch.empty(config.num_attention_heads), persistent=True) -ALL_ATTENTION_FUNCTIONS.register("openai_flex_attention", openai_flex_attention_forward) + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights -class OpenaiAttention(LlamaAttention): - def __init__(self, config: OpenaiConfig, layer_idx: int): - super().__init__(config, layer_idx) - self.register_buffer("sinks", torch.empty(config.num_attention_heads), persistent=True) -class OpenaiDecoderLayer(LlamaDecoderLayer): +class OpenaiDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: OpenaiConfig, layer_idx: int): - super().__init__(config, layer_idx) + super().__init__() self.hidden_size = config.hidden_size self.self_attn = OpenaiAttention(config=config, layer_idx=layer_idx) self.mlp = OpenaiMLP(config) @@ -229,32 +367,404 @@ def forward( if kwargs.get("output_router_logits", False): outputs += (router_logits,) return outputs - -class OpenaiPreTrainedModel(LlamaPreTrainedModel): + + +@auto_docstring +class OpenaiPreTrainedModel(PreTrainedModel): config_class = OpenaiConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["OpenaiDecoderLayer"] - - -class OpenaiModel(MixtralModel, OpenaiPreTrainedModel): + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, OpenaiRMSNorm): + module.weight.data.fill_(1.0) + + +@auto_docstring +class OpenaiModel(OpenaiPreTrainedModel): _no_split_modules = ["OpenaiDecoderLayer"] def __init__(self, config: OpenaiConfig): super().__init__(config) - self.rope = OpenaiRotaryEmbedding(config) - self.layers = nn.ModuleList([OpenaiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [OpenaiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.norm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = OpenaiRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.rope = OpenaiRotaryEmbedding(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) -class OpenaiForCausalLM(MixtralForCausalLM, OpenaiPreTrainedModel): +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): + ... + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class OpenaiForCausalLM(OpenaiPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + def __init__(self, config: OpenaiConfig): super().__init__(config) + # self.model = OpenaiModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok self.model = OpenaiModel(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, OpenaiForCausalLM + + >>> model = OpenaiForCausalLM.from_pretrained("mistralai/Openai-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Openai-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + -__all__ = [ - "OpenaiForCausalLM", - "OpenaiModel", - "OpenaiPreTrainedModel", -] +__all__ = ["OpenaiForCausalLM", "OpenaiModel", "OpenaiPreTrainedModel"] diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 77f40546be27..2a229f23bcbb 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -148,10 +148,10 @@ def openai_flex_attention_forward( dropout: float = 0.0, **kwargs, ): - sink = module.sink + sinks = module.sinks def attention_sink(score, b, h, q_idx, kv_idx): - score = torch.cat([score, sink], dim=-1) + score = torch.cat([score, sinks], dim=-1) return score return flex_attention_forward( From 5b62fb0e0e8359ef38a1627292dd9f2b22918903 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 05:50:18 +0000 Subject: [PATCH 031/342] nice --- .../models/openai_moe/configuration_openai_moe.py | 5 +++++ src/transformers/models/openai_moe/modeling_openai_moe.py | 5 ++--- src/transformers/models/openai_moe/modular_openai_moe.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index deda53a27111..15382ba9dec9 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -24,6 +24,11 @@ class OpenaiConfig(PretrainedConfig): + r""" + This will yield a configuration to that of the BERT + [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture. + + """ model_type = "openai-moe" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `OpenaiModel` diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 29f9373fe558..e039877b74b4 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -231,7 +231,7 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = torch.cat([attn_weights, module.sink], dim=-1) + attn_weights = torch.cat([attn_weights, module.sinks], dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) @@ -642,13 +642,12 @@ class OpenaiForCausalLM(OpenaiPreTrainedModel, GenerationMixin): def __init__(self, config: OpenaiConfig): super().__init__(config) - # self.model = OpenaiModel(config) + self.model = OpenaiModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok - self.model = OpenaiModel(config) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 2a229f23bcbb..ea16963c0143 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -130,7 +130,7 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = torch.cat([attn_weights, module.sink], dim=-1) + attn_weights = torch.cat([attn_weights, module.sinks], dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) From 11c431393c7bc20537da0b289e4586d0a2373afa Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 06:11:35 +0000 Subject: [PATCH 032/342] small fixes to the sinks! --- .../models/openai_moe/modeling_openai_moe.py | 5 +++-- src/transformers/models/openai_moe/modular_openai_moe.py | 9 +++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index e039877b74b4..92d60df88dbf 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -225,16 +225,17 @@ def eager_attention_forward( ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) + sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, query.shape[-2], -1) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = torch.cat([attn_weights, module.sinks], dim=-1) + attn_weights = torch.cat([attn_weights, sinks], dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = torch.matmul(attn_weights[...,:-1], value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index ea16963c0143..89d118eb5de6 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -124,16 +124,17 @@ def eager_attention_forward( ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) + sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, query.shape[-2], -1) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = torch.cat([attn_weights, module.sinks], dim=-1) + attn_weights = torch.cat([attn_weights, sinks], dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = torch.matmul(attn_weights[...,:-1], value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights @@ -148,12 +149,12 @@ def openai_flex_attention_forward( dropout: float = 0.0, **kwargs, ): - sinks = module.sinks + sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, key.shape[-2],-1) def attention_sink(score, b, h, q_idx, kv_idx): score = torch.cat([score, sinks], dim=-1) return score - + # TODO I need to remove the -1 sinks return flex_attention_forward( module, query, From 425dde039f9a9b15fb8035bddfe651bd1bd38390 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 06:16:39 +0000 Subject: [PATCH 033/342] fix default rms_norm_eps --- src/transformers/models/openai_moe/configuration_openai_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 15382ba9dec9..d4ce6b703fe2 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -62,7 +62,7 @@ def __init__( tie_word_embeddings=False, hidden_act: str = "silu", initializer_range: float = 0.02, - rms_norm_eps: float = 1e-6, + rms_norm_eps: float = 1e-5, pad_token_id: int = 0, bos_token_id: int = 1, eos_token_id: int = 2, From 9a59a418a6bc010d56eb2ca32e86af7c0d4f6791 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 06:41:14 +0000 Subject: [PATCH 034/342] because we are repeating the inputs vs slicing them for compile, I think we need a sigmoid bias --- .../models/openai_moe/modeling_openai_moe.py | 31 +++++++++++-------- .../models/openai_moe/modular_openai_moe.py | 2 +- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 92d60df88dbf..d872bd17f7f6 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -42,7 +42,7 @@ from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_openai_moe import OpenaiConfig - +import math logger = logging.get_logger(__name__) @@ -68,10 +68,10 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -def swiglu(x, alpha: float = 1.702): +def swiglu(x, sigmoid_bias, alpha: float = 1.702): # Note we add an extra bias of 1 to the linear layer x_glu, x_linear = torch.chunk(x, 2, dim=-1) - out_glu = x_glu * torch.sigmoid(alpha * x_glu) + out_glu = x_glu * torch.sigmoid((alpha * x_glu) + sigmoid_bias.view(x_glu.shape)) return out_glu * (x_linear + 1) @@ -88,7 +88,7 @@ def __init__(self, config): self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, sigmoid_bias: torch.Tensor) -> torch.Tensor: """ This should really not be run on a single machine, as we are reaching compute bound: - the inputs are expected to be "sorted" per expert already. @@ -103,7 +103,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[:, None, :] - swiglu_ = swiglu(gate_up) + swiglu_ = swiglu(gate_up, sigmoid_bias) next_states = torch.bmm(swiglu_, self.down_proj) + self.down_proj_bias[:, None, :] return next_states @@ -121,33 +121,38 @@ def forward(self, hidden_states): # we don't slice weight as its not compile compatible hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = self.router(hidden_states) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1, sorted=True) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) router_scores = ( torch.full_like(router_logits, float(0)).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) routed_in = hidden_states.repeat(self.num_local_experts, 1) - routed_out = self.experts(routed_in) + sigmoid_bias = ( + torch.full_like(routed_in, float("-inf")).scatter_(1, router_indices, router_top_value) + ).contiguous() + + routed_out = self.experts(routed_in, sigmoid_bias) routed_out = routed_out * router_scores.reshape(self.num_local_experts, -1, 1) hidden_states = routed_out.sum(dim=0)[None, ...] return hidden_states, router_scores class OpenaiRotaryEmbedding(nn.Module): + rope_type = "default" def __init__(self, config: OpenaiConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) else: - self.rope_type = "default" + self.attention_scaling = 1.0 + inv_freq = 1.0 / (config.rope_theta ** ( torch.arange(0, config.head_dim, 2, dtype=torch.float32)/ config.head_dim)) self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.config = config self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -249,7 +254,7 @@ def __init__(self, config: OpenaiConfig, layer_idx: int): self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 + self.scaling = 1 / math.sqrt(self.head_dim) self.attention_dropout = config.attention_dropout self.is_causal = True diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 89d118eb5de6..4da11e3dcd19 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -96,7 +96,7 @@ def forward(self, hidden_states): # we don't slice weight as its not compile compatible hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = self.router(hidden_states) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1, sorted=True) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) router_scores = ( torch.full_like(router_logits, float(0)).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) From 2195b668e482df327d0fd02e480320455b30c0a2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 09:44:33 +0000 Subject: [PATCH 035/342] sink is in the keys! --- .../openai_moe/configuration_openai_moe.py | 2 + .../convert_openai_weights_to_hf.py | 26 +++--- .../models/openai_moe/modeling_openai_moe.py | 79 +++++++++---------- 3 files changed, 51 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index d4ce6b703fe2..3ac771b2c6ef 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -71,6 +71,7 @@ def __init__( num_experts_per_tok=4, router_aux_loss_coef: float = 0.9, output_router_logits=False, + use_cache=True, **kwargs, ): self.vocab_size = vocab_size @@ -104,6 +105,7 @@ def __init__( self.max_position_embeddings = 8192 self.router_aux_loss_coef = router_aux_loss_coef self.output_router_logits = output_router_logits + self.use_cache = use_cache super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 9fcec5f6c085..f4fc63d0df2c 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -113,13 +113,13 @@ def write_model( q_key = re.sub(r"qkv_proj", "q_proj", new_key) k_key = re.sub(r"qkv_proj", "k_proj", new_key) v_key = re.sub(r"qkv_proj", "v_proj", new_key) - state_dict[q_key] = q - state_dict[k_key] = k - state_dict[v_key] = v - elif re.search("gate_up_proj", new_key) and "bias" not in new_key: - state_dict[new_key] = final_[key].permute(0,2,1) + state_dict[q_key] = q.contiguous().to(torch.bfloat16) + state_dict[k_key] = k.contiguous().to(torch.bfloat16) + state_dict[v_key] = v.contiguous().to(torch.bfloat16) + elif re.search("gate_up_proj|down_proj", new_key) and "bias" not in new_key: + state_dict[new_key] = final_[key].permute(0,2,1).contiguous() # einsum in orignal, I use bmm else: - state_dict[new_key] = final_[key].to(torch.bfloat16) # TODO slow, let's rmeove + state_dict[new_key] = final_[key].contiguous().to(torch.bfloat16) # TODO slow, let's rmeove del final_ gc.collect() @@ -223,7 +223,7 @@ def __init__( special_tokens = tokenizer._special_tokens.keys() self.additional_special_tokens = {'<|endoftext|>': 199999, '<|endofprompt|>': 200018} - for k in range(200000, 200018): + for k in range(199999, 200018): self.additional_special_tokens[f"<|reserved_{k}|>"] = k sorted_list = sorted(self.additional_special_tokens.items(), key=lambda x: x[1]) self.additional_special_tokens = [k[0] for k in sorted_list] @@ -315,12 +315,12 @@ def main(): help="Whether the model is an instruct model", ) args = parser.parse_args() - # write_model( - # model_path=args.output_dir, - # input_base_path=args.input_dir, - # safe_serialization=args.safe_serialization, - # instruct=args.instruct, - # ) + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + safe_serialization=args.safe_serialization, + instruct=args.instruct, + ) write_tokenizer( tokenizer_path="o200k_base", diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index d872bd17f7f6..959271f82bad 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -121,6 +121,7 @@ def forward(self, hidden_states): # we don't slice weight as its not compile compatible hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = self.router(hidden_states) + return alternative(self, hidden_states, router_logits) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) router_scores = ( torch.full_like(router_logits, float(0)).scatter_(1, router_indices, router_top_value).transpose(0, 1) @@ -128,27 +129,45 @@ def forward(self, hidden_states): routed_in = hidden_states.repeat(self.num_local_experts, 1) sigmoid_bias = ( - torch.full_like(routed_in, float("-inf")).scatter_(1, router_indices, router_top_value) + torch.full_like(routed_in, float("-inf")).scatter_(1, router_indices, torch.full_like(routed_in, 0.0)) ).contiguous() routed_out = self.experts(routed_in, sigmoid_bias) - routed_out = routed_out * router_scores.reshape(self.num_local_experts, -1, 1) - hidden_states = routed_out.sum(dim=0)[None, ...] - return hidden_states, router_scores + routed_out = routed_out * router_scores.view(self.num_local_experts, -1, 1) + output_states = routed_out.sum(dim=0)[None, ...] + return output_states, router_scores + + + +def swiglu2(x, alpha: float = 1.702): + # Note we add an extra bias of 1 to the linear layer + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return out_glu * (x_linear + 1) + + +def alternative(self, hidden_states, router_logits): + experts = torch.topk(router_logits, k=4, dim=-1, sorted=True) + expert_weights = torch.nn.functional.softmax(experts.values, dim=1) + expert_indices = experts.indices + mlp1_weight = self.experts.gate_up_proj[expert_indices, ...] + mlp1_bias = self.experts.gate_up_proj_bias[expert_indices, ...] + t = torch.einsum("beck,bk->bec", mlp1_weight.permute(0,1,3,2), hidden_states) + mlp1_bias + t = swiglu2(t) + mlp2_weight = self.experts.down_proj[expert_indices, ...] + mlp2_bias = self.experts.down_proj_bias[expert_indices, ...] + t = torch.einsum("beck,bek->bec", mlp2_weight.permute(0,1,3,2), t) + t += mlp2_bias + t = torch.einsum("bec,be->bc", t, expert_weights) + return t, router_logits class OpenaiRotaryEmbedding(nn.Module): rope_type = "default" def __init__(self, config: OpenaiConfig, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - else: - self.attention_scaling = 1.0 - inv_freq = 1.0 / (config.rope_theta ** ( torch.arange(0, config.head_dim, 2, dtype=torch.float32)/ config.head_dim)) + self.attention_scaling = 1.0 + inv_freq = 1.0 / (config.rope_theta ** ( torch.arange(0, config.head_dim, 2, dtype=torch.float32)/ config.head_dim)) self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings @@ -180,25 +199,6 @@ def rotate_half(x): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) @@ -230,7 +230,7 @@ def eager_attention_forward( ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, query.shape[-2], -1) + sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, query.shape[-2], -1) # TODO make sure the sink is like a new token attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: @@ -238,9 +238,9 @@ def eager_attention_forward( attn_weights = attn_weights + causal_mask attn_weights = torch.cat([attn_weights, sinks], dim=-1) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights[...,:-1], value_states) + attn_output = torch.matmul(attn_weights[...,:-1], value_states) # ignore the sinks attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights @@ -297,15 +297,8 @@ def forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + # if self.config._attn_implementation != "eager": + # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, From 97e3706b5fcdc0c63bbd1c8ffbab9fbc5d9a0782 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 12:57:32 +0000 Subject: [PATCH 036/342] forgot about 1 softmax! --- .../models/openai_moe/modeling_openai_moe.py | 82 ++++++++----------- 1 file changed, 34 insertions(+), 48 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 959271f82bad..375cc6240f23 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -62,19 +62,12 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return (self.weight * hidden_states).to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -def swiglu(x, sigmoid_bias, alpha: float = 1.702): - # Note we add an extra bias of 1 to the linear layer - x_glu, x_linear = torch.chunk(x, 2, dim=-1) - out_glu = x_glu * torch.sigmoid((alpha * x_glu) + sigmoid_bias.view(x_glu.shape)) - return out_glu * (x_linear + 1) - - class OpenaiExperts(nn.Module): def __init__(self, config): super().__init__() @@ -86,9 +79,10 @@ def __init__(self, config): self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) - self.act_fn = ACT2FN[config.hidden_act] + self.act_fn = torch.nn.Sigmoid() + self.alpha = 1.702 - def forward(self, hidden_states: torch.Tensor, sigmoid_bias: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ This should really not be run on a single machine, as we are reaching compute bound: - the inputs are expected to be "sorted" per expert already. @@ -102,12 +96,14 @@ def forward(self, hidden_states: torch.Tensor, sigmoid_bias: torch.Tensor) -> to torch.Tensor """ hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[:, None, :] - swiglu_ = swiglu(gate_up, sigmoid_bias) - next_states = torch.bmm(swiglu_, self.down_proj) + self.down_proj_bias[:, None, :] + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[...,None, :] + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm(((up + 1) * self.act_fn(gate * self.alpha)), self.down_proj) + self.down_proj_bias[..., None, :] + next_states = next_states.view(-1, self.hidden_size) return next_states + class OpenaiMLP(nn.Module): def __init__(self, config): super().__init__() @@ -121,46 +117,23 @@ def forward(self, hidden_states): # we don't slice weight as its not compile compatible hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = self.router(hidden_states) - return alternative(self, hidden_states, router_logits) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) router_scores = ( torch.full_like(router_logits, float(0)).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) - routed_in = hidden_states.repeat(self.num_local_experts, 1) - sigmoid_bias = ( - torch.full_like(routed_in, float("-inf")).scatter_(1, router_indices, torch.full_like(routed_in, 0.0)) - ).contiguous() - - routed_out = self.experts(routed_in, sigmoid_bias) - routed_out = routed_out * router_scores.view(self.num_local_experts, -1, 1) + routed_out = self.experts(routed_in) + routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None] output_states = routed_out.sum(dim=0)[None, ...] return output_states, router_scores +"""" - - -def swiglu2(x, alpha: float = 1.702): - # Note we add an extra bias of 1 to the linear layer - x_glu, x_linear = torch.chunk(x, 2, dim=-1) - out_glu = x_glu * torch.sigmoid(alpha * x_glu) - return out_glu * (x_linear + 1) - - -def alternative(self, hidden_states, router_logits): - experts = torch.topk(router_logits, k=4, dim=-1, sorted=True) - expert_weights = torch.nn.functional.softmax(experts.values, dim=1) - expert_indices = experts.indices - mlp1_weight = self.experts.gate_up_proj[expert_indices, ...] - mlp1_bias = self.experts.gate_up_proj_bias[expert_indices, ...] - t = torch.einsum("beck,bk->bec", mlp1_weight.permute(0,1,3,2), hidden_states) + mlp1_bias - t = swiglu2(t) - mlp2_weight = self.experts.down_proj[expert_indices, ...] - mlp2_bias = self.experts.down_proj_bias[expert_indices, ...] - t = torch.einsum("beck,bek->bec", mlp2_weight.permute(0,1,3,2), t) - t += mlp2_bias - t = torch.einsum("bec,be->bc", t, expert_weights) - return t, router_logits - +tensor([[ 7.9688, 0.8750, -0.3535, ..., 3.9219, 9.1875, -2.8906], + [ 2.2500, -15.9375, -2.7500, ..., 18.5000, -1.5312, -18.1250], + [-13.5000, -6.1562, 13.5625, ..., -22.1250, -30.3750, 23.6250], + [ 6.0000, 2.5469, 3.6094, ..., -7.9375, 4.2500, -7.7812]], +""" class OpenaiRotaryEmbedding(nn.Module): rope_type = "default" @@ -185,6 +158,7 @@ def forward(self, x, position_ids): with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) + emb = freqs cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -198,11 +172,23 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + # cos = cos.unsqueeze(-2).to(x.dtype) + # sin = sin.unsqueeze(-2).to(x.dtype) + x1, x2 = torch.chunk(x, 2, dim=-1) + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + return torch.cat((o1, o2), dim=-1) + def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = _apply_rotary_emb(q, cos, sin) + k_embed = _apply_rotary_emb(k, cos, sin) return q_embed, k_embed @@ -231,7 +217,7 @@ def eager_attention_forward( key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, query.shape[-2], -1) # TODO make sure the sink is like a new token - + sinks.zero_() attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] From 8794bb1c29aefec7c5906768e0ff12d459361ee2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 13:37:00 +0000 Subject: [PATCH 037/342] fix glu --- .../models/openai_moe/modeling_openai_moe.py | 49 ++++++------------- 1 file changed, 15 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 375cc6240f23..4250914beeca 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -98,7 +98,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[...,None, :] gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors - next_states = torch.bmm(((up + 1) * self.act_fn(gate * self.alpha)), self.down_proj) + self.down_proj_bias[..., None, :] + glu = gate * self.act_fn(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), self.down_proj) + self.down_proj_bias[..., None, :] next_states = next_states.view(-1, self.hidden_size) return next_states @@ -127,13 +128,7 @@ def forward(self, hidden_states): routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None] output_states = routed_out.sum(dim=0)[None, ...] return output_states, router_scores -"""" -tensor([[ 7.9688, 0.8750, -0.3535, ..., 3.9219, 9.1875, -2.8906], - [ 2.2500, -15.9375, -2.7500, ..., 18.5000, -1.5312, -18.1250], - [-13.5000, -6.1562, 13.5625, ..., -22.1250, -30.3750, 23.6250], - [ 6.0000, 2.5469, 3.6094, ..., -7.9375, 4.2500, -7.7812]], -""" class OpenaiRotaryEmbedding(nn.Module): rope_type = "default" @@ -305,6 +300,7 @@ def forward( class OpenaiDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: OpenaiConfig, layer_idx: int): super().__init__() + self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.self_attn = OpenaiAttention(config=config, layer_idx=layer_idx) self.mlp = OpenaiMLP(config) @@ -483,35 +479,20 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if output_attentions: all_self_attns += (layer_outputs[1],) From a966054d19294f2c79d762f521f8c112d1a05143 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 14:44:46 +0000 Subject: [PATCH 038/342] small nits --- .../convert_openai_weights_to_hf.py | 5 +++- .../models/openai_moe/modeling_openai_moe.py | 30 +++++++------------ 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index f4fc63d0df2c..d9bdafe9f19a 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -119,7 +119,10 @@ def write_model( elif re.search("gate_up_proj|down_proj", new_key) and "bias" not in new_key: state_dict[new_key] = final_[key].permute(0,2,1).contiguous() # einsum in orignal, I use bmm else: - state_dict[new_key] = final_[key].contiguous().to(torch.bfloat16) # TODO slow, let's rmeove + weight = final_[key] + if not re.search("norm", new_key): + weight = weight.to(torch.bfloat16) # norms are the only ones in float32 + state_dict[new_key] = weight del final_ gc.collect() diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 4250914beeca..43612fe07ddd 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -121,7 +121,7 @@ def forward(self, hidden_states): router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) router_scores = ( - torch.full_like(router_logits, float(0)).scatter_(1, router_indices, router_top_value).transpose(0, 1) + torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) routed_in = hidden_states.repeat(self.num_local_experts, 1) routed_out = self.experts(routed_in) @@ -146,38 +146,28 @@ def __init__(self, config: OpenaiConfig, device=None): @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): + self.inv_freq = 1.0 / (self.config.rope_theta ** ( torch.arange(0, self.config.head_dim, 2, dtype=torch.float32, device=x.device)/ self.config.head_dim)) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) emb = freqs cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - + + return cos.to(x.dtype), sin.to(x.dtype) def _apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: - # cos = cos.unsqueeze(-2).to(x.dtype) - # sin = sin.unsqueeze(-2).to(x.dtype) - x1, x2 = torch.chunk(x, 2, dim=-1) - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - return torch.cat((o1, o2), dim=-1) + first_half, second_half = torch.chunk(x, 2, dim=-1) + first_ = first_half * cos - second_half * sin + second_ = second_half * cos + first_half * sin + return torch.cat((first_, second_), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) @@ -186,7 +176,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k_embed = _apply_rotary_emb(k, cos, sin) return q_embed, k_embed - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -219,7 +208,7 @@ def eager_attention_forward( attn_weights = attn_weights + causal_mask attn_weights = torch.cat([attn_weights, sinks], dim=-1) - attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) + attn_weights = torch.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights[...,:-1], value_states) # ignore the sinks attn_output = attn_output.transpose(1, 2).contiguous() @@ -356,6 +345,7 @@ class OpenaiPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["OpenaiDecoderLayer"] + _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True From 7ff7537dc56e3eb5787d0e53320f621671f3048d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 15:33:39 +0000 Subject: [PATCH 039/342] make tp work --- .../integrations/tensor_parallel.py | 47 +++++++++++++------ .../openai_moe/configuration_openai_moe.py | 14 +++--- .../models/openai_moe/modeling_openai_moe.py | 2 +- 3 files changed, 39 insertions(+), 24 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index a9f8940e72e1..9ebbeea99a9d 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -198,8 +198,10 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim): slice_dtype = slice_.get_dtype() # Handle F8_E4M3 dtype by converting to float16 before slicing # Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn' - if slice_dtype == "F8_E4M3": + casted = False + if slice_dtype == "F8_E4M3" or slice_dtype == "F8_E5M2": slice_ = slice_[...].to(torch.float16) + casted = True if dim == 0: tensor = slice_[tensors_slices, ...] @@ -209,7 +211,11 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim): tensor = slice_[..., tensors_slices] else: raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported") - return tensor.to(str_to_torch_dtype[slice_dtype]) + + if casted: + return tensor + else: + return tensor.to(str_to_torch_dtype[slice_dtype]) def repack_weights( @@ -408,6 +414,15 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me # TODO: figure out dynamo support for instance method and switch this to instance method return outputs + + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + param = param[...].to(param_casting_dtype) + if to_contiguous: + param = param.contiguous() + param = param / device_mesh.size() # TODO should be optionable + return param + + def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: distribute_module( module, @@ -823,7 +838,7 @@ def __init__(self): module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}" # 2. We add hooks to the parent module if needed - if "." in layer_name: + if "." in layer_name and current_module_plan != "replicate": parent_layer_name = layer_name.rsplit(".", 1)[0] generic_name = re.sub(r"\d+", "*", parent_layer_name) # The module itself needs hooks @@ -855,7 +870,7 @@ def shard_and_distribute_module( current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan) if current_module_plan is None: - current_module_plan = "replicate" + # current_module_plan = "replicate" -> mega breaking for now as it seems OpenaiAttention gets some hoooks if dist.get_rank() == 0: logger.info(f"Tensor parallel plan for {param_name} not found, using default 'replicate' plan.") else: @@ -867,21 +882,23 @@ def shard_and_distribute_module( if not getattr(module_to_tp, "_is_hooked", False): add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) module_to_tp._is_hooked = True - - try: - tp_layer = ALL_PARALLEL_STYLES[current_module_plan] - param = tp_layer.partition_tensor( - param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh - ) - except NotImplementedError as e: - print( - f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}" - ) + if current_module_plan is not None: + try: + tp_layer = ALL_PARALLEL_STYLES[current_module_plan] + param = tp_layer.partition_tensor( + param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh + ) + except NotImplementedError as e: + print( + f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}" + ) + else: + param = param[:].to(param_casting_dtype) # SUPER IMPORTANT we have to use setattr # otherwise loading is crazy slow if not isinstance(param, torch.nn.Parameter): - param = torch.nn.Parameter(param, requires_grad=param.is_floating_point()) + param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point()) setattr(module_to_tp, param_type, param) # module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True) return param diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 3ac771b2c6ef..1ed0ae935506 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -30,16 +30,14 @@ class OpenaiConfig(PretrainedConfig): """ model_type = "openai-moe" - keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `OpenaiModel` + # a bit special, but this seems to work alright base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_up_proj": "local_packed_rowwise", - "layers.*.mlp.down_proj": "local_colwise", - "layers.*.mlp": "local", + "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", + "layers.*.mlp.experts.gate_up_proj_bias": "local_rowwise", + "layers.*.mlp.experts.down_proj": "local_colwise", + "layers.*.mlp.experts.down_proj_bias": "local", + "layers.*.mlp.experts": "gather", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 43612fe07ddd..5f9731ca1c50 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -99,7 +99,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[...,None, :] gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors glu = gate * self.act_fn(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj) + self.down_proj_bias[..., None, :] + next_states = torch.bmm(((up + 1) * glu), self.down_proj) + self.down_proj_bias[..., None, :] // 4 next_states = next_states.view(-1, self.hidden_size) return next_states From 56305ef685359145bfd35dbe7857942d76af32af Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 16:21:23 +0000 Subject: [PATCH 040/342] fix tp --- .../integrations/tensor_parallel.py | 28 +++++++++++-------- .../models/openai_moe/modeling_openai_moe.py | 2 +- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 9ebbeea99a9d..f2c72422a5db 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -807,7 +807,7 @@ def replace_state_dict_local_with_dtensor( return state_dict -def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh): +def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh, parameter_name=None): """ Add hooks to the module holding the layer. Meaning: ``` @@ -839,10 +839,16 @@ def __init__(self): # 2. We add hooks to the parent module if needed if "." in layer_name and current_module_plan != "replicate": - parent_layer_name = layer_name.rsplit(".", 1)[0] + parent_layer_name = parameter_name.rsplit(".", 1)[0] generic_name = re.sub(r"\d+", "*", parent_layer_name) # The module itself needs hooks - if module_plan := tp_plan.get(generic_name, False): + if module_plan := tp_plan.get(generic_name, False) and not module._hf_tp_plan: + tp_layer = ALL_PARALLEL_STYLES[module_plan] + module_to_tp_ = model.get_submodule(parent_layer_name) + tp_layer.prepare_module_tp(module_to_tp_, device_mesh) + module_to_tp_._hf_tp_plan = current_module_plan + module_to_tp_.__repr__ = lambda: f"{module_to_tp_.__repr__()}\nTP Plan: {current_module_plan}" + elif module_plan := tp_plan.get(re.sub(r"\d+", "*", layer_name.rsplit(".", 1)[0]), False)and not module._hf_tp_plan: tp_layer = ALL_PARALLEL_STYLES[module_plan] module_to_tp_ = model.get_submodule(parent_layer_name) tp_layer.prepare_module_tp(module_to_tp_, device_mesh) @@ -867,24 +873,24 @@ def shard_and_distribute_module( module_to_tp = model.get_submodule(param_name) rank = int(rank) - current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan) + current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan) - if current_module_plan is None: - # current_module_plan = "replicate" -> mega breaking for now as it seems OpenaiAttention gets some hoooks + if current_shard_plan is None: if dist.get_rank() == 0: - logger.info(f"Tensor parallel plan for {param_name} not found, using default 'replicate' plan.") + logger.info(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.") else: if dist.get_rank() == 0: - logger.info(f"Tensor parallel plan for {param_name}: {current_module_plan}") + logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}") # Add hooks to the module if not done yet # add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) if not getattr(module_to_tp, "_is_hooked", False): - add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) + add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_shard_plan, device_mesh, parameter_name) module_to_tp._is_hooked = True - if current_module_plan is not None: + + if current_shard_plan is not None: try: - tp_layer = ALL_PARALLEL_STYLES[current_module_plan] + tp_layer = ALL_PARALLEL_STYLES[current_shard_plan] param = tp_layer.partition_tensor( param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh ) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 5f9731ca1c50..43612fe07ddd 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -99,7 +99,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[...,None, :] gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors glu = gate * self.act_fn(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj) + self.down_proj_bias[..., None, :] // 4 + next_states = torch.bmm(((up + 1) * glu), self.down_proj) + self.down_proj_bias[..., None, :] next_states = next_states.view(-1, self.hidden_size) return next_states From e5db2f86dc5f64c6d89bd9103a048be57139e4df Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 20:35:16 +0000 Subject: [PATCH 041/342] fix batch size and update exampel? --- examples/3D_parallel.py | 4 ++-- src/transformers/modeling_utils.py | 2 +- .../models/openai_moe/modeling_openai_moe.py | 11 ++++++----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/3D_parallel.py b/examples/3D_parallel.py index 64a0fd8fdc7c..958336eefd3e 100644 --- a/examples/3D_parallel.py +++ b/examples/3D_parallel.py @@ -137,7 +137,7 @@ def main(): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}") - + print(f"TP MESH: {tp_mesh}") model = AutoModelForCausalLM.from_pretrained( model_name, device_mesh=tp_mesh if dist.is_initialized() else None, @@ -200,7 +200,7 @@ def create_packed_sequences(examples): batched=True, remove_columns=tokenized_dataset.column_names, batch_size=1000, # Process in batches for efficiency - num_proc=60, + # num_proc=60, ) logger.info(f"Dataset packed. New size: {len(packed_dataset)}") diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2420c585fcf9..7cd641971b7f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4303,7 +4303,7 @@ def from_pretrained( tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None) else: # TODO: make device_mesh support multiple dimensions - if device_mesh.ndim == 1: + if device_mesh.ndim > 1: raise ValueError("device_mesh must be 1 dimensional and will be used for TP") device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"])) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 43612fe07ddd..1ea5079eaad2 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -116,6 +116,7 @@ def __init__(self, config): def forward(self, hidden_states): # we don't slice weight as its not compile compatible + batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = self.router(hidden_states) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) @@ -126,7 +127,7 @@ def forward(self, hidden_states): routed_in = hidden_states.repeat(self.num_local_experts, 1) routed_out = self.experts(routed_in) routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None] - output_states = routed_out.sum(dim=0)[None, ...] + output_states = routed_out.view(self.num_local_experts, batch_size, -1, self.hidden_dim).sum(dim=0) return output_states, router_scores @@ -200,13 +201,14 @@ def eager_attention_forward( ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, query.shape[-2], -1) # TODO make sure the sink is like a new token - sinks.zero_() + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) # TODO make sure the sink is like a new token + # sinks.zero_() attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - + print("attn_weights", attn_weights.shape, attn_weights.dtype) + print("sinks", sinks.shape, sinks.dtype) attn_weights = torch.cat([attn_weights, sinks], dim=-1) attn_weights = torch.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) @@ -452,7 +454,6 @@ def forward( attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds From 7f863a1dae786c585ed505ea8f5a6b5a85a513df Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 20:38:26 +0000 Subject: [PATCH 042/342] oups prints got in there --- src/transformers/models/openai_moe/modeling_openai_moe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 1ea5079eaad2..0b8782d088d2 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -207,8 +207,6 @@ def eager_attention_forward( if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - print("attn_weights", attn_weights.shape, attn_weights.dtype) - print("sinks", sinks.shape, sinks.dtype) attn_weights = torch.cat([attn_weights, sinks], dim=-1) attn_weights = torch.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) From 51013813f34a4b82573117bc1a4e87a0212b2035 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 20:52:50 +0000 Subject: [PATCH 043/342] fix tp + more TP --- .../integrations/tensor_parallel.py | 17 +++++------------ .../openai_moe/configuration_openai_moe.py | 5 +++++ .../models/openai_moe/modeling_openai_moe.py | 1 - 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 11110dc4d5cc..f2116179396c 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -834,26 +834,19 @@ def __init__(self): print( f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}" ) - module._hf_tp_plan = current_module_plan - module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}" # 2. We add hooks to the parent module if needed if "." in layer_name and current_module_plan != "replicate": parent_layer_name = parameter_name.rsplit(".", 1)[0] generic_name = re.sub(r"\d+", "*", parent_layer_name) # The module itself needs hooks - if module_plan := tp_plan.get(generic_name, False) and not module._hf_tp_plan: + if module_plan := tp_plan.get(generic_name, False): tp_layer = ALL_PARALLEL_STYLES[module_plan] module_to_tp_ = model.get_submodule(parent_layer_name) - tp_layer.prepare_module_tp(module_to_tp_, device_mesh) - module_to_tp_._hf_tp_plan = current_module_plan - module_to_tp_.__repr__ = lambda: f"{module_to_tp_.__repr__()}\nTP Plan: {current_module_plan}" - elif module_plan := tp_plan.get(re.sub(r"\d+", "*", layer_name.rsplit(".", 1)[0]), False)and not module._hf_tp_plan: - tp_layer = ALL_PARALLEL_STYLES[module_plan] - module_to_tp_ = model.get_submodule(parent_layer_name) - tp_layer.prepare_module_tp(module_to_tp_, device_mesh) - module_to_tp_._hf_tp_plan = current_module_plan - module_to_tp_.__repr__ = lambda: f"{module_to_tp_.__repr__()}\nTP Plan: {current_module_plan}" + if not getattr(module_to_tp_, "_hf_tp_plan", False): + tp_layer.prepare_module_tp(module_to_tp_, device_mesh) + module_to_tp_._hf_tp_plan = current_module_plan + module_to_tp_.__repr__ = lambda: f"{module_to_tp_.__repr__()}\nTP Plan: {current_module_plan}" def shard_and_distribute_module( diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 1ed0ae935506..368d571fb764 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -33,6 +33,11 @@ class OpenaiConfig(PretrainedConfig): # Default tensor parallel plan for base model `OpenaiModel` # a bit special, but this seems to work alright base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.sinks": "local_rowwise", "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", "layers.*.mlp.experts.gate_up_proj_bias": "local_rowwise", "layers.*.mlp.experts.down_proj": "local_colwise", diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 0b8782d088d2..e8cab3ca76b8 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -202,7 +202,6 @@ def eager_attention_forward( key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) # TODO make sure the sink is like a new token - # sinks.zero_() attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] From e43411aaa63b73a7ba6dddf93c40da0290d0a323 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 May 2025 21:38:47 +0000 Subject: [PATCH 044/342] push a version that does support training, but worst for inference, not compile compatible --- .../models/openai_moe/modeling_openai_moe.py | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index e8cab3ca76b8..25a36f2168a4 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -82,7 +82,7 @@ def __init__(self, config): self.act_fn = torch.nn.Sigmoid() self.alpha = 1.702 - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, router_indices = None, routing_weights=None) -> torch.Tensor: """ This should really not be run on a single machine, as we are reaching compute bound: - the inputs are expected to be "sorted" per expert already. @@ -95,11 +95,33 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor """ - hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[...,None, :] + # hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + final_hidden_states = torch.zeros_like( + hidden_states, dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) + + expert_hitted = (expert_mask.sum(dim=(-1, -2)) > 0).nonzero(as_tuple=True)[0].tolist() + for expert_idx in expert_hitted: + idx, top_x = torch.where(expert_mask[expert_idx]) # idx: top-1/top-2 indicator, top_x: token indices + current_state = hidden_states[top_x] # (num_tokens, hidden_dim) + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] # (num_tokens, 2 * interm_dim) + gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim) + glu = gate * self.act_fn(gate * self.alpha) # (num_tokens, interm_dim) + gated_output = (up + 1) * glu # (num_tokens, interm_dim) + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] # (num_tokens, hidden_dim) + weighted_output = out * routing_weights[top_x, idx].unsqueeze(-1) # (num_tokens, hidden_dim) + final_hidden_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)) + return final_hidden_states + + + gate_up = torch.bmm(hidden_states, self.gate_up_proj[router_indices, ...]) + self.gate_up_proj_bias[router_indices, ...,None, :] gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors glu = gate * self.act_fn(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj) + self.down_proj_bias[..., None, :] + next_states = torch.bmm(((up + 1) * glu), self.down_proj[router_indices, ...]) + self.down_proj_bias[router_indices, ..., None, :] next_states = next_states.view(-1, self.hidden_size) return next_states @@ -124,11 +146,12 @@ def forward(self, hidden_states): router_scores = ( torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) - routed_in = hidden_states.repeat(self.num_local_experts, 1) - routed_out = self.experts(routed_in) - routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None] - output_states = routed_out.view(self.num_local_experts, batch_size, -1, self.hidden_dim).sum(dim=0) - return output_states, router_scores + # routed_in = hidden_states.repeat(self.num_local_experts, 1) + routed_out = self.experts(hidden_states, router_indices, router_top_value) + routed_out = routed_out.view(batch_size, -1, self.hidden_dim) + # routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None] + # output_states = routed_out.view(self.num_local_experts, batch_size, -1, self.hidden_dim).sum(dim=0) + return routed_out, router_scores class OpenaiRotaryEmbedding(nn.Module): @@ -266,8 +289,8 @@ def forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - # if self.config._attn_implementation != "eager": - # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, From 079840adb67c46e140b85d36d9b4de51de419880 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 30 May 2025 06:22:24 +0000 Subject: [PATCH 045/342] Replace activation module with function and fix down_proj_bias shape --- src/transformers/models/openai_moe/modeling_openai_moe.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 25a36f2168a4..23a3383f3af6 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -78,8 +78,7 @@ def __init__(self, config): self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) - self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) - self.act_fn = torch.nn.Sigmoid() + self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) self.alpha = 1.702 def forward(self, hidden_states: torch.Tensor, router_indices = None, routing_weights=None) -> torch.Tensor: @@ -110,7 +109,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices = None, routing_we current_state = hidden_states[top_x] # (num_tokens, hidden_dim) gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] # (num_tokens, 2 * interm_dim) gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim) - glu = gate * self.act_fn(gate * self.alpha) # (num_tokens, interm_dim) + glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim) gated_output = (up + 1) * glu # (num_tokens, interm_dim) out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] # (num_tokens, hidden_dim) weighted_output = out * routing_weights[top_x, idx].unsqueeze(-1) # (num_tokens, hidden_dim) From 5ce8b563721c4ba37f83a1e61873bcfa869673bd Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 2 Jun 2025 08:33:32 +0000 Subject: [PATCH 046/342] updates: trainnig and inference paths, no synch tolist for compile --- .../integrations/flex_attention.py | 2 +- src/transformers/masking_utils.py | 21 ++-- .../openai_moe/configuration_openai_moe.py | 10 +- .../models/openai_moe/modeling_openai_moe.py | 119 +++++++++++------- .../models/openai_moe/modular_openai_moe.py | 1 + 5 files changed, 92 insertions(+), 61 deletions(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 7281765c0086..e69f7a4c6672 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -292,7 +292,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): training=module.training, ) # lse is returned in float32 - attention_weights = attention_weights.to(value.dtype) + attention_weights = attention_weights.to(value.dtype)[:, :, :, : key.shape[-2]] # potential sink attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attention_weights diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index cb502206d780..5c402dd65d72 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -590,7 +590,6 @@ def _preprocess_mask_arguments( attention_mask: Optional[Union[torch.Tensor, BlockMask]], cache_position: torch.Tensor, past_key_values: Optional[Cache], - layer_idx: Optional[int], ) -> tuple[bool, Optional[Union[torch.Tensor, BlockMask]], int, int]: """ Perform some common pre-processing of the mask arguments we get from the modeling code. Mostly determine the @@ -641,6 +640,10 @@ def _preprocess_mask_arguments( # If using a cache, it can give all informations about mask sizes based on seen tokens if past_key_values is not None: + if hasattr(past_key_values, "is_sliding") and isinstance(past_key_values.is_sliding, list): + layer_idx = past_key_values.is_sliding.index(False) + else: + layer_idx = 0 kv_length, kv_offset = past_key_values.get_mask_sizes(cache_position, layer_idx) # Otherwise, the sizes are simply the input sizes else: @@ -684,13 +687,8 @@ def create_causal_mask( useful to easily overlay another mask on top of the causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the full layers - try: - layer_idx = past_key_values.is_sliding.index(False) - except (ValueError, AttributeError): - layer_idx = 0 - early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( - config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx + config, input_embeds, attention_mask, cache_position, past_key_values ) if early_exit: return attention_mask @@ -729,7 +727,7 @@ def create_causal_mask( ) return causal_mask - +@torch.compiler.disable(recursive=True) def create_sliding_window_causal_mask( config: PretrainedConfig, input_embeds: torch.Tensor, @@ -766,13 +764,10 @@ def create_sliding_window_causal_mask( useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the sliding layers - try: - layer_idx = past_key_values.is_sliding.index(True) - except (ValueError, AttributeError): - layer_idx = 0 + early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( - config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx + config, input_embeds, attention_mask, cache_position, past_key_values ) if early_exit: return attention_mask diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 368d571fb764..4b21c8dc9909 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -33,11 +33,11 @@ class OpenaiConfig(PretrainedConfig): # Default tensor parallel plan for base model `OpenaiModel` # a bit special, but this seems to work alright base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.self_attn.sinks": "local_rowwise", + # "layers.*.self_attn.q_proj": "colwise", + # "layers.*.self_attn.k_proj": "colwise", + # "layers.*.self_attn.v_proj": "colwise", + # "layers.*.self_attn.o_proj": "rowwise", + # "layers.*.self_attn.sinks": "local_rowwise", "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", "layers.*.mlp.experts.gate_up_proj_bias": "local_rowwise", "layers.*.mlp.experts.down_proj": "local_colwise", diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 23a3383f3af6..c8c0a474c224 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -33,7 +33,8 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...integrations.flex_attention import flex_attention_forward +from ...masking_utils import create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -83,9 +84,10 @@ def __init__(self, config): def forward(self, hidden_states: torch.Tensor, router_indices = None, routing_weights=None) -> torch.Tensor: """ - This should really not be run on a single machine, as we are reaching compute bound: - - the inputs are expected to be "sorted" per expert already. - - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape + When training is is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + + For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. Args: hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) @@ -94,34 +96,32 @@ def forward(self, hidden_states: torch.Tensor, router_indices = None, routing_we Returns: torch.Tensor """ - # hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - final_hidden_states = torch.zeros_like( - hidden_states, dtype=hidden_states.dtype, device=hidden_states.device - ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) - - expert_hitted = (expert_mask.sum(dim=(-1, -2)) > 0).nonzero(as_tuple=True)[0].tolist() - for expert_idx in expert_hitted: - idx, top_x = torch.where(expert_mask[expert_idx]) # idx: top-1/top-2 indicator, top_x: token indices - current_state = hidden_states[top_x] # (num_tokens, hidden_dim) - gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] # (num_tokens, 2 * interm_dim) - gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim) - glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim) - gated_output = (up + 1) * glu # (num_tokens, interm_dim) - out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] # (num_tokens, hidden_dim) - weighted_output = out * routing_weights[top_x, idx].unsqueeze(-1) # (num_tokens, hidden_dim) - final_hidden_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)) - return final_hidden_states - - - gate_up = torch.bmm(hidden_states, self.gate_up_proj[router_indices, ...]) + self.gate_up_proj_bias[router_indices, ...,None, :] - gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors - glu = gate * self.act_fn(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj[router_indices, ...]) + self.down_proj_bias[router_indices, ..., None, :] - next_states = next_states.view(-1, self.hidden_size) + if self.training: + next_states = torch.zeros_like( + hidden_states, dtype=hidden_states.dtype, device=hidden_states.device + ) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) + expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hitted: + with torch.no_grad(): + idx, top_x = torch.where(expert_mask[expert_idx][0]) # idx: top-1/top-2 indicator, top_x: token indices + current_state = hidden_states[top_x] # (num_tokens, hidden_dim) + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] # (num_tokens, 2 * interm_dim) + gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim) + glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim) + gated_output = (up + 1) * glu # (num_tokens, interm_dim) + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] # (num_tokens, hidden_dim) + weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) + next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) + else: + hidden_states = hidden_states.repeat(self.num_experts, 1) + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), self.down_proj) + self.down_proj_bias[..., None, :] + next_states = next_states.view(-1, self.hidden_size) return next_states @@ -145,12 +145,13 @@ def forward(self, hidden_states): router_scores = ( torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) - # routed_in = hidden_states.repeat(self.num_local_experts, 1) routed_out = self.experts(hidden_states, router_indices, router_top_value) - routed_out = routed_out.view(batch_size, -1, self.hidden_dim) - # routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None] - # output_states = routed_out.view(self.num_local_experts, batch_size, -1, self.hidden_dim).sum(dim=0) - return routed_out, router_scores + if self.training: + output_states = routed_out.view(batch_size, -1, self.hidden_dim) + else: + routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None] + output_states = routed_out.view(self.num_local_experts, batch_size, -1, self.hidden_dim).sum(dim=0) + return output_states, router_scores class OpenaiRotaryEmbedding(nn.Module): @@ -158,7 +159,8 @@ class OpenaiRotaryEmbedding(nn.Module): def __init__(self, config: OpenaiConfig, device=None): super().__init__() self.attention_scaling = 1.0 - inv_freq = 1.0 / (config.rope_theta ** ( torch.arange(0, config.head_dim, 2, dtype=torch.float32)/ config.head_dim)) + with torch.device("cpu"): + inv_freq = 1.0 / (config.rope_theta ** ( torch.arange(0, config.head_dim, 2, dtype=torch.float32)/ config.head_dim)) self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings @@ -169,7 +171,6 @@ def __init__(self, config: OpenaiConfig, device=None): @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): - self.inv_freq = 1.0 / (self.config.rope_theta ** ( torch.arange(0, self.config.head_dim, 2, dtype=torch.float32, device=x.device)/ self.config.head_dim)) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() @@ -236,6 +237,41 @@ def eager_attention_forward( return attn_output, attn_weights +def openai_flex_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, key.shape[-2],-1) + + def attention_sink(score, b, h, q_idx, kv_idx): + score = torch.cat([score, sinks], dim=-1) + return score + + # TODO I need to remove the -1 sinks + return flex_attention_forward( + module, + query, + key, + value, + attention_mask, + scaling=scaling, + dropout=dropout, + attention_sink=attention_sink, + score_mod=attention_sink, + **kwargs, + ) + + +ALL_ATTENTION_FUNCTIONS.register("openai_flex_attention", openai_flex_attention_forward) + + + class OpenaiAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -261,7 +297,7 @@ def __init__(self, config: OpenaiConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.register_buffer("sinks", torch.empty(config.num_attention_heads), persistent=True) + self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) def forward( self, @@ -466,8 +502,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask - causal_mask = mask_function( + causal_mask = create_sliding_window_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 4da11e3dcd19..b12c0c2a5cc9 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -154,6 +154,7 @@ def openai_flex_attention_forward( def attention_sink(score, b, h, q_idx, kv_idx): score = torch.cat([score, sinks], dim=-1) return score + # TODO I need to remove the -1 sinks return flex_attention_forward( module, From c7736764cc355dd2fee85524ed13a96eca4f3eb1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 2 Jun 2025 08:57:05 +0000 Subject: [PATCH 047/342] update conversion for correct tokenizer --- .../convert_openai_weights_to_hf.py | 33 +++++++------------ 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index d9bdafe9f19a..90c3a0636406 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -84,7 +84,7 @@ def write_model( ): os.makedirs(model_path, exist_ok=True) bos_token_id = 128000 - eos_token_id = [128001, 128008, 128009] if instruct else 128001 + eos_token_id = 199999 if not instruct else [199999, 200018] pad_token_id = 128004 config = OpenaiConfig() @@ -186,7 +186,7 @@ def bytes_to_unicode(): class OpenaiConverter(TikTokenConverter): def extract_vocab_merges_from_model(self, tiktoken_url: str): tokenizer = tiktoken.get_encoding(tiktoken_url) - + self.pattern = tokenizer._pat_str bpe_ranks = tokenizer._mergeable_ranks byte_encoder = bytes_to_unicode() @@ -213,18 +213,17 @@ def token_bytes_to_string(b): def __init__( self, vocab_file, - special_tokens: List[str], - pattern: str, model_max_length: int, chat_template: Optional[str] = None, **kwargs, ): - super().__init__(vocab_file, pattern=pattern) + super().__init__(vocab_file, pattern=None) # TODO 1st donwload the vocabfile!!! tokenizer = tiktoken.get_encoding(vocab_file) - special_tokens = tokenizer._special_tokens.keys() - + self.additional_special_tokens = {} + # 199998 is not defined either + self.additional_special_tokens["<|reserved_199998|>"] = 199998 self.additional_special_tokens = {'<|endoftext|>': 199999, '<|endofprompt|>': 200018} for k in range(199999, 200018): self.additional_special_tokens[f"<|reserved_{k}|>"] = k @@ -242,9 +241,6 @@ def __init__( def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): - pattern = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: W605 - special_tokens = [] - # Chat template chat_template = ( "{% for message in messages %}" @@ -272,13 +268,8 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): converter = OpenaiConverter( vocab_file=tokenizer_path, - pattern=pattern, - special_tokens=special_tokens, model_max_length=None, chat_template=chat_template if instruct else None, - bos_token="<|begin_of_text|>", - eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>", - pad_token="<|finetune_right_pad_id|>", ) tokenizer = converter.tokenizer tokenizer.save_pretrained(save_dir) @@ -318,12 +309,12 @@ def main(): help="Whether the model is an instruct model", ) args = parser.parse_args() - write_model( - model_path=args.output_dir, - input_base_path=args.input_dir, - safe_serialization=args.safe_serialization, - instruct=args.instruct, - ) + # write_model( + # model_path=args.output_dir, + # input_base_path=args.input_dir, + # safe_serialization=args.safe_serialization, + # instruct=args.instruct, + # ) write_tokenizer( tokenizer_path="o200k_base", From 3269cdd6392b6488c3537a49f40ce292915a16a3 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 5 Jun 2025 17:26:41 +0900 Subject: [PATCH 048/342] Add sliding window to modeling --- .../models/openai_moe/modeling_openai_moe.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index c8c0a474c224..24dae925bc98 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -298,6 +298,7 @@ def __init__(self, config: OpenaiConfig, layer_idx: int): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None def forward( self, @@ -352,6 +353,7 @@ def __init__(self, config: OpenaiConfig, layer_idx: int): self.mlp = OpenaiMLP(config) self.input_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -502,13 +504,21 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = create_sliding_window_causal_mask( - config=self.config, - input_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } hidden_states = inputs_embeds @@ -526,7 +536,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, From bdd3e562fdda8038aef4d549e2b2dad1191aee58 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 5 Jun 2025 17:28:14 +0900 Subject: [PATCH 049/342] Update configuration_openai_moe.py for sliding window --- .../models/openai_moe/configuration_openai_moe.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 4b21c8dc9909..03c4414ddfe2 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -19,7 +19,7 @@ # limitations under the License. """openai model configuration""" -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation @@ -97,12 +97,18 @@ def __init__( self.rope_scaling = rope_scaling self.attention_dropout = attention_dropout self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) + ] + # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, copy it it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) - + layer_type_validation(self.layer_types) + self.attention_bias = True self.mlp_bias = False self.max_position_embeddings = 8192 From a75d97be2b31b15e6ac90e46a4cf0ffc8e4191b8 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 5 Jun 2025 13:20:31 +0000 Subject: [PATCH 050/342] fix nits to make it work with sliding window --- .../models/openai_moe/configuration_openai_moe.py | 4 +++- src/transformers/models/openai_moe/modeling_openai_moe.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 03c4414ddfe2..bda4eed795c9 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -75,6 +75,7 @@ def __init__( router_aux_loss_coef: float = 0.9, output_router_logits=False, use_cache=True, + layer_types=None, **kwargs, ): self.vocab_size = vocab_size @@ -97,17 +98,18 @@ def __init__( self.rope_scaling = rope_scaling self.attention_dropout = attention_dropout self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.layer_types = layer_types if self.layer_types is None: self.layer_types = [ "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) ] + layer_type_validation(self.layer_types) # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, copy it it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) - layer_type_validation(self.layer_types) self.attention_bias = True self.mlp_bias = False diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 24dae925bc98..e1da7b0639cb 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -34,7 +34,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...integrations.flex_attention import flex_attention_forward -from ...masking_utils import create_sliding_window_causal_mask +from ...masking_utils import create_sliding_window_causal_mask, create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast From 70e0db01a15e2ede8107f5aab915fff3ffe51fe4 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 6 Jun 2025 16:52:50 +0000 Subject: [PATCH 051/342] tp-training --- .../integrations/tensor_parallel.py | 2 +- src/transformers/modeling_utils.py | 2 +- .../models/openai_moe/modeling_openai_moe.py | 2 +- src/transformers/trainer.py | 36 +++++++++++++------ 4 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index f2116179396c..c751cf385637 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -90,7 +90,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None): device_map = tp_device tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size() device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,)) - return tp_device, device_map, device_mesh + return tp_device, device_map, device_mesh, tp_size def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7cd641971b7f..3f53f9d10e96 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4300,7 +4300,7 @@ def from_pretrained( # `device_map` pointing to the correct device if tp_plan is not None: if device_mesh is None and tp_plan is not None: - tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None) + tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size) else: # TODO: make device_mesh support multiple dimensions if device_mesh.ndim > 1: diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index e1da7b0639cb..251a627121f3 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -141,7 +141,7 @@ def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = self.router(hidden_states) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) router_scores = ( torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c2a44b6ff081..7b5b5589ed70 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2306,7 +2306,9 @@ def _inner_training_loop( else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled + delay_optimizer_creation = ( + is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled or self.is_tp_enabled + ) # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404 is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2) @@ -2366,7 +2368,10 @@ def _inner_training_loop( if self.use_apex: model = self.accelerator.prepare(self.model) else: - model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + if delay_optimizer_creation: + self.optimizer = self.accelerator.prepare(self.optimizer) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) else: # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( @@ -2585,10 +2590,16 @@ def _inner_training_loop( args.max_grad_norm, ) else: - _grad_norm = self.accelerator.clip_grad_norm_( - model.parameters(), - args.max_grad_norm, - ) + grad_norm_context = contextlib.nullcontext + if self.is_tp_enabled: + from torch.distributed._tensor.experimental import implicit_replication + + grad_norm_context = implicit_replication + with grad_norm_context(): + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) if ( is_accelerate_available() @@ -2603,7 +2614,12 @@ def _inner_training_loop( self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) - self.optimizer.step() + context = contextlib.nullcontext + if self.is_tp_enabled: + from torch.distributed._tensor.experimental import implicit_replication + context = implicit_replication + with context(): + self.optimizer.step() self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) @@ -3906,9 +3922,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa # remove the dummy state_dict remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) self.model_wrapped.save_checkpoint(output_dir) - - elif self.args.should_save: - self._save(output_dir) + + # elif self.args.should_save: + self._save(output_dir) # Push to the Hub when `save_model` is called by the user. if self.args.push_to_hub and not _internal_call: From 2731d3dce58c85e8250ed2435cd23b8414c93f2f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 10 Jun 2025 17:21:25 +0200 Subject: [PATCH 052/342] fix converter --- .../openai_moe/convert_openai_weights_to_hf.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 90c3a0636406..38057c066a62 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -43,7 +43,7 @@ # special key, wqkv needs to be split afterwards r"block.(\d+).attn.qkv": r"layers.\1.self_attn.qkv_proj", r"block.(\d+).attn.out": r"layers.\1.self_attn.o_proj", - r"block.(\d+).attn.sdpa.sinks": r"layers.\1.self_attn.sinks", + r"block.(\d+).attn.sinks": r"layers.\1.self_attn.sinks", r"block.(\d+).attn.norm.scale": r"layers.\1.input_layernorm.weight", r"block.(\d+).mlp.mlp1_weight": r"layers.\1.mlp.experts.gate_up_proj", @@ -309,12 +309,12 @@ def main(): help="Whether the model is an instruct model", ) args = parser.parse_args() - # write_model( - # model_path=args.output_dir, - # input_base_path=args.input_dir, - # safe_serialization=args.safe_serialization, - # instruct=args.instruct, - # ) + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + safe_serialization=args.safe_serialization, + instruct=args.instruct, + ) write_tokenizer( tokenizer_path="o200k_base", From d89b31f8b9d1c21bab162025f05a64ad77a9acb7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 10 Jun 2025 22:17:25 +0200 Subject: [PATCH 053/342] add yarn support --- src/transformers/modeling_rope_utils.py | 13 +++++++++---- .../models/openai_moe/configuration_openai_moe.py | 5 +++-- .../models/openai_moe/modeling_openai_moe.py | 13 +++++++++---- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index ccb961edea9d..7d36226ac3ea 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -288,10 +288,13 @@ def find_correction_dim(num_rotations, dim, base, max_position_embeddings): """Inverse dimension formula to find the dimension based on the number of rotations""" return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate): """Find dimension range bounds based on rotations""" - low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + low = find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = low = math.floor(low) + high = math.ceil(high) return max(low, 0), min(high, dim - 1) def linear_ramp_factor(min, max, dim): @@ -308,7 +311,8 @@ def linear_ramp_factor(min, max, dim): inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (factor * pos_freqs) - low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings) + truncate = config.rope_scaling.get("truncate", True) + low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate) # Get n-dimensional rotational scaling corrected for extrapolation inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) @@ -512,6 +516,7 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se "original_max_position_embeddings", "mscale", "mscale_all_dim", + "truncate", } received_keys = set(rope_scaling.keys()) _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index bda4eed795c9..9725daa7fd5a 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -65,11 +65,12 @@ def __init__( tie_word_embeddings=False, hidden_act: str = "silu", initializer_range: float = 0.02, + max_position_embeddings=131072, rms_norm_eps: float = 1e-5, pad_token_id: int = 0, bos_token_id: int = 1, eos_token_id: int = 2, - rope_scaling=None, + rope_scaling={"rope_type": "yarn", "factor": 32.0, "beta_fast": 32.0, "beta_slow": 1.0, "truncate": False}, attention_dropout: float = 0.0, num_experts_per_tok=4, router_aux_loss_coef: float = 0.9, @@ -113,7 +114,7 @@ def __init__( self.attention_bias = True self.mlp_bias = False - self.max_position_embeddings = 8192 + self.max_position_embeddings = max_position_embeddings self.router_aux_loss_coef = router_aux_loss_coef self.output_router_logits = output_router_logits self.use_cache = use_cache diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index e1da7b0639cb..9c6f2d8a6eb8 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -158,13 +158,18 @@ class OpenaiRotaryEmbedding(nn.Module): rope_type = "default" def __init__(self, config: OpenaiConfig, device=None): super().__init__() - self.attention_scaling = 1.0 - with torch.device("cpu"): - inv_freq = 1.0 / (config.rope_theta ** ( torch.arange(0, config.head_dim, 2, dtype=torch.float32)/ config.head_dim)) + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings - self.config = config + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq From 55ba2924505046ee597eb324e4dd799224ce2add Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 10:23:41 +0200 Subject: [PATCH 054/342] align modular to modeling --- .../models/openai_moe/modeling_openai_moe.py | 130 +++----- .../models/openai_moe/modular_openai_moe.py | 287 ++++++++++++++---- 2 files changed, 284 insertions(+), 133 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 9c6f2d8a6eb8..8b611165166f 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -23,18 +23,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn -from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...integrations.flex_attention import flex_attention_forward -from ...masking_utils import create_sliding_window_causal_mask, create_causal_mask +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -43,7 +40,7 @@ from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_openai_moe import OpenaiConfig -import math + logger = logging.get_logger(__name__) @@ -63,7 +60,7 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return (self.weight * hidden_states).to(input_dtype) + return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" @@ -80,12 +77,12 @@ def __init__(self, config): self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) - self.alpha = 1.702 + self.alpha = 1.702 - def forward(self, hidden_states: torch.Tensor, router_indices = None, routing_weights=None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: """ When training is is more efficient to just loop over the experts and compute the output for each expert - as otherwise the memory would explode. + as otherwise the memory would explode. For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. @@ -97,22 +94,28 @@ def forward(self, hidden_states: torch.Tensor, router_indices = None, routing_we torch.Tensor """ if self.training: - next_states = torch.zeros_like( - hidden_states, dtype=hidden_states.dtype, device=hidden_states.device - ) + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute( + 2, 1, 0 + ) expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hitted: with torch.no_grad(): - idx, top_x = torch.where(expert_mask[expert_idx][0]) # idx: top-1/top-2 indicator, top_x: token indices + idx, top_x = torch.where( + expert_mask[expert_idx][0] + ) # idx: top-1/top-2 indicator, top_x: token indices current_state = hidden_states[top_x] # (num_tokens, hidden_dim) - gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] # (num_tokens, 2 * interm_dim) + gate_up = ( + current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] + ) # (num_tokens, 2 * interm_dim) gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim) glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim) gated_output = (up + 1) * glu # (num_tokens, interm_dim) - out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] # (num_tokens, hidden_dim) - weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) + out = ( + gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + ) # (num_tokens, hidden_dim) + weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) else: hidden_states = hidden_states.repeat(self.num_experts, 1) @@ -125,7 +128,6 @@ def forward(self, hidden_states: torch.Tensor, router_indices = None, routing_we return next_states - class OpenaiMLP(nn.Module): def __init__(self, config): super().__init__() @@ -142,12 +144,10 @@ def forward(self, hidden_states): router_logits = self.router(hidden_states) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) - router_scores = ( - torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) - ) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) routed_out = self.experts(hidden_states, router_indices, router_top_value) if self.training: - output_states = routed_out.view(batch_size, -1, self.hidden_dim) + output_states = routed_out.view(batch_size, -1, self.hidden_dim) else: routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None] output_states = routed_out.view(self.num_local_experts, batch_size, -1, self.hidden_dim).sum(dim=0) @@ -155,7 +155,6 @@ def forward(self, hidden_states): class OpenaiRotaryEmbedding(nn.Module): - rope_type = "default" def __init__(self, config: OpenaiConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" @@ -185,9 +184,22 @@ def forward(self, x, position_ids): emb = freqs cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling - + return cos.to(x.dtype), sin.to(x.dtype) + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def _apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, @@ -198,6 +210,7 @@ def _apply_rotary_emb( second_ = second_half * cos + first_half * sin return torch.cat((first_, second_), dim=-1) + def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) @@ -205,17 +218,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k_embed = _apply_rotary_emb(k, cos, sin) return q_embed, k_embed -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - def eager_attention_forward( module: nn.Module, @@ -229,54 +231,23 @@ def eager_attention_forward( ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) # TODO make sure the sink is like a new token + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + query.shape[0], -1, query.shape[-2], -1 + ) # TODO make sure the sink is like a new token + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask + attn_weights = torch.cat([attn_weights, sinks], dim=-1) attn_weights = torch.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights[...,:-1], value_states) # ignore the sinks + attn_output = torch.matmul(attn_weights[..., :-1], value_states) # ignore the sinks attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights -def openai_flex_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, key.shape[-2],-1) - - def attention_sink(score, b, h, q_idx, kv_idx): - score = torch.cat([score, sinks], dim=-1) - return score - - # TODO I need to remove the -1 sinks - return flex_attention_forward( - module, - query, - key, - value, - attention_mask, - scaling=scaling, - dropout=dropout, - attention_sink=attention_sink, - score_mod=attention_sink, - **kwargs, - ) - - -ALL_ATTENTION_FUNCTIONS.register("openai_flex_attention", openai_flex_attention_forward) - - - class OpenaiAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -286,10 +257,9 @@ def __init__(self, config: OpenaiConfig, layer_idx: int): self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = 1 / math.sqrt(self.head_dim) + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True - self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) @@ -302,8 +272,8 @@ def __init__(self, config: OpenaiConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) def forward( self, @@ -341,6 +311,7 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, + sliding_window=self.sliding_window, # main diff with Llama **kwargs, ) @@ -352,7 +323,6 @@ def forward( class OpenaiDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: OpenaiConfig, layer_idx: int): super().__init__() - self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.self_attn = OpenaiAttention(config=config, layer_idx=layer_idx) self.mlp = OpenaiMLP(config) @@ -409,7 +379,6 @@ class OpenaiPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["OpenaiDecoderLayer"] - _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True @@ -418,6 +387,7 @@ class OpenaiPreTrainedModel(PreTrainedModel): _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True + _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] def _init_weights(self, module): std = self.config.initializer_range @@ -449,7 +419,6 @@ def __init__(self, config: OpenaiConfig): self.norm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = OpenaiRotaryEmbedding(config=config) self.gradient_checkpointing = False - self.rope = OpenaiRotaryEmbedding(config) # Initialize weights and apply final processing self.post_init() @@ -574,8 +543,7 @@ def forward( ) -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): - ... +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... def load_balancing_loss_func( @@ -666,7 +634,7 @@ class OpenaiForCausalLM(OpenaiPreTrainedModel, GenerationMixin): _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - def __init__(self, config: OpenaiConfig): + def __init__(self, config): super().__init__(config) self.model = OpenaiModel(config) self.vocab_size = config.vocab_size diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index b12c0c2a5cc9..30d58d49c3b3 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -17,16 +17,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple -from ...cache_utils import Cache +from typing import List, Optional, Tuple + import torch from torch import nn -from ...activations import ACT2FN + +from ...cache_utils import Cache, DynamicCache from ...integrations.flex_attention import flex_attention_forward +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + MoeModelOutputWithPast, +) +from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...utils import logging +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, logging from ..llama.modeling_llama import ( - LlamaAttention, LlamaDecoderLayer, LlamaPreTrainedModel, LlamaRMSNorm, @@ -34,6 +41,7 @@ repeat_kv, ) from ..mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel +from ..qwen2.modeling_qwen2 import Qwen2Attention from .configuration_openai_moe import OpenaiConfig @@ -44,12 +52,6 @@ class OpenaiRMSNorm(LlamaRMSNorm): pass -def swiglu(x, alpha: float = 1.702): - # Note we add an extra bias of 1 to the linear layer - x_glu, x_linear = torch.chunk(x, 2, dim=-1) - out_glu = x_glu * torch.sigmoid(alpha * x_glu) - return out_glu * (x_linear + 1) - class OpenaiExperts(nn.Module): def __init__(self, config): super().__init__() @@ -60,14 +62,15 @@ def __init__(self, config): self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) - self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) - self.act_fn = ACT2FN[config.hidden_act] + self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) + self.alpha = 1.702 - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: """ - This should really not be run on a single machine, as we are reaching compute bound: - - the inputs are expected to be "sorted" per expert already. - - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape + When training is is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + + For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. Args: hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) @@ -76,10 +79,38 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor """ - hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[:, None, :] - swiglu_ = swiglu(gate_up) - next_states = torch.bmm(swiglu_, self.down_proj) + self.down_proj_bias[:,None,:] + if self.training: + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute( + 2, 1, 0 + ) + expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hitted: + with torch.no_grad(): + idx, top_x = torch.where( + expert_mask[expert_idx][0] + ) # idx: top-1/top-2 indicator, top_x: token indices + current_state = hidden_states[top_x] # (num_tokens, hidden_dim) + gate_up = ( + current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] + ) # (num_tokens, 2 * interm_dim) + gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim) + glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim) + gated_output = (up + 1) * glu # (num_tokens, interm_dim) + out = ( + gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + ) # (num_tokens, hidden_dim) + weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) + next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) + else: + hidden_states = hidden_states.repeat(self.num_experts, 1) + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), self.down_proj) + self.down_proj_bias[..., None, :] + next_states = next_states.view(-1, self.hidden_size) return next_states @@ -94,22 +125,55 @@ def __init__(self, config): def forward(self, hidden_states): # we don't slice weight as its not compile compatible + batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = self.router(hidden_states) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) - router_scores = ( - torch.full_like(router_logits, float(0)).scatter_(1, router_indices, router_top_value).transpose(0, 1) - ) - - routed_in = hidden_states.repeat(self.num_local_experts, 1) - routed_out = self.experts(routed_in) - routed_out = routed_out * router_scores.reshape(self.num_local_experts, -1, 1) - hidden_states = routed_out.sum(dim=0)[None, ...] - return hidden_states, router_scores + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) + routed_out = self.experts(hidden_states, router_indices, router_top_value) + if self.training: + output_states = routed_out.view(batch_size, -1, self.hidden_dim) + else: + routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None] + output_states = routed_out.view(self.num_local_experts, batch_size, -1, self.hidden_dim).sum(dim=0) + return output_states, router_scores class OpenaiRotaryEmbedding(LlamaRotaryEmbedding): - pass + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = freqs + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(x.dtype), sin.to(x.dtype) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + first_half, second_half = torch.chunk(x, 2, dim=-1) + first_ = first_half * cos - second_half * sin + second_ = second_half * cos + first_half * sin + return torch.cat((first_, second_), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = _apply_rotary_emb(q, cos, sin) + k_embed = _apply_rotary_emb(k, cos, sin) + return q_embed, k_embed def eager_attention_forward( @@ -124,7 +188,9 @@ def eager_attention_forward( ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, query.shape[-2], -1) + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + query.shape[0], -1, query.shape[-2], -1 + ) # TODO make sure the sink is like a new token attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: @@ -132,9 +198,9 @@ def eager_attention_forward( attn_weights = attn_weights + causal_mask attn_weights = torch.cat([attn_weights, sinks], dim=-1) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = torch.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights[...,:-1], value_states) + attn_output = torch.matmul(attn_weights[..., :-1], value_states) # ignore the sinks attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights @@ -149,7 +215,7 @@ def openai_flex_attention_forward( dropout: float = 0.0, **kwargs, ): - sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, key.shape[-2],-1) + sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, key.shape[-2], -1) def attention_sink(score, b, h, q_idx, kv_idx): score = torch.cat([score, sinks], dim=-1) @@ -173,10 +239,23 @@ def attention_sink(score, b, h, q_idx, kv_idx): ALL_ATTENTION_FUNCTIONS.register("openai_flex_attention", openai_flex_attention_forward) -class OpenaiAttention(LlamaAttention): +class OpenaiAttention(Qwen2Attention): def __init__(self, config: OpenaiConfig, layer_idx: int): super().__init__(config, layer_idx) - self.register_buffer("sinks", torch.empty(config.num_attention_heads), persistent=True) + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) + class OpenaiDecoderLayer(LlamaDecoderLayer): def __init__(self, config: OpenaiConfig, layer_idx: int): @@ -186,6 +265,7 @@ def __init__(self, config: OpenaiConfig, layer_idx: int): self.mlp = OpenaiMLP(config) self.input_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -228,28 +308,131 @@ def forward( if kwargs.get("output_router_logits", False): outputs += (router_logits,) return outputs - + + class OpenaiPreTrainedModel(LlamaPreTrainedModel): - config_class = OpenaiConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["OpenaiDecoderLayer"] + _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] -class OpenaiModel(MixtralModel, OpenaiPreTrainedModel): +class OpenaiModel(MixtralModel): _no_split_modules = ["OpenaiDecoderLayer"] - def __init__(self, config: OpenaiConfig): - super().__init__(config) - self.rope = OpenaiRotaryEmbedding(config) - self.layers = nn.ModuleList([OpenaiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) - self.norm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) -class OpenaiForCausalLM(MixtralForCausalLM, OpenaiPreTrainedModel): - def __init__(self, config: OpenaiConfig): - super().__init__(config) - self.model = OpenaiModel(config) +class OpenaiForCausalLM(MixtralForCausalLM): + pass __all__ = [ From 6065e15c6462a0d8cfaee524558c50cc4cb30b52 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 10:41:44 +0200 Subject: [PATCH 055/342] norm dtype --- src/transformers/models/openai_moe/modeling_openai_moe.py | 2 +- src/transformers/models/openai_moe/modular_openai_moe.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 8b611165166f..c30e47fccbcd 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -60,7 +60,7 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return (self.weight * hidden_states).to(input_dtype) # main diff with Llama def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 30d58d49c3b3..e660d21ed062 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -49,7 +49,12 @@ class OpenaiRMSNorm(LlamaRMSNorm): - pass + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) # main diff with Llama class OpenaiExperts(nn.Module): From 04bb8ffe97119e1ba2ebca491ad5047d6a8bcab4 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 10:58:34 +0200 Subject: [PATCH 056/342] switch name convention --- .../models/auto/configuration_auto.py | 4 +- src/transformers/models/auto/modeling_auto.py | 4 +- .../models/auto/tokenization_auto.py | 2 +- .../openai_moe/configuration_openai_moe.py | 16 ++--- .../convert_openai_weights_to_hf.py | 16 ++--- .../models/openai_moe/modeling_openai_moe.py | 64 +++++++++---------- .../models/openai_moe/modular_openai_moe.py | 42 ++++++------ 7 files changed, 74 insertions(+), 74 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 9ffd581a0d3f..b509f3ccdf84 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -242,7 +242,7 @@ ("omdet-turbo", "OmDetTurboConfig"), ("oneformer", "OneFormerConfig"), ("open-llama", "OpenLlamaConfig"), - ("openai-moe", "OpenaiConfig"), + ("openai_moe", "OpenAIMoeConfig"), ("openai-gpt", "OpenAIGPTConfig"), ("opt", "OPTConfig"), ("owlv2", "Owlv2Config"), @@ -621,7 +621,7 @@ ("omdet-turbo", "OmDet-Turbo"), ("oneformer", "OneFormer"), ("open-llama", "OpenLlama"), - ("openai-moe", "openai"), + ("openai_moe", "OpenAIMoe"), ("openai-gpt", "OpenAI GPT"), ("opt", "OPT"), ("owlv2", "OWLv2"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index acd4c8fc33af..c9bc27f0742e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -232,7 +232,7 @@ ("omdet-turbo", "OmDetTurboForObjectDetection"), ("oneformer", "OneFormerModel"), ("open-llama", "OpenLlamaModel"), - ("openai-moe", "OpenaiModel"), + ("openai_moe", "OpenAIMoeModel"), ("openai-gpt", "OpenAIGPTModel"), ("opt", "OPTModel"), ("owlv2", "Owlv2Model"), @@ -607,7 +607,7 @@ ("olmo2", "Olmo2ForCausalLM"), ("olmoe", "OlmoeForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"), - ("openai-moe", "OpenaiForCausalLM"), + ("openai_moe", "OpenAiMoeForCausalLM"), ("openai-gpt", "OpenAIGPTLMHeadModel"), ("opt", "OPTForCausalLM"), ("pegasus", "PegasusForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 29a22fec3320..6389f74d429e 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -407,7 +407,7 @@ ), ("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ( - "openai-moe", + "openai_moe", ( None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None, diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 9725daa7fd5a..9134d87905a1 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -23,21 +23,21 @@ from ...modeling_rope_utils import rope_config_validation -class OpenaiConfig(PretrainedConfig): +class OpenAIMoeConfig(PretrainedConfig): r""" This will yield a configuration to that of the BERT [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture. """ - model_type = "openai-moe" + model_type = "openai_moe" # Default tensor parallel plan for base model `OpenaiModel` # a bit special, but this seems to work alright base_model_tp_plan = { - # "layers.*.self_attn.q_proj": "colwise", - # "layers.*.self_attn.k_proj": "colwise", - # "layers.*.self_attn.v_proj": "colwise", - # "layers.*.self_attn.o_proj": "rowwise", - # "layers.*.self_attn.sinks": "local_rowwise", + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.sinks": "local_rowwise", "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", "layers.*.mlp.experts.gate_up_proj_bias": "local_rowwise", "layers.*.mlp.experts.down_proj": "local_colwise", @@ -127,4 +127,4 @@ def __init__( ) -__all__ = ["OpenaiConfig"] +__all__ = ["OpenAIMoeConfig"] diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 38057c066a62..1a4c49d44bb7 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -25,8 +25,8 @@ from transformers import ( GenerationConfig, - OpenaiConfig, - OpenaiForCausalLM, + OpenAIMoeConfig, + OpenAIMoeForCausalLM, PreTrainedTokenizerFast, ) from transformers.convert_slow_tokenizer import TikTokenConverter @@ -87,7 +87,7 @@ def write_model( eos_token_id = 199999 if not instruct else [199999, 200018] pad_token_id = 128004 - config = OpenaiConfig() + config = OpenAIMoeConfig() print(f"Fetching all parameters from the checkpoint at {input_base_path}...") final_ = {} @@ -127,9 +127,9 @@ def write_model( del final_ gc.collect() - print("Loading the checkpoint in a OpenAI ") + print("Loading the checkpoint in a OpenAIMoe ") with torch.device("meta"): - model = OpenaiForCausalLM(config) + model = OpenAIMoeForCausalLM(config) model.load_state_dict(state_dict, strict=True, assign=True) print("Checkpoint loaded successfully.") del config._name_or_path @@ -141,7 +141,7 @@ def write_model( # Safety check: reload the converted model gc.collect() print("Reloading the model to check if it's saved correctly.") - OpenaiForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") + OpenAIMoeForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") print("Model reloaded successfully.") # generation config @@ -183,7 +183,7 @@ def bytes_to_unicode(): return dict(zip(bs, cs)) import tiktoken -class OpenaiConverter(TikTokenConverter): +class OpenAIMoeConverter(TikTokenConverter): def extract_vocab_merges_from_model(self, tiktoken_url: str): tokenizer = tiktoken.get_encoding(tiktoken_url) self.pattern = tokenizer._pat_str @@ -266,7 +266,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): "{% endif %}" ) - converter = OpenaiConverter( + converter = OpenAIMoeConverter( vocab_file=tokenizer_path, model_max_length=None, chat_template=chat_template if instruct else None, diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index c30e47fccbcd..fc5eb34f0c9f 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -39,17 +39,17 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging -from .configuration_openai_moe import OpenaiConfig +from .configuration_openai_moe import OpenAIMoeConfig logger = logging.get_logger(__name__) @use_kernel_forward_from_hub("RMSNorm") -class OpenaiRMSNorm(nn.Module): +class OpenAIMoeRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ - OpenaiRMSNorm is equivalent to T5LayerNorm + OpenAIMoeRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -66,7 +66,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class OpenaiExperts(nn.Module): +class OpenAIMoeExperts(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.num_local_experts @@ -128,13 +128,13 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig return next_states -class OpenaiMLP(nn.Module): +class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.hidden_dim = config.hidden_size self.num_local_experts = config.num_local_experts - self.experts = OpenaiExperts(config) + self.experts = OpenAIMoeExperts(config) self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=True) def forward(self, hidden_states): @@ -154,8 +154,8 @@ def forward(self, hidden_states): return output_states, router_scores -class OpenaiRotaryEmbedding(nn.Module): - def __init__(self, config: OpenaiConfig, device=None): +class OpenAIMoeRotaryEmbedding(nn.Module): + def __init__(self, config: OpenAIMoeConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: @@ -248,10 +248,10 @@ def eager_attention_forward( return attn_output, attn_weights -class OpenaiAttention(nn.Module): +class OpenAIMoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: OpenaiConfig, layer_idx: int): + def __init__(self, config: OpenAIMoeConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx @@ -320,14 +320,14 @@ def forward( return attn_output, attn_weights -class OpenaiDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: OpenaiConfig, layer_idx: int): +class OpenAIMoeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: OpenAIMoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = OpenaiAttention(config=config, layer_idx=layer_idx) - self.mlp = OpenaiMLP(config) - self.input_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = OpenAIMoeAttention(config=config, layer_idx=layer_idx) + self.mlp = OpenAIMoeMLP(config) + self.input_layernorm = OpenAIMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = OpenAIMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] def forward( @@ -374,11 +374,11 @@ def forward( @auto_docstring -class OpenaiPreTrainedModel(PreTrainedModel): - config_class = OpenaiConfig +class OpenAIMoePreTrainedModel(PreTrainedModel): + config_class = OpenAIMoeConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["OpenaiDecoderLayer"] + _no_split_modules = ["OpenAIMoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True @@ -399,25 +399,25 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, OpenaiRMSNorm): + elif isinstance(module, OpenAIMoeRMSNorm): module.weight.data.fill_(1.0) @auto_docstring -class OpenaiModel(OpenaiPreTrainedModel): - _no_split_modules = ["OpenaiDecoderLayer"] +class OpenAIMoeModel(OpenAIMoePreTrainedModel): + _no_split_modules = ["OpenAIMoeDecoderLayer"] - def __init__(self, config: OpenaiConfig): + def __init__(self, config: OpenAIMoeConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( - [OpenaiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [OpenAIMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = OpenaiRotaryEmbedding(config=config) + self.norm = OpenAIMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = OpenAIMoeRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -629,14 +629,14 @@ def load_balancing_loss_func( @auto_docstring -class OpenaiForCausalLM(OpenaiPreTrainedModel, GenerationMixin): +class OpenAIMoeForCausalLM(OpenAIMoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) - self.model = OpenaiModel(config) + self.model = OpenAIMoeModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.router_aux_loss_coef @@ -691,10 +691,10 @@ def forward( Example: ```python - >>> from transformers import AutoTokenizer, OpenaiForCausalLM + >>> from transformers import AutoTokenizer, OpenAIMoeForCausalLM - >>> model = OpenaiForCausalLM.from_pretrained("mistralai/Openai-8x7B-v0.1") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Openai-8x7B-v0.1") + >>> model = OpenAIMoeForCausalLM.from_pretrained("mistralai/OpenAIMoe-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/OpenAIMoe-8x7B-v0.1") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -760,4 +760,4 @@ def forward( ) -__all__ = ["OpenaiForCausalLM", "OpenaiModel", "OpenaiPreTrainedModel"] +__all__ = ["OpenAIMoeForCausalLM", "OpenAIMoeModel", "OpenAIMoePreTrainedModel"] diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index e660d21ed062..abab709e423a 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -42,13 +42,13 @@ ) from ..mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel from ..qwen2.modeling_qwen2 import Qwen2Attention -from .configuration_openai_moe import OpenaiConfig +from .configuration_openai_moe import OpenAIMoeConfig logger = logging.get_logger(__name__) -class OpenaiRMSNorm(LlamaRMSNorm): +class OpenAIMoeRMSNorm(LlamaRMSNorm): def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) @@ -57,7 +57,7 @@ def forward(self, hidden_states): return (self.weight * hidden_states).to(input_dtype) # main diff with Llama -class OpenaiExperts(nn.Module): +class OpenAIMoeExperts(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.num_local_experts @@ -119,13 +119,13 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig return next_states -class OpenaiMLP(nn.Module): +class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.hidden_dim = config.hidden_size self.num_local_experts = config.num_local_experts - self.experts = OpenaiExperts(config) + self.experts = OpenAIMoeExperts(config) self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=True) def forward(self, hidden_states): @@ -145,7 +145,7 @@ def forward(self, hidden_states): return output_states, router_scores -class OpenaiRotaryEmbedding(LlamaRotaryEmbedding): +class OpenAIMoeRotaryEmbedding(LlamaRotaryEmbedding): @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): @@ -244,8 +244,8 @@ def attention_sink(score, b, h, q_idx, kv_idx): ALL_ATTENTION_FUNCTIONS.register("openai_flex_attention", openai_flex_attention_forward) -class OpenaiAttention(Qwen2Attention): - def __init__(self, config: OpenaiConfig, layer_idx: int): +class OpenAIMoeAttention(Qwen2Attention): + def __init__(self, config: OpenAIMoeConfig, layer_idx: int): super().__init__(config, layer_idx) self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias @@ -262,14 +262,14 @@ def __init__(self, config: OpenaiConfig, layer_idx: int): self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) -class OpenaiDecoderLayer(LlamaDecoderLayer): - def __init__(self, config: OpenaiConfig, layer_idx: int): +class OpenAIMoeDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: OpenAIMoeConfig, layer_idx: int): super().__init__(config, layer_idx) self.hidden_size = config.hidden_size - self.self_attn = OpenaiAttention(config=config, layer_idx=layer_idx) - self.mlp = OpenaiMLP(config) - self.input_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = OpenaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = OpenAIMoeAttention(config=config, layer_idx=layer_idx) + self.mlp = OpenAIMoeMLP(config) + self.input_layernorm = OpenAIMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = OpenAIMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] def forward( @@ -315,12 +315,12 @@ def forward( return outputs -class OpenaiPreTrainedModel(LlamaPreTrainedModel): +class OpenAIMoePreTrainedModel(LlamaPreTrainedModel): _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] -class OpenaiModel(MixtralModel): - _no_split_modules = ["OpenaiDecoderLayer"] +class OpenAIMoeModel(MixtralModel): + _no_split_modules = ["OpenAIMoeDecoderLayer"] @can_return_tuple @auto_docstring @@ -436,12 +436,12 @@ def forward( ) -class OpenaiForCausalLM(MixtralForCausalLM): +class OpenAIMoeForCausalLM(MixtralForCausalLM): pass __all__ = [ - "OpenaiForCausalLM", - "OpenaiModel", - "OpenaiPreTrainedModel", + "OpenAIMoeForCausalLM", + "OpenAIMoeModel", + "OpenAIMoePreTrainedModel", ] From 6d0bfededadeba6e9f02f714229cb4f5b6840bc8 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 11:02:04 +0200 Subject: [PATCH 057/342] Update __init__.py --- src/transformers/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index d2988a960b8b..96e28e4e35f8 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -217,6 +217,7 @@ from .omdet_turbo import * from .oneformer import * from .openai import * + from .openai_moe import * from .opt import * from .owlv2 import * from .owlvit import * From 09d4d7d28745544ac5bd570f8ee9238af415f06e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 11:05:00 +0200 Subject: [PATCH 058/342] Update modeling_auto.py --- src/transformers/models/auto/modeling_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index c9bc27f0742e..4925fc785cf3 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -607,7 +607,7 @@ ("olmo2", "Olmo2ForCausalLM"), ("olmoe", "OlmoeForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"), - ("openai_moe", "OpenAiMoeForCausalLM"), + ("openai_moe", "OpenAIMoeForCausalLM"), ("openai-gpt", "OpenAIGPTLMHeadModel"), ("opt", "OPTForCausalLM"), ("pegasus", "PegasusForCausalLM"), From 0148f6838690c7a05857ace279993850b2c99493 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 11:15:18 +0200 Subject: [PATCH 059/342] Add back old model --- src/transformers/models/openai/__init__.py | 30 + .../models/openai/configuration_openai.py | 2 +- ...penai_original_tf_checkpoint_to_pytorch.py | 74 ++ .../models/openai/modeling_openai.py | 867 ++++++++++++++++ .../models/openai/modeling_tf_openai.py | 937 ++++++++++++++++++ .../models/openai/tokenization_openai.py | 396 ++++++++ .../models/openai/tokenization_openai_fast.py | 66 ++ 7 files changed, 2371 insertions(+), 1 deletion(-) create mode 100644 src/transformers/models/openai/__init__.py create mode 100755 src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/openai/modeling_openai.py create mode 100644 src/transformers/models/openai/modeling_tf_openai.py create mode 100644 src/transformers/models/openai/tokenization_openai.py create mode 100644 src/transformers/models/openai/tokenization_openai_fast.py diff --git a/src/transformers/models/openai/__init__.py b/src/transformers/models/openai/__init__.py new file mode 100644 index 000000000000..a07b0ab669f3 --- /dev/null +++ b/src/transformers/models/openai/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_openai import * + from .modeling_openai import * + from .modeling_tf_openai import * + from .tokenization_openai import * + from .tokenization_openai_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/openai/configuration_openai.py b/src/transformers/models/openai/configuration_openai.py index 7dc1525fe0ef..b4f2fae9d304 100644 --- a/src/transformers/models/openai/configuration_openai.py +++ b/src/transformers/models/openai/configuration_openai.py @@ -153,4 +153,4 @@ def __init__( super().__init__(**kwargs) -__all__ = ["OpenAIGPTConfig"] \ No newline at end of file +__all__ = ["OpenAIGPTConfig"] diff --git a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py new file mode 100755 index 000000000000..3d5218c20426 --- /dev/null +++ b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,74 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OpenAI GPT checkpoint.""" + +import argparse + +import torch + +from transformers import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt +from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging + + +logging.set_verbosity_info() + + +def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): + # Construct model + if openai_config_file == "": + config = OpenAIGPTConfig() + else: + config = OpenAIGPTConfig.from_json_file(openai_config_file) + model = OpenAIGPTModel(config) + + # Load weights from numpy + load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) + + # Save pytorch-model + pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME + pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME + print(f"Save PyTorch model to {pytorch_weights_dump_path}") + torch.save(model.state_dict(), pytorch_weights_dump_path) + print(f"Save configuration file to {pytorch_config_dump_path}") + with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--openai_checkpoint_folder_path", + default=None, + type=str, + required=True, + help="Path to the TensorFlow checkpoint path.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--openai_config_file", + default="", + type=str, + help=( + "An optional config json file corresponding to the pre-trained OpenAI model. \n" + "This specifies the model architecture." + ), + ) + args = parser.parse_args() + convert_openai_checkpoint_to_pytorch( + args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path + ) diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py new file mode 100644 index 000000000000..55aa53c40ffd --- /dev/null +++ b/src/transformers/models/openai/modeling_openai.py @@ -0,0 +1,867 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT model.""" + +import json +import math +import os +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import gelu_new, get_activation, silu +from ...generation import GenerationMixin +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from ...utils import ( + ModelOutput, + auto_docstring, + logging, +) +from .configuration_openai import OpenAIGPTConfig + + +logger = logging.get_logger(__name__) + + +def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): + """Load tf pre-trained weights in a pytorch model (from NumPy arrays here)""" + import re + + import numpy as np + + if ".ckpt" in openai_checkpoint_folder_path: + openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path) + + logger.info(f"Loading weights from {openai_checkpoint_folder_path}") + + with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle: + names = json.load(names_handle) + with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle: + shapes = json.load(shapes_handle) + offsets = np.cumsum([np.prod(shape) for shape in shapes]) + init_params = [np.load(openai_checkpoint_folder_path + f"/params_{n}.npy") for n in range(10)] + init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] + init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] + + # This was used when we had a single embedding matrix for positions and tokens + # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0) + # del init_params[1] + init_params = [arr.squeeze() for arr in init_params] + + # Check that the token and position embeddings weight dimensions map those of the init parameters. + if model.tokens_embed.weight.shape != init_params[1].shape: + raise ValueError( + f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape:" + f" {init_params[1].shape}" + ) + + if model.positions_embed.weight.shape != init_params[0].shape: + raise ValueError( + f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape:" + f" {init_params[0].shape}" + ) + + model.tokens_embed.weight.data = torch.from_numpy(init_params[1]) + model.positions_embed.weight.data = torch.from_numpy(init_params[0]) + names.pop(0) + # Pop position and token embedding arrays + init_params.pop(0) + init_params.pop(0) + + for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]): + name = name[6:] # skip "model/" + if name[-2:] != ":0": + raise ValueError(f"Layer {name} does not end with :0") + name = name[:-2] + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "w": + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + + # Ensure that the pointer and array have compatible shapes. + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +ACT_FNS = {"relu": nn.ReLU(), "silu": silu, "gelu": gelu_new, "swish": silu} + + +class Attention(nn.Module): + def __init__(self, nx, n_positions, config, scale=False): + super().__init__() + n_state = nx # in Attention: n_state=768 (nx=n_embd) + # [switch nx => n_state from Block to Attention to keep identical to TF implementation] + if n_state % config.n_head != 0: + raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}") + self.register_buffer( + "bias", + torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions), + persistent=False, + ) + self.n_head = config.n_head + self.split_size = n_state + self.scale = scale + + self.c_attn = Conv1D(n_state * 3, nx) + self.c_proj = Conv1D(n_state, nx) + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_head, self.split_size // self.n_head, self.pruned_heads + ) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + # Update hyper params + self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) + self.n_head = self.n_head - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): + w = torch.matmul(q, k) + if self.scale: + w = w / math.sqrt(v.size(-1)) + # w = w * self.bias + -1e9 * (1 - self.bias) # TF implementation method: mask_attn_weights + # XD: self.b may be larger than w, so we need to crop it + b = self.bias[:, :, : w.size(-2), : w.size(-1)] + w = w * b + -1e4 * (1 - b) + + if attention_mask is not None: + # Apply the attention mask + w = w + attention_mask + + w = nn.functional.softmax(w, dim=-1) + w = self.attn_dropout(w) + + # Mask heads if we want to + if head_mask is not None: + w = w * head_mask + + outputs = [torch.matmul(w, v)] + if output_attentions: + outputs.append(w) + return outputs + + def merge_heads(self, x): + x = x.permute(0, 2, 1, 3).contiguous() + new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) + return x.view(*new_x_shape) # in Tensorflow implementation: fct merge_states + + def split_heads(self, x, k=False): + new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) + x = x.view(*new_x_shape) # in Tensorflow implementation: fct split_states + if k: + return x.permute(0, 2, 3, 1) + else: + return x.permute(0, 2, 1, 3) + + def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False): + x = self.c_attn(x) + query, key, value = x.split(self.split_size, dim=2) + query = self.split_heads(query) + key = self.split_heads(key, k=True) + value = self.split_heads(value) + + attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) + a = attn_outputs[0] + + a = self.merge_heads(a) + a = self.c_proj(a) + a = self.resid_dropout(a) + + outputs = [a] + attn_outputs[1:] + return outputs # a, (attentions) + + +class MLP(nn.Module): + def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) + super().__init__() + nx = config.n_embd + self.c_fc = Conv1D(n_state, nx) + self.c_proj = Conv1D(nx, n_state) + self.act = ACT_FNS[config.afn] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, x): + h = self.act(self.c_fc(x)) + h2 = self.c_proj(h) + return self.dropout(h2) + + +class Block(nn.Module): + def __init__(self, n_positions, config, scale=False): + super().__init__() + nx = config.n_embd + self.attn = Attention(nx, n_positions, config, scale) + self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) + self.mlp = MLP(4 * nx, config) + self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) + + def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False): + attn_outputs = self.attn( + x, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + ) + a = attn_outputs[0] + + n = self.ln_1(x + a) + m = self.mlp(n) + h = self.ln_2(n + m) + + outputs = [h] + attn_outputs[1:] + return outputs + + +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->OpenAIGPT +class OpenAIGPTSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`OpenAIGPTConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: OpenAIGPTConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + +@auto_docstring +class OpenAIGPTPreTrainedModel(PreTrainedModel): + config_class = OpenAIGPTConfig + load_tf_weights = load_tf_weights_in_openai_gpt + base_model_prefix = "transformer" + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@dataclass +class OpenAIGPTDoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + mc_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@auto_docstring +class OpenAIGPTModel(OpenAIGPTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd) + self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)]) + + self.register_buffer("position_ids", torch.arange(config.n_positions), persistent=False) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.tokens_embed + + def set_input_embeddings(self, new_embeddings): + self.tokens_embed = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if position_ids is None: + # Code is different from when we had a single embedding matrix from position and token embeddings + position_ids = self.position_ids[None, : input_shape[-1]] + + # Attention mask. + if attention_mask is not None: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.tokens_embed(input_ids) + position_embeds = self.positions_embed(position_ids) + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) + token_type_embeds = self.tokens_embed(token_type_ids) + else: + token_type_embeds = 0 + hidden_states = inputs_embeds + position_embeds + token_type_embeds + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, block in enumerate(self.h): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions) + hidden_states = outputs[0] + if output_attentions: + all_attentions = all_attentions + (outputs[1],) + + hidden_states = hidden_states.view(*output_shape) + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +@auto_docstring( + custom_intro=""" + OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """ +) +class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = OpenAIGPTModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Flatten the tokens + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, + logits=lm_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: + # Overwritten -- old model with reduced inputs + return {"input_ids": input_ids} + + +@auto_docstring( + custom_intro=""" + OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for + RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the + input embeddings, the classification head takes as input the input of a specified classification token index in the + input sequence). + """ +) +class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + + config.num_labels = 1 + self.transformer = OpenAIGPTModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = OpenAIGPTSequenceSummary(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], OpenAIGPTDoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-1, 0, ..., config.vocab_size]` All labels set to `-100` are + ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Examples: + + ```python + >>> from transformers import AutoTokenizer, OpenAIGPTDoubleHeadsModel + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt") + >>> model = OpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt") + >>> tokenizer.add_special_tokens( + ... {"cls_token": "[CLS]"} + ... ) # Add a [CLS] to the vocabulary (we should train it also!) + >>> model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices + >>> mc_token_ids = torch.tensor([input_ids.size(-1) - 1, input_ids.size(-1) - 1]).unsqueeze(0) # Batch size 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + lm_loss, mc_loss = None, None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + if labels is not None: + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return OpenAIGPTDoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The Original OpenAI GPT Model transformer with a sequence classification head on top (linear layer). + [`OpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal + models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the + last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding + token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since + it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take + the last value in each row of the batch). + """ +) +class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = OpenAIGPTModel(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + # Ensure the batch size is > 1 if there is no padding. + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +__all__ = [ + "OpenAIGPTDoubleHeadsModel", + "OpenAIGPTForSequenceClassification", + "OpenAIGPTLMHeadModel", + "OpenAIGPTModel", + "OpenAIGPTPreTrainedModel", + "load_tf_weights_in_openai_gpt", +] diff --git a/src/transformers/models/openai/modeling_tf_openai.py b/src/transformers/models/openai/modeling_tf_openai.py new file mode 100644 index 000000000000..3856711d1062 --- /dev/null +++ b/src/transformers/models/openai/modeling_tf_openai.py @@ -0,0 +1,937 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 OpenAI GPT model.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFConv1D, + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + TFSequenceSummary, + TFSharedEmbeddings, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_openai import OpenAIGPTConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/openai-gpt" +_CONFIG_FOR_DOC = "OpenAIGPTConfig" + + +class TFAttention(keras.layers.Layer): + def __init__(self, nx, config, scale=False, **kwargs): + super().__init__(**kwargs) + + n_state = nx # in Attention: n_state=768 (nx=n_embd) + # [switch nx => n_state from Block to Attention to keep identical to TF implementation] + assert n_state % config.n_head == 0, ( + f"Hidden dimension {n_state} not dividable by number of heads {config.n_head}" + ) + self.n_head = config.n_head + self.split_size = n_state + self.scale = scale + self.output_attentions = config.output_attentions + + self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn") + self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj") + self.attn_dropout = keras.layers.Dropout(config.attn_pdrop) + self.resid_dropout = keras.layers.Dropout(config.resid_pdrop) + self.n_state = n_state + self.pruned_heads = set() + + def prune_heads(self, heads): + pass + + @staticmethod + def causal_attention_mask(nd, ns): + """ + 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), + -1, ns-nd), but doesn't produce garbage on TPUs. + """ + i = tf.range(nd)[:, None] + j = tf.range(ns) + m = i >= j - ns + nd + return m + + def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False): + # q, k, v have shape [batch, heads, sequence, features] + w = tf.matmul(q, k, transpose_b=True) + if self.scale: + dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores + w = w / tf.math.sqrt(dk) + + # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. + _, _, nd, ns = shape_list(w) + b = tf.cast(self.causal_attention_mask(nd, ns), dtype=w.dtype) + b = tf.reshape(b, [1, 1, nd, ns]) + w = w * b - 1e4 * (1 - b) + + if attention_mask is not None: + # Apply the attention mask + attention_mask = tf.cast(attention_mask, dtype=w.dtype) + w = w + attention_mask + + w = stable_softmax(w, axis=-1) + w = self.attn_dropout(w, training=training) + + # Mask heads if we want to + if head_mask is not None: + w = w * head_mask + + outputs = [tf.matmul(w, v)] + if output_attentions: + outputs.append(w) + return outputs + + def merge_heads(self, x): + x = tf.transpose(x, [0, 2, 1, 3]) + x_shape = shape_list(x) + new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]] + return tf.reshape(x, new_x_shape) + + def split_heads(self, x): + x_shape = shape_list(x) + new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head] + x = tf.reshape(x, new_x_shape) + return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) + + def call(self, x, attention_mask, head_mask, output_attentions, training=False): + x = self.c_attn(x) + query, key, value = tf.split(x, 3, axis=2) + query = self.split_heads(query) + key = self.split_heads(key) + value = self.split_heads(value) + + attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training) + a = attn_outputs[0] + + a = self.merge_heads(a) + a = self.c_proj(a) + a = self.resid_dropout(a, training=training) + + outputs = [a] + attn_outputs[1:] + return outputs # a, (attentions) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "c_attn", None) is not None: + with tf.name_scope(self.c_attn.name): + self.c_attn.build([None, None, self.n_state * 3]) + if getattr(self, "c_proj", None) is not None: + with tf.name_scope(self.c_proj.name): + self.c_proj.build([None, None, self.n_state]) + + +class TFMLP(keras.layers.Layer): + def __init__(self, n_state, config, **kwargs): + super().__init__(**kwargs) + nx = config.n_embd + self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc") + self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj") + self.act = get_tf_activation("gelu") + self.dropout = keras.layers.Dropout(config.resid_pdrop) + self.nx = nx + self.n_state = n_state + + def call(self, x, training=False): + h = self.act(self.c_fc(x)) + h2 = self.c_proj(h) + h2 = self.dropout(h2, training=training) + return h2 + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "c_fc", None) is not None: + with tf.name_scope(self.c_fc.name): + self.c_fc.build([None, None, self.n_state]) + if getattr(self, "c_proj", None) is not None: + with tf.name_scope(self.c_proj.name): + self.c_proj.build([None, None, self.nx]) + + +class TFBlock(keras.layers.Layer): + def __init__(self, config, scale=False, **kwargs): + super().__init__(**kwargs) + nx = config.n_embd + self.attn = TFAttention(nx, config, scale, name="attn") + self.ln_1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1") + self.mlp = TFMLP(4 * nx, config, name="mlp") + self.ln_2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2") + self.nx = nx + + def call(self, x, attention_mask, head_mask, output_attentions, training=False): + output_attn = self.attn(x, attention_mask, head_mask, output_attentions, training=training) + a = output_attn[0] # output_attn: a, (attentions) + + n = self.ln_1(x + a) + m = self.mlp(n, training=training) + h = self.ln_2(n + m) + + outputs = [h] + output_attn[1:] + return outputs # x, (attentions) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attn", None) is not None: + with tf.name_scope(self.attn.name): + self.attn.build(None) + if getattr(self, "ln_1", None) is not None: + with tf.name_scope(self.ln_1.name): + self.ln_1.build([None, None, self.nx]) + if getattr(self, "mlp", None) is not None: + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + if getattr(self, "ln_2", None) is not None: + with tf.name_scope(self.ln_2.name): + self.ln_2.build([None, None, self.nx]) + + +@keras_serializable +class TFOpenAIGPTMainLayer(keras.layers.Layer): + config_class = OpenAIGPTConfig + + def __init__(self, config, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + self.config = config + self.output_hidden_states = config.output_hidden_states + self.output_attentions = config.output_attentions + self.return_dict = config.use_return_dict + self.num_hidden_layers = config.n_layer + self.n_embd = config.n_embd + self.n_positions = config.n_positions + self.initializer_range = config.initializer_range + + self.tokens_embed = TFSharedEmbeddings( + config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed" + ) + self.drop = keras.layers.Dropout(config.embd_pdrop) + self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)] + + def build(self, input_shape=None): + with tf.name_scope("positions_embed"): + self.positions_embed = self.add_weight( + name="embeddings", + shape=[self.n_positions, self.n_embd], + initializer=get_initializer(self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "tokens_embed", None) is not None: + with tf.name_scope(self.tokens_embed.name): + self.tokens_embed.build(None) + if getattr(self, "h", None) is not None: + for layer in self.h: + with tf.name_scope(layer.name): + layer.build(None) + + def get_input_embeddings(self): + return self.tokens_embed + + def set_input_embeddings(self, value): + self.tokens_embed.weight = value + self.tokens_embed.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFBaseModelOutput]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if position_ids is None: + position_ids = tf.expand_dims(tf.range(input_shape[-1]), axis=0) + + if attention_mask is not None: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + + one_cst = tf.constant(1.0) + attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype) + attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0)) + else: + attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.num_hidden_layers + # head_mask = tf.constant([0] * self.num_hidden_layers) + + position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = self.tokens_embed(input_ids, mode="embedding") + position_embeds = tf.gather(self.positions_embed, position_ids) + if token_type_ids is not None: + token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) + check_embeddings_within_bounds(token_type_ids, self.config.vocab_size, "token_type_ids") + token_type_embeds = self.tokens_embed(token_type_ids, mode="embedding") + else: + token_type_embeds = 0 + hidden_states = inputs_embeds + position_embeds + token_type_embeds + hidden_states = self.drop(hidden_states, training=training) + + output_shape = input_shape + [shape_list(hidden_states)[-1]] + + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, block in enumerate(self.h): + if output_hidden_states: + all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) + + outputs = block( + hidden_states, + attention_mask, + head_mask[i], + output_attentions, + training=training, + ) + hidden_states = outputs[0] + if output_attentions: + all_attentions = all_attentions + (outputs[1],) + + hidden_states = tf.reshape(hidden_states, output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if output_attentions: + # let the number of heads free (-1) so we can extract attention even after head pruning + attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] + all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + + return TFBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = OpenAIGPTConfig + base_model_prefix = "transformer" + + +@dataclass +class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + logits (`tf.Tensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`tf.Tensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: Optional[tf.Tensor] = None + mc_logits: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +OPENAI_GPT_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`OpenAIGPTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +OPENAI_GPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`tf.Tensor` or `Numpy array` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.", + OPENAI_GPT_START_DOCSTRING, +) +class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") + + @unpack_inputs + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFBaseModelOutput]: + outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + + +@add_start_docstrings( + """ + OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + OPENAI_GPT_START_DOCSTRING, +) +class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelingLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") + # OpenAIGPT does not have past caching features + self.supports_xla_generation = False + + def get_output_embeddings(self): + return self.get_input_embeddings() + + def set_output_embeddings(self, value): + self.set_input_embeddings(value) + + @unpack_inputs + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFCausalLMOutput]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + hidden_states = transformer_outputs[0] + + logits = self.transformer.tokens_embed(hidden_states, mode="linear") + + loss = None + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels, shifted_logits) + + if not return_dict: + output = (logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def prepare_inputs_for_generation(self, inputs, **kwargs): + return {"input_ids": inputs} + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + + +@add_start_docstrings( + """ + OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for + RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the + input embeddings, the classification head takes as input the input of a specified classification token index in the + input sequence). + """, + OPENAI_GPT_START_DOCSTRING, +) +class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + config.num_labels = 1 + self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") + self.multiple_choice_head = TFSequenceSummary( + config, initializer_range=config.initializer_range, name="multiple_choice_head" + ) + + @unpack_inputs + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + mc_token_ids: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFOpenAIGPTDoubleHeadsModelOutput]: + r""" + mc_token_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + + Return: + + Examples: + + ```python + >>> import tensorflow as tf + >>> from transformers import AutoTokenizer, TFOpenAIGPTDoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt") + >>> model = TFOpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size + >>> print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoding = tokenizer(choices, return_tensors="tf") + >>> inputs = {k: tf.expand_dims(v, 0) for k, v in encoding.items()} + >>> inputs["mc_token_ids"] = tf.constant( + ... [inputs["input_ids"].shape[-1] - 1, inputs["input_ids"].shape[-1] - 1] + ... )[ + ... None, : + ... ] # Batch size 1 + >>> outputs = model(inputs) + >>> lm_prediction_scores, mc_prediction_scores = outputs[:2] + ```""" + + if input_ids is not None: + input_shapes = shape_list(input_ids) + else: + input_shapes = shape_list(inputs_embeds)[:-1] + + seq_length = input_shapes[-1] + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None + transformer_outputs = self.transformer( + flat_input_ids, + flat_attention_mask, + flat_token_type_ids, + flat_position_ids, + head_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + return_dict=return_dict, + training=training, + ) + hidden_states = transformer_outputs[0] + hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) + if return_dict and output_hidden_states: + # We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the + # input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged) + all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,) + else: + all_hidden_states = None + lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear") + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) + mc_logits = tf.squeeze(mc_logits, axis=-1) + + if not return_dict: + return (lm_logits, mc_logits) + transformer_outputs[1:] + + return TFOpenAIGPTDoubleHeadsModelOutput( + logits=lm_logits, + mc_logits=mc_logits, + hidden_states=all_hidden_states, + attentions=transformer_outputs.attentions, + ) + + @property + def input_signature(self): + return { + "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"), + "mc_token_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"), + } + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + if getattr(self, "multiple_choice_head", None) is not None: + with tf.name_scope(self.multiple_choice_head.name): + self.multiple_choice_head.build(None) + + +@add_start_docstrings( + """ + The OpenAI GPT Model transformer with a sequence classification head on top (linear layer). + + [`TFOpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal + models (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + OPENAI_GPT_START_DOCSTRING, +) +class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + self.score = keras.layers.Dense( + config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="score", + use_bias=False, + ) + self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFSequenceClassifierOutput]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + logits_shape = shape_list(logits) + batch_size = logits_shape[0] + + if self.config.pad_token_id is None: + last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) + else: + if input_ids is not None: + token_indices = tf.range(shape_list(input_ids)[-1]) + non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype) + last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1) + else: + last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + loss = None + + pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1) + + if labels is not None: + if self.config.pad_token_id is None and logits_shape[0] != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + + loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels])) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "score", None) is not None: + with tf.name_scope(self.score.name): + self.score.build([None, None, self.config.n_embd]) + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + + +__all__ = [ + "TFOpenAIGPTDoubleHeadsModel", + "TFOpenAIGPTForSequenceClassification", + "TFOpenAIGPTLMHeadModel", + "TFOpenAIGPTMainLayer", + "TFOpenAIGPTModel", + "TFOpenAIGPTPreTrainedModel", +] diff --git a/src/transformers/models/openai/tokenization_openai.py b/src/transformers/models/openai/tokenization_openai.py new file mode 100644 index 000000000000..cbfb41fc888f --- /dev/null +++ b/src/transformers/models/openai/tokenization_openai.py @@ -0,0 +1,396 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OpenAI GPT.""" + +import json +import os +import re +import unicodedata +from typing import Optional, Tuple + +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + + +# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length + strings) + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def text_standardize(text): + """ + fixes some issues the spacy tokenizer had on books corpus also does some whitespace standardization + """ + text = text.replace("—", "-") + text = text.replace("–", "-") + text = text.replace("―", "-") + text = text.replace("…", "...") + text = text.replace("´", "'") + text = re.sub(r"""(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)""", r" \1 ", text) + text = re.sub(r"\s*\n\s*", " \n ", text) + text = re.sub(r"[^\S\n]+", " ", text) + return text.strip() + + +class OpenAIGPTTokenizer(PreTrainedTokenizer): + """ + Construct a GPT Tokenizer. Based on Byte-Pair-Encoding with the following peculiarities: + + - lowercases all inputs, + - uses `SpaCy` tokenizer and `ftfy` for pre-BPE tokenization if they are installed, fallback to BERT's + `BasicTokenizer` if not. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): + try: + import ftfy + from spacy.lang.en import English + + _nlp = English() + self.nlp = _nlp.tokenizer + self.fix_text = ftfy.fix_text + except ImportError: + logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") + self.nlp = BasicTokenizer(do_lower_case=True) + self.fix_text = None + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + merges = merges_handle.read().split("\n")[1:-1] + merges = [tuple(merge.split()) for merge in merges] + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {} + + super().__init__(unk_token=unk_token, **kwargs) + + @property + def do_lower_case(self): + return True + + @property + def vocab_size(self): + return len(self.encoder) + + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + def bpe(self, token): + word = tuple(token[:-1]) + (token[-1] + "",) + if token in self.cache: + return self.cache[token] + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + if word == "\n ": + word = "\n" + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + split_tokens = [] + if self.fix_text is None: + # Using BERT's BasicTokenizer + text = self.nlp.tokenize(text) + for token in text: + split_tokens.extend(list(self.bpe(token).split(" "))) + else: + # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) + text = self.nlp(text_standardize(self.fix_text(text))) + for token in text: + split_tokens.extend(list(self.bpe(token.text.lower()).split(" "))) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an id in a token (BPE) using the vocab.""" + return self.decoder.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = "".join(tokens).replace("", " ").strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + +__all__ = ["OpenAIGPTTokenizer"] diff --git a/src/transformers/models/openai/tokenization_openai_fast.py b/src/transformers/models/openai/tokenization_openai_fast.py new file mode 100644 index 000000000000..c17d7d29b7dd --- /dev/null +++ b/src/transformers/models/openai/tokenization_openai_fast.py @@ -0,0 +1,66 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization classes for OpenAI GPT.""" + +from typing import Optional, Tuple + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_openai import OpenAIGPTTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + + +class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" GPT Tokenizer (backed by HuggingFace's *tokenizers* library). Based on Byte-Pair-Encoding with + the following peculiarities: + + - lower case all inputs + - uses BERT's BasicTokenizer for pre-BPE tokenization + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = OpenAIGPTTokenizer + + def __init__(self, vocab_file=None, merges_file=None, tokenizer_file=None, unk_token="", **kwargs): + super().__init__(vocab_file, merges_file, tokenizer_file=tokenizer_file, unk_token=unk_token, **kwargs) + + @property + def do_lower_case(self): + return True + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["OpenAIGPTTokenizerFast"] From 823428a93063076f78fa1f8079cc19546a584590 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 11:21:49 +0200 Subject: [PATCH 060/342] Finalize rename --- docs/source/en/_toctree.yml | 4 +- .../en/model_doc/{openai.md => openai_moe.md} | 0 tests/models/openai/test_modeling_openai.py | 977 ++++-------------- tests/models/openai_moe/__init__.py | 0 .../openai_moe/test_modeling_openai_moe.py | 932 +++++++++++++++++ 5 files changed, 1112 insertions(+), 801 deletions(-) rename docs/source/en/model_doc/{openai.md => openai_moe.md} (100%) create mode 100644 tests/models/openai_moe/__init__.py create mode 100644 tests/models/openai_moe/test_modeling_openai_moe.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d6bf7c8ee679..3fe78f64459e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -595,8 +595,8 @@ title: OLMoE - local: model_doc/open-llama title: Open-Llama - - local: model_doc/openai - title: openai + - local: model_doc/openai_moe + title: OpenAIMoe - local: model_doc/opt title: OPT - local: model_doc/pegasus diff --git a/docs/source/en/model_doc/openai.md b/docs/source/en/model_doc/openai_moe.md similarity index 100% rename from docs/source/en/model_doc/openai.md rename to docs/source/en/model_doc/openai_moe.md diff --git a/tests/models/openai/test_modeling_openai.py b/tests/models/openai/test_modeling_openai.py index 4239e0816bf7..bba4ad8660fb 100644 --- a/tests/models/openai/test_modeling_openai.py +++ b/tests/models/openai/test_modeling_openai.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,24 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Testing suite for the PyTorch openai model.""" + import unittest -from packaging import version -from parameterized import parameterized - -from transformers import AutoTokenizer, OpenaiConfig, StaticCache, is_torch_available, set_seed -from transformers.generation.configuration_utils import GenerationConfig -from transformers.testing_utils import ( - Expectations, - cleanup, - require_read_token, - require_torch, - require_torch_accelerator, - slow, - torch_device, -) +from transformers import is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -40,25 +28,22 @@ import torch from transformers import ( - LlamaTokenizer, - OpenaiForCausalLM, - OpenaiForQuestionAnswering, - OpenaiForSequenceClassification, - OpenaiForTokenClassification, - OpenaiModel, + OpenAIGPTConfig, + OpenAIGPTDoubleHeadsModel, + OpenAIGPTForSequenceClassification, + OpenAIGPTLMHeadModel, + OpenAIGPTModel, ) - from transformers.models.openai.modeling_openai import OpenaiRotaryEmbedding -class OpenaiModelTester: +class OpenAIGPTModelTester: def __init__( self, parent, batch_size=13, seq_length=7, is_training=True, - use_input_mask=True, - use_token_type_ids=False, + use_token_type_ids=True, use_labels=True, vocab_size=99, hidden_size=32, @@ -74,14 +59,12 @@ def __init__( initializer_range=0.02, num_labels=3, num_choices=4, - pad_token_id=0, scope=None, ): self.parent = parent self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training - self.use_input_mask = use_input_mask self.use_token_type_ids = use_token_type_ids self.use_labels = use_labels self.vocab_size = vocab_size @@ -98,16 +81,12 @@ def __init__( self.initializer_range = initializer_range self.num_labels = num_labels self.num_choices = num_choices - self.pad_token_id = pad_token_id self.scope = scope + self.pad_token_id = self.vocab_size - 1 def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - token_type_ids = None if self.use_token_type_ids: token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) @@ -120,813 +99,213 @@ def prepare_config_and_inputs(self): token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return OpenaiConfig( + config = OpenAIGPTConfig( vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, + n_embd=self.hidden_size, + n_layer=self.num_hidden_layers, + n_head=self.num_attention_heads, + # intermediate_size=self.intermediate_size, + # hidden_act=self.hidden_act, + # hidden_dropout_prob=self.hidden_dropout_prob, + # attention_probs_dropout_prob=self.attention_probs_dropout_prob, + n_positions=self.max_position_embeddings, + # type_vocab_size=self.type_vocab_size, + # initializer_range=self.initializer_range pad_token_id=self.pad_token_id, ) - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = OpenaiModel(config=config) + head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) + + return ( + config, + input_ids, + head_mask, + token_type_ids, + sequence_labels, + token_labels, + choice_labels, + ) + + def create_and_check_openai_gpt_model(self, config, input_ids, head_mask, token_type_ids, *args): + model = OpenAIGPTModel(config=config) model.to(torch_device) model.eval() - result = model(input_ids, attention_mask=input_mask) + + result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) + result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + def create_and_check_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args): + model = OpenAIGPTLMHeadModel(config) + model.to(torch_device) + model.eval() + + result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_double_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args): + model = OpenAIGPTDoubleHeadsModel(config) + model.to(torch_device) + model.eval() + + result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_openai_gpt_for_sequence_classification( + self, config, input_ids, head_mask, token_type_ids, *args + ): + config.num_labels = self.num_labels + model = OpenAIGPTForSequenceClassification(config) + model.to(torch_device) + model.eval() + + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( config, input_ids, + head_mask, token_type_ids, - input_mask, sequence_labels, token_labels, choice_labels, ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "head_mask": head_mask, + } + return config, inputs_dict @require_torch -class OpenaiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - ( - OpenaiModel, - OpenaiForCausalLM, - OpenaiForSequenceClassification, - OpenaiForQuestionAnswering, - OpenaiForTokenClassification, - ) + (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTForSequenceClassification) if is_torch_available() else () ) - test_headmasking = False - test_pruning = False - fx_compatible = False # Broken by attention refactor cc @Cyrilvallez - - # Need to use `0.8` instead of `0.9` for `test_cpu_offload` - # This is because we are hitting edge cases with the causal_mask buffer - model_split_percents = [0.5, 0.7, 0.8] + pipeline_model_mapping = ( + { + "feature-extraction": OpenAIGPTModel, + "text-classification": OpenAIGPTForSequenceClassification, + "text-generation": OpenAIGPTLMHeadModel, + "zero-shot": OpenAIGPTForSequenceClassification, + } + if is_torch_available() + else {} + ) - # used in `test_torch_compile_for_training` - _torch_compile_train_cls = OpenaiForCausalLM if is_torch_available() else None + # TODO: Fix the failed tests + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + if pipeline_test_case_name == "ZeroShotClassificationPipelineTests": + # Get `tokenizer does not have a padding token` error for both fast/slow tokenizers. + # `OpenAIGPTConfig` was never used in pipeline tests, either because of a missing checkpoint or because a + # tiny config could not be created. + return True + + return False + + # special case for DoubleHeads model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "OpenAIGPTDoubleHeadsModel": + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.num_choices, self.model_tester.seq_length), + dtype=torch.long, + device=torch_device, + ) + inputs_dict["input_ids"] = inputs_dict["labels"] + inputs_dict["token_type_ids"] = inputs_dict["labels"] + inputs_dict["mc_token_ids"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.num_choices), + dtype=torch.long, + device=torch_device, + ) + inputs_dict["mc_labels"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + return inputs_dict def setUp(self): - self.model_tester = OpenaiModelTester(self) - self.config_tester = ConfigTester(self, config_class=OpenaiConfig, hidden_size=37) + self.model_tester = OpenAIGPTModelTester(self) + self.config_tester = ConfigTester(self, config_class=OpenAIGPTConfig, n_embd=37) def test_config(self): self.config_tester.run_common_tests() - def test_model(self): + def test_openai_gpt_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) + self.model_tester.create_and_check_openai_gpt_model(*config_and_inputs) - def test_model_various_embeddings(self): + def test_openai_gpt_lm_head_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_openai_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = OpenaiForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_openai_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = OpenaiForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_openai_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = OpenaiForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_openai_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = OpenaiForTokenClassification(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - - @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) - def test_model_rope_scaling_from_config(self, scaling_type): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - short_input = ids_tensor([1, 10], config.vocab_size) - long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - original_model = OpenaiModel(config) - original_model.to(torch_device) - original_model.eval() - original_short_output = original_model(short_input).last_hidden_state - original_long_output = original_model(long_input).last_hidden_state - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - config.rope_scaling = {"type": scaling_type, "factor": 10.0} - scaled_model = OpenaiModel(config) - scaled_model.to(torch_device) - scaled_model.eval() - scaled_short_output = scaled_model(short_input).last_hidden_state - scaled_long_output = scaled_model(long_input).last_hidden_state - - # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original - # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": - torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) - else: - self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - - # The output should be different for long inputs - self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - - def test_model_rope_scaling(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - scaling_factor = 10 - short_input_length = 10 - long_input_length = int(config.max_position_embeddings * 1.5) - - # Inputs - x = torch.randn( - 1, dtype=torch.float32, device=torch_device - ) # used exclusively to get the dtype and the device - position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - position_ids_short = position_ids_short.unsqueeze(0) - position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - position_ids_long = position_ids_long.unsqueeze(0) - - # Sanity check original RoPE - original_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) - original_cos_short, original_sin_short = original_rope(x, position_ids_short) - original_cos_long, original_sin_long = original_rope(x, position_ids_long) - torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) - - # Sanity check linear RoPE scaling - # New position "x" should match original position with index "x/scaling_factor" - config.rope_scaling = {"type": "linear", "factor": scaling_factor} - linear_scaling_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) - linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) - linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) - torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) - for new_position in range(0, long_input_length, scaling_factor): - original_position = int(new_position // scaling_factor) - torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) - torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) - - # Sanity check Dynamic NTK RoPE scaling - # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase - # with scaling_factor (or that `inv_freq` decreases) - config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} - ntk_scaling_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) - ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) - ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) - torch.testing.assert_close(ntk_cos_short, original_cos_short) - torch.testing.assert_close(ntk_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_sin_long, original_sin_long) - self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - - # Sanity check Yarn RoPE scaling - # Scaling should be over the entire input - config.rope_scaling = {"type": "yarn", "factor": scaling_factor} - yarn_scaling_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) - yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) - yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) - torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_short, original_cos_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_long, original_sin_long) - - def test_model_loading_old_rope_configs(self): - def _reinitialize_config(base_config, new_kwargs): - # Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation - # steps. - base_config_dict = base_config.to_dict() - new_config = OpenaiConfig.from_dict(config_dict={**base_config_dict, **new_kwargs}) - return new_config - - # from untouched config -> ✅ - base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common() - original_model = OpenaiForCausalLM(base_config).to(torch_device) - original_model(**model_inputs) - - # from a config with the expected rope configuration -> ✅ - config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}}) - original_model = OpenaiForCausalLM(config).to(torch_device) - original_model(**model_inputs) - - # from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC - config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}}) - original_model = OpenaiForCausalLM(config).to(torch_device) - original_model(**model_inputs) - - # from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config) - config = _reinitialize_config( - base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}} - ) - self.assertTrue(config.rope_scaling["type"] == "linear") - self.assertTrue(config.rope_scaling["rope_type"] == "linear") - original_model = OpenaiForCausalLM(config).to(torch_device) - original_model(**model_inputs) - - # from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning - with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: - config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}}) - original_model = OpenaiForCausalLM(config).to(torch_device) - original_model(**model_inputs) - self.assertEqual(len(logs.output), 1) - self.assertIn("factor field", logs.output[0]) - - # from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning - with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: - config = _reinitialize_config( - base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}} - ) - original_model = OpenaiForCausalLM(config).to(torch_device) - original_model(**model_inputs) - self.assertEqual(len(logs.output), 1) - self.assertIn("Unrecognized keys", logs.output[0]) - - # from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception - with self.assertRaises(KeyError): - config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" - - -@require_torch_accelerator -class OpenaiIntegrationTest(unittest.TestCase): - def tearDown(self): - # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves - # some memory allocated in the cache, which means some object is not being released properly. This causes some - # unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU. - # Investigate the root cause. - cleanup(torch_device, gc_collect=False) - - @slow - @require_read_token - def test_openai_3_1_hard(self): - """ - An integration test for openai 3.1. It tests against a long output to ensure the subtle numerical differences - from openai 3.1.'s RoPE can be detected - """ - # diff on `EXPECTED_TEXT`: - # 2024-08-26: updating from torch 2.3.1 to 2.4.0 slightly changes the results. - EXPECTED_TEXT = ( - "Tell me about the french revolution. The french revolution was a period of radical political and social " - "upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked " - "by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the " - "First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative " - "assembly that had not met since 1614. The Third Estate, which represented the common people, " - "demanded greater representation and eventually broke away to form the National Assembly. This marked " - "the beginning of the end of the absolute monarchy and the rise of the middle class.\n" - ) - - tokenizer = AutoTokenizer.from_pretrained("meta-openai/Meta-Openai-3.1-8B-Instruct") - model = OpenaiForCausalLM.from_pretrained( - "meta-openai/Meta-Openai-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16 - ) - input_text = ["Tell me about the french revolution."] - model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device) - - generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False) - generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(generated_text, EXPECTED_TEXT) - - @slow - @require_read_token - def test_model_7b_logits_bf16(self): - input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] - - model = OpenaiForCausalLM.from_pretrained( - "meta-openai/Openai-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" - ) - - with torch.no_grad(): - out = model(torch.tensor([input_ids]).to(torch_device)) - # Expected mean on dim = -1 - - # fmt: off - expected_means = Expectations( - { - ("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]), - ("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]), - ("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]) - }) - - expected_mean = expected_means.get_expectation() - self.assertTrue( - torch.allclose( - expected_mean.to(torch_device), - out.logits.float().mean(-1), - atol=1e-2, - rtol=1e-2 - ) - ) - - # slicing logits[0, 0, 0:15] - expected_slices = Expectations( - { - ("xpu", 3): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]), - ("cuda", 7): torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]), - ("cuda", 8): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]) - }) - # fmt: on - expected_slice = expected_slices.get_expectation() - self.assertTrue( - torch.allclose( - expected_slice.to(torch_device), - out.logits[0, 0, :15].float(), - atol=1e-2, - rtol=1e-2, - ) - ) - - @slow - @require_read_token - def test_model_7b_logits(self): - input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] - - model = OpenaiForCausalLM.from_pretrained( - "meta-openai/Openai-2-7b-hf", device_map="auto", torch_dtype=torch.float16 - ) - - with torch.no_grad(): - out = model(torch.tensor([input_ids]).to(torch_device)) - - # fmt: off - # Expected mean on dim = -1 - expected_means = Expectations( - { - ("xpu", 3): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), - ("cuda", 7): torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]), - ("cuda", 8): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), - }) - - expected_mean = expected_means.get_expectation() - self.assertTrue( - torch.allclose( - expected_mean.to(torch_device), - out.logits.float().mean(-1), - atol=1e-2, - rtol=1e-2 - ) - ) - - # slicing logits[0, 0, 0:15] - expected_slices = Expectations( - { - ("xpu", 3): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]), - ("cuda", 7): torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]), - ("cuda", 8): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]) - }) - # fmt: on - - expected_slice = expected_slices.get_expectation() - self.assertTrue( - torch.allclose( - expected_slice.to(torch_device), - out.logits[0, 0, :15].float(), - atol=1e-2, - rtol=1e-2, - ) - ) + self.model_tester.create_and_check_lm_head_model(*config_and_inputs) - @slow - def test_model_7b_dola_generation(self): - # ground truth text generated with dola_layers="low", repetition_penalty=1.2 - EXPECTED_TEXT_COMPLETION = ( - "Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of " - "physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of " - "relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our " - "understanding of space and time." - ) - prompt = "Simply put, the theory of relativity states that " - tokenizer = LlamaTokenizer.from_pretrained("meta-openai/Openai-2-7b-chat-hf") - model = OpenaiForCausalLM.from_pretrained( - "meta-openai/Openai-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16 - ) - model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + def test_openai_gpt_double_lm_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs) - # greedy generation outputs - generated_ids = model.generate( - **model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low" - ) - text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + def test_openai_gpt_classification_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_openai_gpt_for_sequence_classification(*config_and_inputs) @slow - @require_torch_accelerator - @require_read_token - def test_compile_static_cache(self): - # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 - # work as intended. See https://github.com/pytorch/pytorch/issues/121943 - if version.parse(torch.__version__) < version.parse("2.3.0"): - self.skipTest(reason="This test requires torch >= 2.3 to run.") - - NUM_TOKENS_TO_GENERATE = 40 - # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test - # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. - EXPECTED_TEXT_COMPLETION = [ - "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " - "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " - "theory of relativ", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " - "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ] - - prompts = [ - "Simply put, the theory of relativity states that ", - "My favorite all time favorite condiment is ketchup.", - ] - tokenizer = LlamaTokenizer.from_pretrained( - "meta-openai/Openai-2-7b-hf", pad_token="", padding_side="right" - ) - model = OpenaiForCausalLM.from_pretrained( - "meta-openai/Openai-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16 - ) - inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - - # Dynamic Cache - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) - dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + def test_model_from_pretrained(self): + model_name = "openai-community/openai-gpt" + model = OpenAIGPTModel.from_pretrained(model_name) + self.assertIsNotNone(model) - # Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used) - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" - ) - static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) +@require_torch +class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase): @slow - @require_read_token - def test_export_static_cache(self): - if version.parse(torch.__version__) < version.parse("2.4.0"): - self.skipTest(reason="This test requires torch >= 2.4 to run.") - - from transformers.integrations.executorch import ( - TorchExportableModuleWithStaticCache, - convert_and_export_with_cache, - ) - - openai_models = { - "meta-openai/Openai-3.2-1B": [ - "Simply put, the theory of relativity states that 1) the speed of light is the same for all " - "observers, regardless of their location, and 2) the laws of physics are the same for all observers" - ], - } - - for openai_model_ckp, EXPECTED_TEXT_COMPLETION in openai_models.items(): - # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(openai_model_ckp, pad_token="", padding_side="right") - max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ - "input_ids" - ].shape[-1] - - # Load model - device = "cpu" - dtype = torch.bfloat16 - cache_implementation = "static" - attn_implementation = "sdpa" - batch_size = 1 - model = OpenaiForCausalLM.from_pretrained( - openai_model_ckp, - device_map=device, - torch_dtype=dtype, - attn_implementation=attn_implementation, - generation_config=GenerationConfig( - use_cache=True, - cache_implementation=cache_implementation, - max_length=max_generation_length, - cache_config={ - "batch_size": batch_size, - "max_cache_len": max_generation_length, - "device": device, - }, - ), - ) - - prompts = ["Simply put, the theory of relativity states that "] - prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - prompt_token_ids = prompt_tokens["input_ids"] - max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] - - # Static Cache + export - exported_program = convert_and_export_with_cache(model) - ep_generated_ids = TorchExportableModuleWithStaticCache.generate( - exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens - ) - ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) - - -@slow -@require_torch_accelerator -class Mask4DTestHard(unittest.TestCase): - def tearDown(self): - cleanup(torch_device, gc_collect=True) - - def setUp(self): - cleanup(torch_device, gc_collect=True) - model_name = "TinyOpenai/TinyOpenai-1.1B-Chat-v1.0" - self.model_dtype = torch.float32 - self.tokenizer = LlamaTokenizer.from_pretrained(model_name) - self.model = OpenaiForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) - - def get_test_data(self): - template = "my favorite {}" - items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item - - batch_separate = [template.format(x) for x in items] # 3 separate lines - batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated - - input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) - input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) - - mask_shared_prefix = torch.tensor( - [ - [ - [ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], - ] - ] - ], - device=torch_device, - ) - - position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) - - # building custom positions ids based on custom mask - position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) - # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) - - # inverting the mask - min_dtype = torch.finfo(self.model_dtype).min - mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype - - return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix - - def test_stacked_causal_mask(self): - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # single forward run with 4D custom mask - logits_shared_prefix = self.model.forward( - input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix - ).logits - logits_shared_prefix_last = logits_shared_prefix[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : - ] # last three tokens - decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] - - self.assertEqual(decoded, decoded_shared_prefix) - - def test_partial_stacked_causal_mask(self): - # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks - - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # 2 forward runs with custom 4D masks - part_a = 3 # split point - - input_1a = input_ids_shared_prefix[:, :part_a] - position_ids_1a = position_ids_shared_prefix[:, :part_a] - mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] - - outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a) - past_key_values_a = outs_1a["past_key_values"] - - # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) - input_1b = input_ids_shared_prefix[:, part_a:] - position_ids_1b = position_ids_shared_prefix[:, part_a:] - mask_1b = mask_shared_prefix[:, :, part_a:, :] - outs_1b = self.model.forward( - input_1b, - attention_mask=mask_1b, - position_ids=position_ids_1b, - past_key_values=past_key_values_a, - ) - decoded_1b = [ - self.tokenizer.decode(t) - for t in outs_1b.logits.argmax(-1)[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a - ] - ] - self.assertEqual(decoded, decoded_1b) - - def test_stacked_causal_mask_static_cache(self): - """same as above but with StaticCache""" - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # upgrade the model with StaticCache - max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache( - config=self.model.config, - max_batch_size=1, - max_cache_len=max_cache_len, - device=torch_device, - dtype=self.model.dtype, - ) - - padded_attention_mask = torch.nn.functional.pad( - input=mask_shared_prefix, - pad=(0, max_cache_len - mask_shared_prefix.shape[-1]), - mode="constant", - value=torch.finfo(self.model_dtype).min, - ) - - # single forward run with 4D custom mask - logits_shared_prefix = self.model.forward( - input_ids_shared_prefix, - attention_mask=padded_attention_mask, - position_ids=position_ids_shared_prefix, - cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device), - past_key_values=past_key_values, - ).logits - logits_shared_prefix_last = logits_shared_prefix[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : - ] # last three tokens - decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] - - self.assertEqual(decoded, decoded_shared_prefix) - - def test_partial_stacked_causal_mask_static_cache(self): - # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks - # we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len]) - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # upgrade the model with StaticCache - max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache( - config=self.model.config, - max_batch_size=1, - max_cache_len=max_cache_len, - device=torch_device, - dtype=self.model.dtype, - ) - - # forward run for the first part of input - part_a = 3 # split point - - input_1a = input_ids_shared_prefix[:, :part_a] - position_ids_1a = position_ids_shared_prefix[:, :part_a] - mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] - - padded_mask_1a = torch.nn.functional.pad( - input=mask_1a, - pad=(0, max_cache_len - mask_1a.shape[-1]), - mode="constant", - value=torch.finfo(self.model_dtype).min, - ) - - _ = self.model.forward( - input_1a, - attention_mask=padded_mask_1a, - position_ids=position_ids_1a, - cache_position=torch.arange(part_a, device=torch_device), - past_key_values=past_key_values, - ) - - # forward run for the second part of input - input_1b = input_ids_shared_prefix[:, part_a:] - position_ids_1b = position_ids_shared_prefix[:, part_a:] - mask_1b = mask_shared_prefix[:, :, part_a:, :] - - padded_mask_1b = torch.nn.functional.pad( - input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0 - ) - - outs_1b = self.model.forward( - input_1b, - attention_mask=padded_mask_1b, - position_ids=position_ids_1b, - cache_position=torch.arange( - part_a, - input_ids_shared_prefix.shape[-1], - device=torch_device, - ), - past_key_values=past_key_values, - ) - decoded_1b = [ - self.tokenizer.decode(t) - for t in outs_1b.logits.argmax(-1)[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a - ] - ] - self.assertEqual(decoded, decoded_1b) + def test_lm_generate_openai_gpt(self): + model = OpenAIGPTLMHeadModel.from_pretrained("openai-community/openai-gpt") + model.to(torch_device) + input_ids = torch.tensor([[481, 4735, 544]], dtype=torch.long, device=torch_device) # the president is + expected_output_ids = [ + 481, + 4735, + 544, + 246, + 963, + 870, + 762, + 239, + 244, + 40477, + 244, + 249, + 719, + 881, + 487, + 544, + 240, + 244, + 603, + 481, + ] # the president is a very good man. " \n " i\'m sure he is, " said the + + output_ids = model.generate(input_ids, do_sample=False) + self.assertListEqual(output_ids[0].tolist(), expected_output_ids) diff --git a/tests/models/openai_moe/__init__.py b/tests/models/openai_moe/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py new file mode 100644 index 000000000000..4239e0816bf7 --- /dev/null +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -0,0 +1,932 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch openai model.""" + +import unittest + +from packaging import version +from parameterized import parameterized + +from transformers import AutoTokenizer, OpenaiConfig, StaticCache, is_torch_available, set_seed +from transformers.generation.configuration_utils import GenerationConfig +from transformers.testing_utils import ( + Expectations, + cleanup, + require_read_token, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + LlamaTokenizer, + OpenaiForCausalLM, + OpenaiForQuestionAnswering, + OpenaiForSequenceClassification, + OpenaiForTokenClassification, + OpenaiModel, + ) + from transformers.models.openai.modeling_openai import OpenaiRotaryEmbedding + + +class OpenaiModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + pad_token_id=0, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return OpenaiConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + ) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = OpenaiModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class OpenaiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + OpenaiModel, + OpenaiForCausalLM, + OpenaiForSequenceClassification, + OpenaiForQuestionAnswering, + OpenaiForTokenClassification, + ) + if is_torch_available() + else () + ) + test_headmasking = False + test_pruning = False + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = OpenaiForCausalLM if is_torch_available() else None + + def setUp(self): + self.model_tester = OpenaiModelTester(self) + self.config_tester = ConfigTester(self, config_class=OpenaiConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_openai_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = OpenaiForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_openai_sequence_classification_model_for_single_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = OpenaiForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_openai_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + model = OpenaiForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_openai_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = OpenaiForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + + @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) + def test_model_rope_scaling_from_config(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = OpenaiModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = OpenaiModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn( + 1, dtype=torch.float32, device=torch_device + ) # used exclusively to get the dtype and the device + position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) + position_ids_short = position_ids_short.unsqueeze(0) + position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) + position_ids_long = position_ids_long.unsqueeze(0) + + # Sanity check original RoPE + original_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, position_ids_short) + original_cos_long, original_sin_long = original_rope(x, position_ids_long) + torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) + torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + + # Sanity check Yarn RoPE scaling + # Scaling should be over the entire input + config.rope_scaling = {"type": "yarn", "factor": scaling_factor} + yarn_scaling_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_short, original_cos_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + + def test_model_loading_old_rope_configs(self): + def _reinitialize_config(base_config, new_kwargs): + # Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation + # steps. + base_config_dict = base_config.to_dict() + new_config = OpenaiConfig.from_dict(config_dict={**base_config_dict, **new_kwargs}) + return new_config + + # from untouched config -> ✅ + base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common() + original_model = OpenaiForCausalLM(base_config).to(torch_device) + original_model(**model_inputs) + + # from a config with the expected rope configuration -> ✅ + config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}}) + original_model = OpenaiForCausalLM(config).to(torch_device) + original_model(**model_inputs) + + # from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC + config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}}) + original_model = OpenaiForCausalLM(config).to(torch_device) + original_model(**model_inputs) + + # from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config) + config = _reinitialize_config( + base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}} + ) + self.assertTrue(config.rope_scaling["type"] == "linear") + self.assertTrue(config.rope_scaling["rope_type"] == "linear") + original_model = OpenaiForCausalLM(config).to(torch_device) + original_model(**model_inputs) + + # from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning + with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: + config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}}) + original_model = OpenaiForCausalLM(config).to(torch_device) + original_model(**model_inputs) + self.assertEqual(len(logs.output), 1) + self.assertIn("factor field", logs.output[0]) + + # from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning + with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: + config = _reinitialize_config( + base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}} + ) + original_model = OpenaiForCausalLM(config).to(torch_device) + original_model(**model_inputs) + self.assertEqual(len(logs.output), 1) + self.assertIn("Unrecognized keys", logs.output[0]) + + # from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception + with self.assertRaises(KeyError): + config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" + + +@require_torch_accelerator +class OpenaiIntegrationTest(unittest.TestCase): + def tearDown(self): + # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves + # some memory allocated in the cache, which means some object is not being released properly. This causes some + # unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU. + # Investigate the root cause. + cleanup(torch_device, gc_collect=False) + + @slow + @require_read_token + def test_openai_3_1_hard(self): + """ + An integration test for openai 3.1. It tests against a long output to ensure the subtle numerical differences + from openai 3.1.'s RoPE can be detected + """ + # diff on `EXPECTED_TEXT`: + # 2024-08-26: updating from torch 2.3.1 to 2.4.0 slightly changes the results. + EXPECTED_TEXT = ( + "Tell me about the french revolution. The french revolution was a period of radical political and social " + "upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked " + "by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the " + "First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative " + "assembly that had not met since 1614. The Third Estate, which represented the common people, " + "demanded greater representation and eventually broke away to form the National Assembly. This marked " + "the beginning of the end of the absolute monarchy and the rise of the middle class.\n" + ) + + tokenizer = AutoTokenizer.from_pretrained("meta-openai/Meta-Openai-3.1-8B-Instruct") + model = OpenaiForCausalLM.from_pretrained( + "meta-openai/Meta-Openai-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16 + ) + input_text = ["Tell me about the french revolution."] + model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device) + + generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False) + generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(generated_text, EXPECTED_TEXT) + + @slow + @require_read_token + def test_model_7b_logits_bf16(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + + model = OpenaiForCausalLM.from_pretrained( + "meta-openai/Openai-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" + ) + + with torch.no_grad(): + out = model(torch.tensor([input_ids]).to(torch_device)) + # Expected mean on dim = -1 + + # fmt: off + expected_means = Expectations( + { + ("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]), + ("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]), + ("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]) + }) + + expected_mean = expected_means.get_expectation() + self.assertTrue( + torch.allclose( + expected_mean.to(torch_device), + out.logits.float().mean(-1), + atol=1e-2, + rtol=1e-2 + ) + ) + + # slicing logits[0, 0, 0:15] + expected_slices = Expectations( + { + ("xpu", 3): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]), + ("cuda", 7): torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]), + ("cuda", 8): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]) + }) + # fmt: on + expected_slice = expected_slices.get_expectation() + self.assertTrue( + torch.allclose( + expected_slice.to(torch_device), + out.logits[0, 0, :15].float(), + atol=1e-2, + rtol=1e-2, + ) + ) + + @slow + @require_read_token + def test_model_7b_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + + model = OpenaiForCausalLM.from_pretrained( + "meta-openai/Openai-2-7b-hf", device_map="auto", torch_dtype=torch.float16 + ) + + with torch.no_grad(): + out = model(torch.tensor([input_ids]).to(torch_device)) + + # fmt: off + # Expected mean on dim = -1 + expected_means = Expectations( + { + ("xpu", 3): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), + ("cuda", 7): torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]), + ("cuda", 8): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), + }) + + expected_mean = expected_means.get_expectation() + self.assertTrue( + torch.allclose( + expected_mean.to(torch_device), + out.logits.float().mean(-1), + atol=1e-2, + rtol=1e-2 + ) + ) + + # slicing logits[0, 0, 0:15] + expected_slices = Expectations( + { + ("xpu", 3): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]), + ("cuda", 7): torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]), + ("cuda", 8): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]) + }) + # fmt: on + + expected_slice = expected_slices.get_expectation() + self.assertTrue( + torch.allclose( + expected_slice.to(torch_device), + out.logits[0, 0, :15].float(), + atol=1e-2, + rtol=1e-2, + ) + ) + + @slow + def test_model_7b_dola_generation(self): + # ground truth text generated with dola_layers="low", repetition_penalty=1.2 + EXPECTED_TEXT_COMPLETION = ( + "Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of " + "physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of " + "relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our " + "understanding of space and time." + ) + prompt = "Simply put, the theory of relativity states that " + tokenizer = LlamaTokenizer.from_pretrained("meta-openai/Openai-2-7b-chat-hf") + model = OpenaiForCausalLM.from_pretrained( + "meta-openai/Openai-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16 + ) + model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + # greedy generation outputs + generated_ids = model.generate( + **model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low" + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + @require_torch_accelerator + @require_read_token + def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ] + + prompts = [ + "Simply put, the theory of relativity states that ", + "My favorite all time favorite condiment is ketchup.", + ] + tokenizer = LlamaTokenizer.from_pretrained( + "meta-openai/Openai-2-7b-hf", pad_token="", padding_side="right" + ) + model = OpenaiForCausalLM.from_pretrained( + "meta-openai/Openai-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + @slow + @require_read_token + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + convert_and_export_with_cache, + ) + + openai_models = { + "meta-openai/Openai-3.2-1B": [ + "Simply put, the theory of relativity states that 1) the speed of light is the same for all " + "observers, regardless of their location, and 2) the laws of physics are the same for all observers" + ], + } + + for openai_model_ckp, EXPECTED_TEXT_COMPLETION in openai_models.items(): + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(openai_model_ckp, pad_token="", padding_side="right") + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + model = OpenaiForCausalLM.from_pretrained( + openai_model_ckp, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + "device": device, + }, + ), + ) + + prompts = ["Simply put, the theory of relativity states that "] + prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + export + exported_program = convert_and_export_with_cache(model) + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + + +@slow +@require_torch_accelerator +class Mask4DTestHard(unittest.TestCase): + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def setUp(self): + cleanup(torch_device, gc_collect=True) + model_name = "TinyOpenai/TinyOpenai-1.1B-Chat-v1.0" + self.model_dtype = torch.float32 + self.tokenizer = LlamaTokenizer.from_pretrained(model_name) + self.model = OpenaiForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) + + def get_test_data(self): + template = "my favorite {}" + items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item + + batch_separate = [template.format(x) for x in items] # 3 separate lines + batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated + + input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) + input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) + + mask_shared_prefix = torch.tensor( + [ + [ + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], + ] + ] + ], + device=torch_device, + ) + + position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) + + # building custom positions ids based on custom mask + position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) + # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) + + # inverting the mask + min_dtype = torch.finfo(self.model_dtype).min + mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype + + return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix + + def test_stacked_causal_mask(self): + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # single forward run with 4D custom mask + logits_shared_prefix = self.model.forward( + input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix + ).logits + logits_shared_prefix_last = logits_shared_prefix[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : + ] # last three tokens + decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] + + self.assertEqual(decoded, decoded_shared_prefix) + + def test_partial_stacked_causal_mask(self): + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks + + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # 2 forward runs with custom 4D masks + part_a = 3 # split point + + input_1a = input_ids_shared_prefix[:, :part_a] + position_ids_1a = position_ids_shared_prefix[:, :part_a] + mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] + + outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a) + past_key_values_a = outs_1a["past_key_values"] + + # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) + input_1b = input_ids_shared_prefix[:, part_a:] + position_ids_1b = position_ids_shared_prefix[:, part_a:] + mask_1b = mask_shared_prefix[:, :, part_a:, :] + outs_1b = self.model.forward( + input_1b, + attention_mask=mask_1b, + position_ids=position_ids_1b, + past_key_values=past_key_values_a, + ) + decoded_1b = [ + self.tokenizer.decode(t) + for t in outs_1b.logits.argmax(-1)[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a + ] + ] + self.assertEqual(decoded, decoded_1b) + + def test_stacked_causal_mask_static_cache(self): + """same as above but with StaticCache""" + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # upgrade the model with StaticCache + max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] + past_key_values = StaticCache( + config=self.model.config, + max_batch_size=1, + max_cache_len=max_cache_len, + device=torch_device, + dtype=self.model.dtype, + ) + + padded_attention_mask = torch.nn.functional.pad( + input=mask_shared_prefix, + pad=(0, max_cache_len - mask_shared_prefix.shape[-1]), + mode="constant", + value=torch.finfo(self.model_dtype).min, + ) + + # single forward run with 4D custom mask + logits_shared_prefix = self.model.forward( + input_ids_shared_prefix, + attention_mask=padded_attention_mask, + position_ids=position_ids_shared_prefix, + cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device), + past_key_values=past_key_values, + ).logits + logits_shared_prefix_last = logits_shared_prefix[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : + ] # last three tokens + decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] + + self.assertEqual(decoded, decoded_shared_prefix) + + def test_partial_stacked_causal_mask_static_cache(self): + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks + # we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len]) + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # upgrade the model with StaticCache + max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] + past_key_values = StaticCache( + config=self.model.config, + max_batch_size=1, + max_cache_len=max_cache_len, + device=torch_device, + dtype=self.model.dtype, + ) + + # forward run for the first part of input + part_a = 3 # split point + + input_1a = input_ids_shared_prefix[:, :part_a] + position_ids_1a = position_ids_shared_prefix[:, :part_a] + mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] + + padded_mask_1a = torch.nn.functional.pad( + input=mask_1a, + pad=(0, max_cache_len - mask_1a.shape[-1]), + mode="constant", + value=torch.finfo(self.model_dtype).min, + ) + + _ = self.model.forward( + input_1a, + attention_mask=padded_mask_1a, + position_ids=position_ids_1a, + cache_position=torch.arange(part_a, device=torch_device), + past_key_values=past_key_values, + ) + + # forward run for the second part of input + input_1b = input_ids_shared_prefix[:, part_a:] + position_ids_1b = position_ids_shared_prefix[:, part_a:] + mask_1b = mask_shared_prefix[:, :, part_a:, :] + + padded_mask_1b = torch.nn.functional.pad( + input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0 + ) + + outs_1b = self.model.forward( + input_1b, + attention_mask=padded_mask_1b, + position_ids=position_ids_1b, + cache_position=torch.arange( + part_a, + input_ids_shared_prefix.shape[-1], + device=torch_device, + ), + past_key_values=past_key_values, + ) + decoded_1b = [ + self.tokenizer.decode(t) + for t in outs_1b.logits.argmax(-1)[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a + ] + ] + self.assertEqual(decoded, decoded_1b) From 6749ebc660cebd6803dffeb0372baf00ea5d1003 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 11:46:48 +0200 Subject: [PATCH 061/342] simplify tests + style --- .../models/auto/configuration_auto.py | 4 +- src/transformers/models/auto/modeling_auto.py | 4 +- .../models/auto/tokenization_auto.py | 8 +- .../openai_moe/configuration_openai_moe.py | 5 +- .../convert_openai_weights_to_hf.py | 20 +- .../openai_moe/test_modeling_openai_moe.py | 582 +----------------- 6 files changed, 44 insertions(+), 579 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index b509f3ccdf84..d653d6e0076f 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -242,8 +242,8 @@ ("omdet-turbo", "OmDetTurboConfig"), ("oneformer", "OneFormerConfig"), ("open-llama", "OpenLlamaConfig"), - ("openai_moe", "OpenAIMoeConfig"), ("openai-gpt", "OpenAIGPTConfig"), + ("openai_moe", "OpenAIMoeConfig"), ("opt", "OPTConfig"), ("owlv2", "Owlv2Config"), ("owlvit", "OwlViTConfig"), @@ -621,8 +621,8 @@ ("omdet-turbo", "OmDet-Turbo"), ("oneformer", "OneFormer"), ("open-llama", "OpenLlama"), - ("openai_moe", "OpenAIMoe"), ("openai-gpt", "OpenAI GPT"), + ("openai_moe", "OpenAIMoe"), ("opt", "OPT"), ("owlv2", "OWLv2"), ("owlvit", "OWL-ViT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4925fc785cf3..0d3f64032065 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -232,8 +232,8 @@ ("omdet-turbo", "OmDetTurboForObjectDetection"), ("oneformer", "OneFormerModel"), ("open-llama", "OpenLlamaModel"), - ("openai_moe", "OpenAIMoeModel"), ("openai-gpt", "OpenAIGPTModel"), + ("openai_moe", "OpenAIMoeModel"), ("opt", "OPTModel"), ("owlv2", "Owlv2Model"), ("owlvit", "OwlViTModel"), @@ -607,8 +607,8 @@ ("olmo2", "Olmo2ForCausalLM"), ("olmoe", "OlmoeForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"), - ("openai_moe", "OpenAIMoeForCausalLM"), ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("openai_moe", "OpenAIMoeForCausalLM"), ("opt", "OPTForCausalLM"), ("pegasus", "PegasusForCausalLM"), ("persimmon", "PersimmonForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 6389f74d429e..ecbfaecdadfa 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -406,17 +406,11 @@ ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None), ), ("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), - ( - "openai_moe", - ( - None, - "PreTrainedTokenizerFast" if is_tokenizers_available() else None, - ), - ), ( "openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None), ), + ("openai_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 9134d87905a1..8abe91d1e6db 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -29,6 +29,7 @@ class OpenAIMoeConfig(PretrainedConfig): [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture. """ + model_type = "openai_moe" # Default tensor parallel plan for base model `OpenaiModel` # a bit special, but this seems to work alright @@ -111,8 +112,8 @@ def __init__( if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) - - self.attention_bias = True + + self.attention_bias = True self.mlp_bias = False self.max_position_embeddings = max_position_embeddings self.router_aux_loss_coef = router_aux_loss_coef diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 1a4c49d44bb7..7feab380e0cc 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -19,8 +19,8 @@ from typing import List, Optional import regex as re +import tiktoken import torch -from tqdm import tqdm from safetensors.torch import load_file as safe_load from transformers import ( @@ -56,7 +56,6 @@ # fmt: on - def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None): """ This function should be applied only once, on the concatenated keys to efficiently rename using @@ -75,7 +74,6 @@ def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None): return output_dict - def write_model( model_path, input_base_path, @@ -93,7 +91,7 @@ def write_model( final_ = {} for file in list(os.listdir(input_base_path)): if file.endswith(".safetensors"): - final_.update(safe_load(os.path.join(input_base_path,file)) ) + final_.update(safe_load(os.path.join(input_base_path, file))) print("Converting ..") all_keys = final_.keys() @@ -109,7 +107,11 @@ def write_model( if re.search("qkv_proj", new_key): q_len = config.head_dim * config.num_attention_heads k_len = config.head_dim * config.num_key_value_heads - q, k, v = final_[key][:q_len, ...], final_[key][q_len:k_len+q_len, ...], final_[key][k_len+q_len:, ...] + q, k, v = ( + final_[key][:q_len, ...], + final_[key][q_len : k_len + q_len, ...], + final_[key][k_len + q_len :, ...], + ) q_key = re.sub(r"qkv_proj", "q_proj", new_key) k_key = re.sub(r"qkv_proj", "k_proj", new_key) v_key = re.sub(r"qkv_proj", "v_proj", new_key) @@ -117,11 +119,11 @@ def write_model( state_dict[k_key] = k.contiguous().to(torch.bfloat16) state_dict[v_key] = v.contiguous().to(torch.bfloat16) elif re.search("gate_up_proj|down_proj", new_key) and "bias" not in new_key: - state_dict[new_key] = final_[key].permute(0,2,1).contiguous() # einsum in orignal, I use bmm + state_dict[new_key] = final_[key].permute(0, 2, 1).contiguous() # einsum in orignal, I use bmm else: weight = final_[key] if not re.search("norm", new_key): - weight = weight.to(torch.bfloat16) # norms are the only ones in float32 + weight = weight.to(torch.bfloat16) # norms are the only ones in float32 state_dict[new_key] = weight del final_ @@ -182,7 +184,7 @@ def bytes_to_unicode(): cs = [chr(n) for n in cs] return dict(zip(bs, cs)) -import tiktoken + class OpenAIMoeConverter(TikTokenConverter): def extract_vocab_merges_from_model(self, tiktoken_url: str): tokenizer = tiktoken.get_encoding(tiktoken_url) @@ -224,7 +226,7 @@ def __init__( self.additional_special_tokens = {} # 199998 is not defined either self.additional_special_tokens["<|reserved_199998|>"] = 199998 - self.additional_special_tokens = {'<|endoftext|>': 199999, '<|endofprompt|>': 200018} + self.additional_special_tokens = {"<|endoftext|>": 199999, "<|endofprompt|>": 200018} for k in range(199999, 200018): self.additional_special_tokens[f"<|reserved_{k}|>"] = k sorted_list = sorted(self.additional_special_tokens.items(), key=lambda x: x[1]) diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index 4239e0816bf7..243d626fbd70 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -15,11 +15,16 @@ import unittest +import torch from packaging import version -from parameterized import parameterized -from transformers import AutoTokenizer, OpenaiConfig, StaticCache, is_torch_available, set_seed -from transformers.generation.configuration_utils import GenerationConfig +from transformers import ( + AutoTokenizer, + OpenAIMoeConfig, + OpenAIMoeForCausalLM, + OpenAIMoeModel, + is_torch_available, +) from transformers.testing_utils import ( Expectations, cleanup, @@ -36,21 +41,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin -if is_torch_available(): - import torch - - from transformers import ( - LlamaTokenizer, - OpenaiForCausalLM, - OpenaiForQuestionAnswering, - OpenaiForSequenceClassification, - OpenaiForTokenClassification, - OpenaiModel, - ) - from transformers.models.openai.modeling_openai import OpenaiRotaryEmbedding - - -class OpenaiModelTester: +class OpenAIMoeModelTester: def __init__( self, parent, @@ -125,7 +116,7 @@ def prepare_config_and_inputs(self): return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels def get_config(self): - return OpenaiConfig( + return OpenAIMoeConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, @@ -144,7 +135,7 @@ def get_config(self): def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): - model = OpenaiModel(config=config) + model = OpenAIMoeModel(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask) @@ -167,14 +158,11 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class OpenaiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class OpenAIMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( ( - OpenaiModel, - OpenaiForCausalLM, - OpenaiForSequenceClassification, - OpenaiForQuestionAnswering, - OpenaiForTokenClassification, + OpenAIMoeModel, + OpenAIMoeForCausalLM, ) if is_torch_available() else () @@ -188,11 +176,11 @@ class OpenaiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix model_split_percents = [0.5, 0.7, 0.8] # used in `test_torch_compile_for_training` - _torch_compile_train_cls = OpenaiForCausalLM if is_torch_available() else None + _torch_compile_train_cls = OpenAIMoeForCausalLM if is_torch_available() else None def setUp(self): - self.model_tester = OpenaiModelTester(self) - self.config_tester = ConfigTester(self, config_class=OpenaiConfig, hidden_size=37) + self.model_tester = OpenAIMoeModelTester(self) + self.config_tester = ConfigTester(self, config_class=OpenAIMoeConfig, hidden_size=37) def test_config(self): self.config_tester.run_common_tests() @@ -207,216 +195,9 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - def test_openai_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = OpenaiForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_openai_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = OpenaiForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_openai_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = OpenaiForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_openai_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = OpenaiForTokenClassification(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - - @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) - def test_model_rope_scaling_from_config(self, scaling_type): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - short_input = ids_tensor([1, 10], config.vocab_size) - long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - original_model = OpenaiModel(config) - original_model.to(torch_device) - original_model.eval() - original_short_output = original_model(short_input).last_hidden_state - original_long_output = original_model(long_input).last_hidden_state - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - config.rope_scaling = {"type": scaling_type, "factor": 10.0} - scaled_model = OpenaiModel(config) - scaled_model.to(torch_device) - scaled_model.eval() - scaled_short_output = scaled_model(short_input).last_hidden_state - scaled_long_output = scaled_model(long_input).last_hidden_state - - # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original - # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": - torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) - else: - self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - - # The output should be different for long inputs - self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - - def test_model_rope_scaling(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - scaling_factor = 10 - short_input_length = 10 - long_input_length = int(config.max_position_embeddings * 1.5) - - # Inputs - x = torch.randn( - 1, dtype=torch.float32, device=torch_device - ) # used exclusively to get the dtype and the device - position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - position_ids_short = position_ids_short.unsqueeze(0) - position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - position_ids_long = position_ids_long.unsqueeze(0) - - # Sanity check original RoPE - original_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) - original_cos_short, original_sin_short = original_rope(x, position_ids_short) - original_cos_long, original_sin_long = original_rope(x, position_ids_long) - torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) - - # Sanity check linear RoPE scaling - # New position "x" should match original position with index "x/scaling_factor" - config.rope_scaling = {"type": "linear", "factor": scaling_factor} - linear_scaling_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) - linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) - linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) - torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) - for new_position in range(0, long_input_length, scaling_factor): - original_position = int(new_position // scaling_factor) - torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) - torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) - - # Sanity check Dynamic NTK RoPE scaling - # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase - # with scaling_factor (or that `inv_freq` decreases) - config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} - ntk_scaling_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) - ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) - ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) - torch.testing.assert_close(ntk_cos_short, original_cos_short) - torch.testing.assert_close(ntk_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_sin_long, original_sin_long) - self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - - # Sanity check Yarn RoPE scaling - # Scaling should be over the entire input - config.rope_scaling = {"type": "yarn", "factor": scaling_factor} - yarn_scaling_rope = OpenaiRotaryEmbedding(config=config).to(torch_device) - yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) - yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) - torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_short, original_cos_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_long, original_sin_long) - - def test_model_loading_old_rope_configs(self): - def _reinitialize_config(base_config, new_kwargs): - # Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation - # steps. - base_config_dict = base_config.to_dict() - new_config = OpenaiConfig.from_dict(config_dict={**base_config_dict, **new_kwargs}) - return new_config - - # from untouched config -> ✅ - base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common() - original_model = OpenaiForCausalLM(base_config).to(torch_device) - original_model(**model_inputs) - - # from a config with the expected rope configuration -> ✅ - config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}}) - original_model = OpenaiForCausalLM(config).to(torch_device) - original_model(**model_inputs) - - # from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC - config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}}) - original_model = OpenaiForCausalLM(config).to(torch_device) - original_model(**model_inputs) - - # from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config) - config = _reinitialize_config( - base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}} - ) - self.assertTrue(config.rope_scaling["type"] == "linear") - self.assertTrue(config.rope_scaling["rope_type"] == "linear") - original_model = OpenaiForCausalLM(config).to(torch_device) - original_model(**model_inputs) - - # from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning - with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: - config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}}) - original_model = OpenaiForCausalLM(config).to(torch_device) - original_model(**model_inputs) - self.assertEqual(len(logs.output), 1) - self.assertIn("factor field", logs.output[0]) - - # from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning - with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: - config = _reinitialize_config( - base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}} - ) - original_model = OpenaiForCausalLM(config).to(torch_device) - original_model(**model_inputs) - self.assertEqual(len(logs.output), 1) - self.assertIn("Unrecognized keys", logs.output[0]) - - # from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception - with self.assertRaises(KeyError): - config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" - @require_torch_accelerator -class OpenaiIntegrationTest(unittest.TestCase): +class OpenAIMoeIntegrationTest(unittest.TestCase): def tearDown(self): # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves # some memory allocated in the cache, which means some object is not being released properly. This causes some @@ -444,7 +225,7 @@ def test_openai_3_1_hard(self): ) tokenizer = AutoTokenizer.from_pretrained("meta-openai/Meta-Openai-3.1-8B-Instruct") - model = OpenaiForCausalLM.from_pretrained( + model = OpenAIMoeForCausalLM.from_pretrained( "meta-openai/Meta-Openai-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16 ) input_text = ["Tell me about the french revolution."] @@ -459,7 +240,7 @@ def test_openai_3_1_hard(self): def test_model_7b_logits_bf16(self): input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] - model = OpenaiForCausalLM.from_pretrained( + model = OpenAIMoeForCausalLM.from_pretrained( "meta-openai/Openai-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" ) @@ -508,7 +289,7 @@ def test_model_7b_logits_bf16(self): def test_model_7b_logits(self): input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] - model = OpenaiForCausalLM.from_pretrained( + model = OpenAIMoeForCausalLM.from_pretrained( "meta-openai/Openai-2-7b-hf", device_map="auto", torch_dtype=torch.float16 ) @@ -563,8 +344,8 @@ def test_model_7b_dola_generation(self): "understanding of space and time." ) prompt = "Simply put, the theory of relativity states that " - tokenizer = LlamaTokenizer.from_pretrained("meta-openai/Openai-2-7b-chat-hf") - model = OpenaiForCausalLM.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained("meta-openai/Openai-2-7b-chat-hf") + model = OpenAIMoeForCausalLM.from_pretrained( "meta-openai/Openai-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16 ) model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) @@ -600,10 +381,8 @@ def test_compile_static_cache(self): "Simply put, the theory of relativity states that ", "My favorite all time favorite condiment is ketchup.", ] - tokenizer = LlamaTokenizer.from_pretrained( - "meta-openai/Openai-2-7b-hf", pad_token="", padding_side="right" - ) - model = OpenaiForCausalLM.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained("meta-openai/Openai-2-7b-hf", pad_token="", padding_side="right") + model = OpenAIMoeForCausalLM.from_pretrained( "meta-openai/Openai-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16 ) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) @@ -619,314 +398,3 @@ def test_compile_static_cache(self): ) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) - - @slow - @require_read_token - def test_export_static_cache(self): - if version.parse(torch.__version__) < version.parse("2.4.0"): - self.skipTest(reason="This test requires torch >= 2.4 to run.") - - from transformers.integrations.executorch import ( - TorchExportableModuleWithStaticCache, - convert_and_export_with_cache, - ) - - openai_models = { - "meta-openai/Openai-3.2-1B": [ - "Simply put, the theory of relativity states that 1) the speed of light is the same for all " - "observers, regardless of their location, and 2) the laws of physics are the same for all observers" - ], - } - - for openai_model_ckp, EXPECTED_TEXT_COMPLETION in openai_models.items(): - # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(openai_model_ckp, pad_token="", padding_side="right") - max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ - "input_ids" - ].shape[-1] - - # Load model - device = "cpu" - dtype = torch.bfloat16 - cache_implementation = "static" - attn_implementation = "sdpa" - batch_size = 1 - model = OpenaiForCausalLM.from_pretrained( - openai_model_ckp, - device_map=device, - torch_dtype=dtype, - attn_implementation=attn_implementation, - generation_config=GenerationConfig( - use_cache=True, - cache_implementation=cache_implementation, - max_length=max_generation_length, - cache_config={ - "batch_size": batch_size, - "max_cache_len": max_generation_length, - "device": device, - }, - ), - ) - - prompts = ["Simply put, the theory of relativity states that "] - prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - prompt_token_ids = prompt_tokens["input_ids"] - max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] - - # Static Cache + export - exported_program = convert_and_export_with_cache(model) - ep_generated_ids = TorchExportableModuleWithStaticCache.generate( - exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens - ) - ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) - - -@slow -@require_torch_accelerator -class Mask4DTestHard(unittest.TestCase): - def tearDown(self): - cleanup(torch_device, gc_collect=True) - - def setUp(self): - cleanup(torch_device, gc_collect=True) - model_name = "TinyOpenai/TinyOpenai-1.1B-Chat-v1.0" - self.model_dtype = torch.float32 - self.tokenizer = LlamaTokenizer.from_pretrained(model_name) - self.model = OpenaiForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) - - def get_test_data(self): - template = "my favorite {}" - items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item - - batch_separate = [template.format(x) for x in items] # 3 separate lines - batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated - - input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) - input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) - - mask_shared_prefix = torch.tensor( - [ - [ - [ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], - ] - ] - ], - device=torch_device, - ) - - position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) - - # building custom positions ids based on custom mask - position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) - # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) - - # inverting the mask - min_dtype = torch.finfo(self.model_dtype).min - mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype - - return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix - - def test_stacked_causal_mask(self): - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # single forward run with 4D custom mask - logits_shared_prefix = self.model.forward( - input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix - ).logits - logits_shared_prefix_last = logits_shared_prefix[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : - ] # last three tokens - decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] - - self.assertEqual(decoded, decoded_shared_prefix) - - def test_partial_stacked_causal_mask(self): - # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks - - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # 2 forward runs with custom 4D masks - part_a = 3 # split point - - input_1a = input_ids_shared_prefix[:, :part_a] - position_ids_1a = position_ids_shared_prefix[:, :part_a] - mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] - - outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a) - past_key_values_a = outs_1a["past_key_values"] - - # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) - input_1b = input_ids_shared_prefix[:, part_a:] - position_ids_1b = position_ids_shared_prefix[:, part_a:] - mask_1b = mask_shared_prefix[:, :, part_a:, :] - outs_1b = self.model.forward( - input_1b, - attention_mask=mask_1b, - position_ids=position_ids_1b, - past_key_values=past_key_values_a, - ) - decoded_1b = [ - self.tokenizer.decode(t) - for t in outs_1b.logits.argmax(-1)[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a - ] - ] - self.assertEqual(decoded, decoded_1b) - - def test_stacked_causal_mask_static_cache(self): - """same as above but with StaticCache""" - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # upgrade the model with StaticCache - max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache( - config=self.model.config, - max_batch_size=1, - max_cache_len=max_cache_len, - device=torch_device, - dtype=self.model.dtype, - ) - - padded_attention_mask = torch.nn.functional.pad( - input=mask_shared_prefix, - pad=(0, max_cache_len - mask_shared_prefix.shape[-1]), - mode="constant", - value=torch.finfo(self.model_dtype).min, - ) - - # single forward run with 4D custom mask - logits_shared_prefix = self.model.forward( - input_ids_shared_prefix, - attention_mask=padded_attention_mask, - position_ids=position_ids_shared_prefix, - cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device), - past_key_values=past_key_values, - ).logits - logits_shared_prefix_last = logits_shared_prefix[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : - ] # last three tokens - decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] - - self.assertEqual(decoded, decoded_shared_prefix) - - def test_partial_stacked_causal_mask_static_cache(self): - # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks - # we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len]) - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # upgrade the model with StaticCache - max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache( - config=self.model.config, - max_batch_size=1, - max_cache_len=max_cache_len, - device=torch_device, - dtype=self.model.dtype, - ) - - # forward run for the first part of input - part_a = 3 # split point - - input_1a = input_ids_shared_prefix[:, :part_a] - position_ids_1a = position_ids_shared_prefix[:, :part_a] - mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] - - padded_mask_1a = torch.nn.functional.pad( - input=mask_1a, - pad=(0, max_cache_len - mask_1a.shape[-1]), - mode="constant", - value=torch.finfo(self.model_dtype).min, - ) - - _ = self.model.forward( - input_1a, - attention_mask=padded_mask_1a, - position_ids=position_ids_1a, - cache_position=torch.arange(part_a, device=torch_device), - past_key_values=past_key_values, - ) - - # forward run for the second part of input - input_1b = input_ids_shared_prefix[:, part_a:] - position_ids_1b = position_ids_shared_prefix[:, part_a:] - mask_1b = mask_shared_prefix[:, :, part_a:, :] - - padded_mask_1b = torch.nn.functional.pad( - input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0 - ) - - outs_1b = self.model.forward( - input_1b, - attention_mask=padded_mask_1b, - position_ids=position_ids_1b, - cache_position=torch.arange( - part_a, - input_ids_shared_prefix.shape[-1], - device=torch_device, - ), - past_key_values=past_key_values, - ) - decoded_1b = [ - self.tokenizer.decode(t) - for t in outs_1b.logits.argmax(-1)[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a - ] - ] - self.assertEqual(decoded, decoded_1b) From ac5b5d4168d2355310f38555936367e93283819f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 11:50:00 +0200 Subject: [PATCH 062/342] revert flex change --- .../integrations/flex_attention.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index e69f7a4c6672..1e1228873f17 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -233,7 +233,6 @@ def flex_attention_forward( scaling: Optional[float] = None, softcap: Optional[float] = None, head_mask: Optional[torch.Tensor] = None, - score_mod: Optional[callable] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: if head_mask is not None: @@ -257,15 +256,14 @@ def flex_attention_forward( if score_mask is not None: score_mask = score_mask[:, :, :, : key.shape[-2]] - if score_mod is not None: - def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): - if softcap is not None: - score = softcap * torch.tanh(score / softcap) - if score_mask is not None: - score = score + score_mask[batch_idx][0][q_idx][kv_idx] - if head_mask is not None: - score = score + head_mask[batch_idx][head_idx][0][0] - return score + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + if softcap is not None: + score = softcap * torch.tanh(score / softcap) + if score_mask is not None: + score = score + score_mask[batch_idx][0][q_idx][kv_idx] + if head_mask is not None: + score = score + head_mask[batch_idx][head_idx][0][0] + return score enable_gqa = True num_local_query_heads = query.shape[1] @@ -292,7 +290,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): training=module.training, ) # lse is returned in float32 - attention_weights = attention_weights.to(value.dtype)[:, :, :, : key.shape[-2]] # potential sink + attention_weights = attention_weights.to(value.dtype) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attention_weights From f929f841aa7c10bb981de76b7b8575f2f8749ae3 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 11:56:13 +0200 Subject: [PATCH 063/342] revert and correct change to masking utils --- src/transformers/masking_utils.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 5c402dd65d72..f5affab2306f 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -590,6 +590,7 @@ def _preprocess_mask_arguments( attention_mask: Optional[Union[torch.Tensor, BlockMask]], cache_position: torch.Tensor, past_key_values: Optional[Cache], + layer_idx: Optional[int], ) -> tuple[bool, Optional[Union[torch.Tensor, BlockMask]], int, int]: """ Perform some common pre-processing of the mask arguments we get from the modeling code. Mostly determine the @@ -640,10 +641,6 @@ def _preprocess_mask_arguments( # If using a cache, it can give all informations about mask sizes based on seen tokens if past_key_values is not None: - if hasattr(past_key_values, "is_sliding") and isinstance(past_key_values.is_sliding, list): - layer_idx = past_key_values.is_sliding.index(False) - else: - layer_idx = 0 kv_length, kv_offset = past_key_values.get_mask_sizes(cache_position, layer_idx) # Otherwise, the sizes are simply the input sizes else: @@ -687,8 +684,13 @@ def create_causal_mask( useful to easily overlay another mask on top of the causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the full layers + if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding: + layer_idx = past_key_values.is_sliding.index(False) + else: + layer_idx = 0 + early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( - config, input_embeds, attention_mask, cache_position, past_key_values + config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx ) if early_exit: return attention_mask @@ -727,7 +729,7 @@ def create_causal_mask( ) return causal_mask -@torch.compiler.disable(recursive=True) + def create_sliding_window_causal_mask( config: PretrainedConfig, input_embeds: torch.Tensor, @@ -764,10 +766,13 @@ def create_sliding_window_causal_mask( useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the sliding layers - + if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding: + layer_idx = past_key_values.is_sliding.index(True) + else: + layer_idx = 0 early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( - config, input_embeds, attention_mask, cache_position, past_key_values + config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx ) if early_exit: return attention_mask @@ -848,9 +853,9 @@ def create_chunked_causal_mask( useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling. """ # If we have an HybridCache structure, here we want to create the mask for the sliding layers - try: + if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding: layer_idx = past_key_values.is_sliding.index(True) - except (ValueError, AttributeError): + else: layer_idx = 0 early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( From eb8a4cd5ba672ad526402c64dec9116471538de6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 11:57:44 +0200 Subject: [PATCH 064/342] revert change to example --- examples/3D_parallel.py | 68 ++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/examples/3D_parallel.py b/examples/3D_parallel.py index 958336eefd3e..fc6e3bd71062 100644 --- a/examples/3D_parallel.py +++ b/examples/3D_parallel.py @@ -81,7 +81,7 @@ def main(): seq_len = 1024 # Sequence length num_train_steps = 10000 # Number of training steps LR = 1e-5 - model_name = "/fsx/arthur/oai_hf" # Path to the model directory or Hugging Face model name + model_name = "HuggingFaceTB/SmolLM2-1.7B" # model_name = "unsloth/Llama-3.2-1B" CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}" @@ -109,27 +109,27 @@ def main(): f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}" ) - # if dist.get_rank() == 0: - # wandb.init( - # project="tp_dp_test", - # config={ - # "tp_size": tp_size, - # "dp_size": dp_size, - # "cp_size": cp_size, - # "global_batch_size": global_batch_size, - # "model_name": model_name, - # "dataset": "roneneldan/TinyStories-1M", - # "seq_len": seq_len, - # "lr": LR, - # "weight_decay": 0.1, - # }, - # name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}" - # if model_name == "unsloth/Llama-3.2-1B" - # else f"tp{tp_size}_dp{dp_size}_cp{cp_size}", - # ) - # logger.info("Wandb initialized.") - # # Log the current file to wandb - # wandb.save("test_train.py") + if dist.get_rank() == 0: + wandb.init( + project="tp_dp_test", + config={ + "tp_size": tp_size, + "dp_size": dp_size, + "cp_size": cp_size, + "global_batch_size": global_batch_size, + "model_name": model_name, + "dataset": "roneneldan/TinyStories-1M", + "seq_len": seq_len, + "lr": LR, + "weight_decay": 0.1, + }, + name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}" + if model_name == "unsloth/Llama-3.2-1B" + else f"tp{tp_size}_dp{dp_size}_cp{cp_size}", + ) + logger.info("Wandb initialized.") + # Log the current file to wandb + wandb.save("test_train.py") # Load model and tokenizer logger.info(f"Loading model and tokenizer from {model_name}") @@ -137,7 +137,7 @@ def main(): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}") - print(f"TP MESH: {tp_mesh}") + model = AutoModelForCausalLM.from_pretrained( model_name, device_mesh=tp_mesh if dist.is_initialized() else None, @@ -200,7 +200,7 @@ def create_packed_sequences(examples): batched=True, remove_columns=tokenized_dataset.column_names, batch_size=1000, # Process in batches for efficiency - # num_proc=60, + num_proc=60, ) logger.info(f"Dataset packed. New size: {len(packed_dataset)}") @@ -316,15 +316,15 @@ def collate_fn(batch): logger.info( f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}" ) - # wandb.log( - # { - # "train/loss": current_loss, - # "train/gradnorm": gradnorm, - # "step": step, - # "lr": LR, - # "GBS": global_batch_size, - # } - # ) + wandb.log( + { + "train/loss": current_loss, + "train/gradnorm": gradnorm, + "step": step, + "lr": LR, + "GBS": global_batch_size, + } + ) step += 1 # Increment step count @@ -349,7 +349,7 @@ def collate_fn(batch): logger.info("Cleaned up distributed process group") # Finish wandb run on rank 0 if dist.get_rank() == 0: - # wandb.finish() + wandb.finish() logger.info("Wandb run finished.") From 128db147387adedeccd3626d5f3137f280966ba6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 11:58:03 +0200 Subject: [PATCH 065/342] style --- src/transformers/integrations/tensor_parallel.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index f2116179396c..0e093d3ed409 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -414,7 +414,6 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me # TODO: figure out dynamo support for instance method and switch this to instance method return outputs - def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): param = param[...].to(param_casting_dtype) if to_contiguous: @@ -422,7 +421,6 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, param = param / device_mesh.size() # TODO should be optionable return param - def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: distribute_module( module, @@ -807,7 +805,9 @@ def replace_state_dict_local_with_dtensor( return state_dict -def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh, parameter_name=None): +def add_tensor_parallel_hooks_to_module( + model, module, tp_plan, layer_name, current_module_plan, device_mesh, parameter_name=None +): """ Add hooks to the module holding the layer. Meaning: ``` @@ -878,7 +878,9 @@ def shard_and_distribute_module( # Add hooks to the module if not done yet # add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) if not getattr(module_to_tp, "_is_hooked", False): - add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_shard_plan, device_mesh, parameter_name) + add_tensor_parallel_hooks_to_module( + model, module_to_tp, tp_plan, param_name, current_shard_plan, device_mesh, parameter_name + ) module_to_tp._is_hooked = True if current_shard_plan is not None: From 1e4242c3675247851efe0175d93b53c8df88cad6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 12:08:26 +0200 Subject: [PATCH 066/342] Update test_modeling_openai_moe.py --- tests/models/openai_moe/test_modeling_openai_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index 243d626fbd70..9adc509878d8 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -196,6 +196,7 @@ def test_model_various_embeddings(self): self.model_tester.create_and_check_model(*config_and_inputs) +@unittest.skip(reason="No available checkpoint for integration tests yet") @require_torch_accelerator class OpenAIMoeIntegrationTest(unittest.TestCase): def tearDown(self): From 0d0532e58455073356deb121088fdb943d511e26 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 12:12:50 +0200 Subject: [PATCH 067/342] licenses --- .../models/openai_moe/configuration_openai_moe.py | 7 +------ .../models/openai_moe/convert_openai_weights_to_hf.py | 2 +- src/transformers/models/openai_moe/modeling_openai_moe.py | 7 +------ src/transformers/models/openai_moe/modular_openai_moe.py | 7 +------ tests/models/openai_moe/test_modeling_openai_moe.py | 2 +- 5 files changed, 5 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 8abe91d1e6db..3467a9522d8c 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -1,10 +1,5 @@ # coding=utf-8 -# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 7feab380e0cc..5e673b6e6453 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index fc5eb34f0c9f..5db22f4ffa22 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -5,12 +5,7 @@ # modular_openai_moe.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 -# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index abab709e423a..0096c4be9e2e 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -1,10 +1,5 @@ # coding=utf-8 -# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index 9adc509878d8..8310505ddd6a 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From ce8b97d3a5a23b4ff0b5c7747ed3a8268b15af4a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 12:28:09 +0200 Subject: [PATCH 068/342] model init --- .../models/openai_moe/modeling_openai_moe.py | 7 ++ .../models/openai_moe/modular_openai_moe.py | 82 +++++++++++-------- 2 files changed, 57 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 5db22f4ffa22..42f6d621183a 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -396,6 +396,13 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() elif isinstance(module, OpenAIMoeRMSNorm): module.weight.data.fill_(1.0) + elif isinstance(module, OpenAIMoeExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.gate_up_proj_bias.data.zero_() + module.down_proj.data.normal_(mean=0.0, std=std) + module.down_proj_bias.data.zero_() + elif isinstance(module, OpenAIMoeAttention): + module.sinks.data.normal_(mean=0.0, std=std) @auto_docstring diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 0096c4be9e2e..983e6d717efc 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -205,38 +205,37 @@ def eager_attention_forward( return attn_output, attn_weights -def openai_flex_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, key.shape[-2], -1) - - def attention_sink(score, b, h, q_idx, kv_idx): - score = torch.cat([score, sinks], dim=-1) - return score - - # TODO I need to remove the -1 sinks - return flex_attention_forward( - module, - query, - key, - value, - attention_mask, - scaling=scaling, - dropout=dropout, - attention_sink=attention_sink, - score_mod=attention_sink, - **kwargs, - ) - - -ALL_ATTENTION_FUNCTIONS.register("openai_flex_attention", openai_flex_attention_forward) +# def openai_flex_attention_forward( +# module: nn.Module, +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# attention_mask: Optional[torch.Tensor], +# scaling: float, +# dropout: float = 0.0, +# **kwargs, +# ): +# sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, key.shape[-2], -1) + +# def attention_sink(score, b, h, q_idx, kv_idx): +# score = torch.cat([score, sinks], dim=-1) +# return score + +# # TODO I need to remove the -1 sinks +# return flex_attention_forward( +# module, +# query, +# key, +# value, +# attention_mask, +# scaling=scaling, +# dropout=dropout, +# attention_sink=attention_sink, +# score_mod=attention_sink, +# **kwargs, +# ) + +# ALL_ATTENTION_FUNCTIONS.register("openai_flex_attention", openai_flex_attention_forward) class OpenAIMoeAttention(Qwen2Attention): @@ -313,6 +312,25 @@ def forward( class OpenAIMoePreTrainedModel(LlamaPreTrainedModel): _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, OpenAIMoeRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, OpenAIMoeExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.gate_up_proj_bias.data.zero_() + module.down_proj.data.normal_(mean=0.0, std=std) + module.down_proj_bias.data.zero_() + elif isinstance(module, OpenAIMoeAttention): + module.sinks.data.normal_(mean=0.0, std=std) class OpenAIMoeModel(MixtralModel): _no_split_modules = ["OpenAIMoeDecoderLayer"] From fde43e02b88fb4896687bd8e33fda400cb10c29b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 12:33:08 +0200 Subject: [PATCH 069/342] doc --- docs/source/en/model_doc/openai_moe.md | 33 +++++-------------- .../models/openai_moe/modular_openai_moe.py | 3 +- 2 files changed, 9 insertions(+), 27 deletions(-) diff --git a/docs/source/en/model_doc/openai_moe.md b/docs/source/en/model_doc/openai_moe.md index 211d954de749..2c0b39013dc4 100644 --- a/docs/source/en/model_doc/openai_moe.md +++ b/docs/source/en/model_doc/openai_moe.md @@ -24,13 +24,11 @@ rendered properly in your Markdown viewer. -# openai - -# openai +# OpenAIMoE ## Overview -The openai model was proposed in []() by . +The OpenAIMoE model was proposed in []() by . The abstract from the paper is the following: @@ -45,31 +43,16 @@ This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface The original code can be found [here](). -## OpenaiConfig - -[[autodoc]] OpenaiConfig - -## OpenaiModel - -[[autodoc]] OpenaiModel - - forward - -## OpenaiForCausalLM - -[[autodoc]] OpenaiForCausalLM - - forward +## OpenAIMoeConfig -## OpenaiForSequenceClassification - -[[autodoc]] OpenaiForSequenceClassification - - forward +[[autodoc]] OpenAIMoeConfig -## OpenaiForQuestionAnswering +## OpenAIMoeModel -[[autodoc]] OpenaiForQuestionAnswering +[[autodoc]] OpenAIMoeModel - forward -## OpenaiForTokenClassification +## OpenAIMoeForCausalLM -[[autodoc]] OpenaiForTokenClassification +[[autodoc]] OpenAIMoeForCausalLM - forward diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 983e6d717efc..56d4bb7cc61a 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -18,14 +18,12 @@ from torch import nn from ...cache_utils import Cache, DynamicCache -from ...integrations.flex_attention import flex_attention_forward from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( MoeModelOutputWithPast, ) from ...modeling_rope_utils import dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, logging from ..llama.modeling_llama import ( @@ -332,6 +330,7 @@ def _init_weights(self, module): elif isinstance(module, OpenAIMoeAttention): module.sinks.data.normal_(mean=0.0, std=std) + class OpenAIMoeModel(MixtralModel): _no_split_modules = ["OpenAIMoeDecoderLayer"] From 7ebcfdb98a42b0bf2b14abe3a9f48f3db515feef Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 11 Jun 2025 14:38:06 +0200 Subject: [PATCH 070/342] Update convert_openai_weights_to_hf.py --- .../models/openai_moe/convert_openai_weights_to_hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 5e673b6e6453..cfe1b3b63eaf 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -129,14 +129,14 @@ def write_model( del final_ gc.collect() - print("Loading the checkpoint in a OpenAIMoe ") + print("Loading the checkpoint in a OpenAIMoe model") with torch.device("meta"): model = OpenAIMoeForCausalLM(config) model.load_state_dict(state_dict, strict=True, assign=True) print("Checkpoint loaded successfully.") del config._name_or_path - print("Saving the ") + print("Saving the model") model.save_pretrained(model_path, safe_serialization=safe_serialization) del state_dict, model From e4897635fae62ae9d437030aad82889f71a143c4 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 26 Jun 2025 15:36:30 +0000 Subject: [PATCH 071/342] update with attention fix --- .../models/openai_moe/modular_openai_moe.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 56d4bb7cc61a..511e622196ef 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -173,7 +173,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k_embed = _apply_rotary_emb(k, cos, sin) return q_embed, k_embed - def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -195,14 +194,18 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = torch.cat([attn_weights, sinks], dim=-1) - attn_weights = torch.softmax(attn_weights, dim=-1) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + # scale the logits to prevent overflows + logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values + sinks = torch.exp(sinks - logits_max) + unnormalized_scores = torch.exp(attn_weights - logits_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights[..., :-1], value_states) # ignore the sinks attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights - # def openai_flex_attention_forward( # module: nn.Module, # query: torch.Tensor, From f626f79696c2b12a868e9f3669eebcf24a449374 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 26 Jun 2025 17:36:45 +0200 Subject: [PATCH 072/342] Update conversion script for the 20b --- .../openai_moe/convert_openai_weights_to_hf.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index cfe1b3b63eaf..56c222ee342d 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -16,6 +16,7 @@ import gc import json import os +from pathlib import Path from typing import List, Optional import regex as re @@ -85,7 +86,18 @@ def write_model( eos_token_id = 199999 if not instruct else [199999, 200018] pad_token_id = 128004 - config = OpenAIMoeConfig() + original_config = json.loads((Path(input_base_path) / "config.json").read_text()) + + num_local_experts = original_config.pop("num_experts") + rope_scaling = { + "beta_fast": float(original_config.pop("rope_ntk_beta")), + "beta_slow": float(original_config.pop("rope_ntk_alpha")), + "factor": float(original_config.pop('rope_scaling_factor')), + "rope_type": "yarn", + "truncate": False + } + + config = OpenAIMoeConfig(num_local_experts=num_local_experts, rope_scaling=rope_scaling, **original_config) print(f"Fetching all parameters from the checkpoint at {input_base_path}...") final_ = {} From 37741c146262ef3d5599949195129f82bda9d36d Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 26 Jun 2025 17:49:44 +0200 Subject: [PATCH 073/342] original_max_position_embeddings --- .../models/openai_moe/convert_openai_weights_to_hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 56c222ee342d..3312ae58f71a 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -94,7 +94,8 @@ def write_model( "beta_slow": float(original_config.pop("rope_ntk_alpha")), "factor": float(original_config.pop('rope_scaling_factor')), "rope_type": "yarn", - "truncate": False + "truncate": False, + "original_max_position_embeddings": 4096 } config = OpenAIMoeConfig(num_local_experts=num_local_experts, rope_scaling=rope_scaling, **original_config) From cbec81ad12d1a49704cad8e27ee1bda843af4de6 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 30 Jun 2025 10:36:13 +0000 Subject: [PATCH 074/342] bias should only be added once --- src/transformers/models/openai_moe/modeling_openai_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 42f6d621183a..7b5b90d2b7aa 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -118,7 +118,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors glu = gate * torch.sigmoid(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj) + self.down_proj_bias[..., None, :] + next_states = torch.bmm(((up + 1) * glu), self.down_proj) + # add bias only on TP=0 so that we avoid adding it for all TPs + if torch.distributed.get_rank() == 0: + next_states = next_states + self.down_proj_bias[..., None, :] next_states = next_states.view(-1, self.hidden_size) return next_states From ae2bce12a8d16df5d7e8a2bb9f24d0eff156efe9 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 30 Jun 2025 10:36:56 +0000 Subject: [PATCH 075/342] fix: GatherParallel allreduces first element --- src/transformers/models/openai_moe/modeling_openai_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 7b5b90d2b7aa..8671e828d9d7 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -123,7 +123,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig if torch.distributed.get_rank() == 0: next_states = next_states + self.down_proj_bias[..., None, :] next_states = next_states.view(-1, self.hidden_size) - return next_states + return next_states, None # bcz GatherParallel allreduces first element class OpenAIMoeMLP(nn.Module): From 08aa2e2057084bb284b22da40eb5fbbebefe7962 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 30 Jun 2025 10:39:48 +0000 Subject: [PATCH 076/342] fix: GatherParallel allreduces first element --- src/transformers/models/openai_moe/modeling_openai_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 8671e828d9d7..3fe3e024bc81 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -143,7 +143,7 @@ def forward(self, hidden_states): router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) - routed_out = self.experts(hidden_states, router_indices, router_top_value) + routed_out, _ = self.experts(hidden_states, router_indices, router_top_value) #TODO: router_indices isn't used inside this func if self.training: output_states = routed_out.view(batch_size, -1, self.hidden_dim) else: From 94dfb88977f01d4ac6abbcfafed8c526310e2814 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 30 Jun 2025 10:40:12 +0000 Subject: [PATCH 077/342] raise if eager not used --- src/transformers/models/openai_moe/modeling_openai_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 3fe3e024bc81..dee02ea49f00 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -299,6 +299,7 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": + raise ValueError(f"Attention implementation {self.config._attn_implementation} doesn't support sinks") attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( From 38fac7f4fda772ff92dcb56e068bda40d4cbdc64 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 30 Jun 2025 10:44:15 +0000 Subject: [PATCH 078/342] add test script --- hf_generate.py | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 hf_generate.py diff --git a/hf_generate.py b/hf_generate.py new file mode 100644 index 000000000000..734808277e35 --- /dev/null +++ b/hf_generate.py @@ -0,0 +1,78 @@ +""" +torchrun --nproc_per_node=2 hf_generate.py +""" + +from transformers import AutoTokenizer, OpenAIMoeForCausalLM +import torch +import os +import logging +import torch.distributed as dist +from torch.distributed.tensor.experimental import context_parallel +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.distributed.device_mesh import DeviceMesh + +model_id = "ft-hf-o-c/random-checkpoint-converted-20b" + + +# torch.use_deterministic_algorithms(True) +torch.backends.cudnn.deterministic = True + +# Set up logging +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +def main(): + tp_size = int(os.environ.get("TP_SIZE", 4)) + + # Initialize distributed environment + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + dist.init_process_group("nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + + mesh = torch.arange(world_size).reshape(tp_size) + world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("tp",)) + tp_mesh = world_mesh["tp"] + logger.info(f"Created DeviceMesh: {world_mesh}") + logger.info( + f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, TP: {tp_mesh.get_local_rank()}" + ) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # messages = [ + # {"role": "user", "content": "Who are you?"}, + # ] + # inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True) + inputs = tokenizer("Hello! How are you?", return_tensors="pt") + model = OpenAIMoeForCausalLM.from_pretrained( + model_id, + device_mesh=tp_mesh if dist.is_initialized() else None, + tp_plan="auto", + tp_size=tp_size, + torch_dtype=torch.bfloat16, + # torch_dtype=torch.float32, + attn_implementation="eager", + ) + logger.info(f"Model loaded onto device mesh: {tp_mesh}") + device = torch.device(f"cuda:{local_rank}") + logger.info(f"Using device: {device} for non-model tensors") + model.eval() + + outputs = model.generate(**inputs.to(model.device), max_new_tokens=100) + outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:]) + print(outputs[0]) + + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + main() \ No newline at end of file From ca817e6839e45481d9cdbcb8cfe28a4ef169bff1 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 30 Jun 2025 12:12:04 +0000 Subject: [PATCH 079/342] disable bias normalization for now --- src/transformers/integrations/tensor_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 0e093d3ed409..91c41d62386d 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -418,7 +418,7 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, param = param[...].to(param_casting_dtype) if to_contiguous: param = param.contiguous() - param = param / device_mesh.size() # TODO should be optionable + # param = param / device_mesh.size() # TODO should be optionable return param def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: From 0e4ace000ba2cb9b4b7e09a4625133359a01b6dc Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 30 Jun 2025 12:12:38 +0000 Subject: [PATCH 080/342] gate_up_proj_bias must be local_packed_rowwise --- .../models/openai_moe/configuration_openai_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 3467a9522d8c..a768586ff0ce 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -35,9 +35,9 @@ class OpenAIMoeConfig(PretrainedConfig): "layers.*.self_attn.o_proj": "rowwise", "layers.*.self_attn.sinks": "local_rowwise", "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", - "layers.*.mlp.experts.gate_up_proj_bias": "local_rowwise", + "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", "layers.*.mlp.experts.down_proj": "local_colwise", - "layers.*.mlp.experts.down_proj_bias": "local", + "layers.*.mlp.experts.down_proj_bias": "local", # TODO: maybe add smthg that says bias exists only once for all TPs "layers.*.mlp.experts": "gather", } base_model_pp_plan = { From 679826ad5228dd4c4b27a47d40680cd2c16c4ece Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Tue, 1 Jul 2025 08:41:40 +0000 Subject: [PATCH 081/342] doing rowlinear allreduce before adding bias --- src/transformers/integrations/tensor_parallel.py | 4 ++-- src/transformers/models/openai_moe/modeling_openai_moe.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 91c41d62386d..e09b89819463 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -391,7 +391,7 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_ @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): # this op cannot be async, otherwise it completely breaks the outputs of models - torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False) + # torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False) # TODO: need to create a new class for backward compatibility return outputs @@ -877,7 +877,7 @@ def shard_and_distribute_module( # Add hooks to the module if not done yet # add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) - if not getattr(module_to_tp, "_is_hooked", False): + if not getattr(module_to_tp, "_is_hooked", False): # this adds gather hook to layers.*.mlp.experts and skips the rest add_tensor_parallel_hooks_to_module( model, module_to_tp, tp_plan, param_name, current_shard_plan, device_mesh, parameter_name ) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index dee02ea49f00..abcd840684c7 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -119,6 +119,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors glu = gate * torch.sigmoid(gate * self.alpha) next_states = torch.bmm(((up + 1) * glu), self.down_proj) + torch.distributed.all_reduce(next_states, op=torch.distributed.ReduceOp.SUM) # TODO: if we can attach hook to `down_proj` we can move this to hook # add bias only on TP=0 so that we avoid adding it for all TPs if torch.distributed.get_rank() == 0: next_states = next_states + self.down_proj_bias[..., None, :] From de1b870a691291d6393581d467c4070491cb71c7 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Tue, 1 Jul 2025 08:44:57 +0000 Subject: [PATCH 082/342] testing --- src/transformers/generation/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 713d57a8994d..23838076f043 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3573,6 +3573,8 @@ def _sample( else: is_prefill = True + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("pytorch-os-mini-final-quantized-moe-sharded_hf") while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -3631,6 +3633,9 @@ def _sample( else: next_tokens = torch.argmax(next_token_scores, dim=-1) + if torch.distributed.get_rank() == 0: + print(f"Generated token: {repr(tokenizer.decode(next_tokens))}, logprob: {next_token_logits.max()}") + # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) From 5b3eae205fa04b1451fec0476f9d1a30e56de56d Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Tue, 1 Jul 2025 08:47:28 +0000 Subject: [PATCH 083/342] dont parallelize attn for exact match of greedy gen --- .../models/openai_moe/configuration_openai_moe.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index a768586ff0ce..b2ce502630d7 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -29,11 +29,11 @@ class OpenAIMoeConfig(PretrainedConfig): # Default tensor parallel plan for base model `OpenaiModel` # a bit special, but this seems to work alright base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.self_attn.sinks": "local_rowwise", + # "layers.*.self_attn.q_proj": "colwise", + # "layers.*.self_attn.k_proj": "colwise", + # "layers.*.self_attn.v_proj": "colwise", + # "layers.*.self_attn.o_proj": "rowwise", + # "layers.*.self_attn.sinks": "local_rowwise", "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", "layers.*.mlp.experts.down_proj": "local_colwise", From b8b50a5c0b743759ffd2b4ee7b979b842913ecb7 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Tue, 1 Jul 2025 16:43:06 +0000 Subject: [PATCH 084/342] EP in tensor_parallel.py --- src/transformers/integrations/tensor_parallel.py | 15 +++++++++++++++ .../openai_moe/configuration_openai_moe.py | 16 +++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index e09b89819463..e791eb1c0397 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -736,6 +736,20 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) +class GroupedGemmParallel(TensorParallelLayer): + def __init__(self): + super().__init__() + self.use_dtensor = False + + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + ep_rank = rank + global_num_experts = empty_param.shape[0] + assert global_num_experts % device_mesh.size() == 0, f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0" + local_num_experts = global_num_experts // device_mesh.size() + param = param[ep_rank*local_num_experts:(ep_rank+1)*local_num_experts].to(param_casting_dtype) + if to_contiguous: + param = param.contiguous() + return param class ParallelInterface(GeneralInterface): # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if @@ -753,6 +767,7 @@ class ParallelInterface(GeneralInterface): "local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False), "sequence_parallel": SequenceParallel(), "replicate": ReplicateParallel(), + "grouped_gemm": GroupedGemmParallel(), } if is_torch_greater_or_equal("2.5") and _torch_distributed_available else {} diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index b2ce502630d7..7d8ae8feb74d 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -34,11 +34,17 @@ class OpenAIMoeConfig(PretrainedConfig): # "layers.*.self_attn.v_proj": "colwise", # "layers.*.self_attn.o_proj": "rowwise", # "layers.*.self_attn.sinks": "local_rowwise", - "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", - "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", - "layers.*.mlp.experts.down_proj": "local_colwise", - "layers.*.mlp.experts.down_proj_bias": "local", # TODO: maybe add smthg that says bias exists only once for all TPs - "layers.*.mlp.experts": "gather", + + # "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", + # "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", + # "layers.*.mlp.experts.down_proj": "local_colwise", + # "layers.*.mlp.experts.down_proj_bias": "local", # TODO: maybe add smthg that says bias exists only once for all TPs + # "layers.*.mlp.experts": "gather", + + 'model.layers.*.mlp.experts.gate_up_proj': "grouped_gemm", + 'model.layers.*.mlp.experts.gate_up_proj_bias': "grouped_gemm", + 'model.layers.*.mlp.experts.down_proj': "grouped_gemm", + 'model.layers.*.mlp.experts.down_proj_bias': "grouped_gemm", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), From deb4edbd0551d869dbe9dc96116087c5b1bf5850 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Wed, 2 Jul 2025 10:15:17 +0000 Subject: [PATCH 085/342] Naive Expert Parallelism (greedy gen matches) --- hf_generate.py | 16 +++++--- src/transformers/configuration_utils.py | 5 +++ .../openai_moe/configuration_openai_moe.py | 2 +- .../models/openai_moe/modeling_openai_moe.py | 37 ++++++++++++------- 4 files changed, 41 insertions(+), 19 deletions(-) diff --git a/hf_generate.py b/hf_generate.py index 734808277e35..2232908cd2cd 100644 --- a/hf_generate.py +++ b/hf_generate.py @@ -1,5 +1,5 @@ """ -torchrun --nproc_per_node=2 hf_generate.py +TP_SIZE=8 torchrun --nproc_per_node=8 hf_generate.py """ from transformers import AutoTokenizer, OpenAIMoeForCausalLM @@ -11,7 +11,11 @@ from torch.nn.attention import SDPBackend, sdpa_kernel from torch.distributed.device_mesh import DeviceMesh -model_id = "ft-hf-o-c/random-checkpoint-converted-20b" +# model_id = "ft-hf-o-c/random-checkpoint-converted-20b" +# openai = "/fsx/vb/pytorch-os-mini-final-quantized-moe-sharded" + +# model_id = "/scratch/pytorch-os-mini-final-quantized-moe-sharded_hf" +model_id = "pytorch-os-mini-final-quantized-moe-sharded_hf" # torch.use_deterministic_algorithms(True) @@ -27,7 +31,7 @@ def main(): - tp_size = int(os.environ.get("TP_SIZE", 4)) + tp_size = int(os.environ.get("TP_SIZE", 8)) # Initialize distributed environment if "RANK" in os.environ and "WORLD_SIZE" in os.environ: @@ -51,11 +55,13 @@ def main(): # {"role": "user", "content": "Who are you?"}, # ] # inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True) - inputs = tokenizer("Hello! How are you?", return_tensors="pt") + # inputs = tokenizer("Hello! How are you?", return_tensors="pt") + inputs = tokenizer("Who are you? And who made you?", return_tensors="pt") model = OpenAIMoeForCausalLM.from_pretrained( model_id, device_mesh=tp_mesh if dist.is_initialized() else None, tp_plan="auto", + enable_expert_parallel=os.environ.get("ENABLE_EXPERT_PARALLEL", "0") == "1", tp_size=tp_size, torch_dtype=torch.bfloat16, # torch_dtype=torch.float32, @@ -66,7 +72,7 @@ def main(): logger.info(f"Using device: {device} for non-model tensors") model.eval() - outputs = model.generate(**inputs.to(model.device), max_new_tokens=100) + outputs = model.generate(**inputs.to(model.device), max_new_tokens=100, use_cache=False, do_sample=False) outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:]) print(outputs[0]) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 205a7dde8f28..91f5f878de25 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -302,6 +302,9 @@ def __init__(self, **kwargs): self._attn_implementation_internal = kwargs.pop("attn_implementation", None) self._attn_implementation_autoset = False + # Expert parallelism + self.enable_expert_parallel = kwargs.pop("enable_expert_parallel", False) + # Drop the transformers version info self.transformers_version = kwargs.pop("transformers_version", None) @@ -1017,6 +1020,8 @@ def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None: del d["_attn_implementation_internal"] if "_attn_implementation_autoset" in d: del d["_attn_implementation_autoset"] + if "enable_expert_parallel" in d: + del d["enable_expert_parallel"] # Do not serialize `base_model_tp_plan` for now if "base_model_tp_plan" in d: del d["base_model_tp_plan"] diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 7d8ae8feb74d..668cb897572f 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -55,7 +55,7 @@ class OpenAIMoeConfig(PretrainedConfig): def __init__( self, num_hidden_layers: int = 36, - num_local_experts: int = 128, + num_local_experts: int = 128, #TODO: rename to num_experts otherwise confusing with EP vocab_size: int = 201088, hidden_size: int = 2880, intermediate_size: int = 2880, diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index abcd840684c7..951d8cb7ed8e 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Tuple, Union +import torch.distributed as dist import torch from torch import nn @@ -64,7 +65,9 @@ def extra_repr(self): class OpenAIMoeExperts(nn.Module): def __init__(self, config): super().__init__() - self.num_experts = config.num_local_experts + self.enable_expert_parallel = config.enable_expert_parallel + self.ep_size = dist.get_world_size() if (self.enable_expert_parallel and dist.is_initialized()) else 1 + self.num_experts = config.num_local_experts // self.ep_size self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size @@ -119,20 +122,22 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors glu = gate * torch.sigmoid(gate * self.alpha) next_states = torch.bmm(((up + 1) * glu), self.down_proj) - torch.distributed.all_reduce(next_states, op=torch.distributed.ReduceOp.SUM) # TODO: if we can attach hook to `down_proj` we can move this to hook - # add bias only on TP=0 so that we avoid adding it for all TPs - if torch.distributed.get_rank() == 0: - next_states = next_states + self.down_proj_bias[..., None, :] + if not self.enable_expert_parallel: + torch.distributed.all_reduce(next_states, op=torch.distributed.ReduceOp.SUM) # TODO: if we can attach hook to `down_proj` we can move this to hook + next_states = next_states + self.down_proj_bias[..., None, :] next_states = next_states.view(-1, self.hidden_size) return next_states, None # bcz GatherParallel allreduces first element - class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok + # self.top_k = 1 self.hidden_dim = config.hidden_size - self.num_local_experts = config.num_local_experts + self.enable_expert_parallel = config.enable_expert_parallel + self.ep_size = dist.get_world_size() if (self.enable_expert_parallel and dist.is_initialized()) else 1 + self.ep_rank = dist.get_rank() if (self.enable_expert_parallel and dist.is_initialized()) else 0 + self.num_local_experts = config.num_local_experts // self.ep_size # TODO: bad naming self.experts = OpenAIMoeExperts(config) self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=True) @@ -140,17 +145,23 @@ def forward(self, hidden_states): # we don't slice weight as its not compile compatible batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.router(hidden_states) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) + router_logits = self.router(hidden_states) # (seq_len, num_experts) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) - routed_out, _ = self.experts(hidden_states, router_indices, router_top_value) #TODO: router_indices isn't used inside this func + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) # (num_experts, seq_len) + if self.enable_expert_parallel: + local_router_scores = router_scores[self.ep_rank * self.num_local_experts:(self.ep_rank + 1) * self.num_local_experts] + else: + local_router_scores = router_scores + routed_out, _ = self.experts(hidden_states, router_indices, local_router_scores) #TODO: router_indices isn't used inside this func if self.training: output_states = routed_out.view(batch_size, -1, self.hidden_dim) else: - routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None] + routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * local_router_scores[..., None] # we're throwing away computed routed_out for rest of experts output_states = routed_out.view(self.num_local_experts, batch_size, -1, self.hidden_dim).sum(dim=0) - return output_states, router_scores + if self.enable_expert_parallel: + torch.distributed.all_reduce(output_states, op=torch.distributed.ReduceOp.SUM) # TODO: move to hook + return output_states, local_router_scores class OpenAIMoeRotaryEmbedding(nn.Module): From a3b22ced6fa2972d5fd3ebb6bcd99c73175926d5 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 2 Jul 2025 13:46:11 +0000 Subject: [PATCH 086/342] feat: add megablocks moe mlp kernel --- src/transformers/integrations/hub_kernels.py | 44 +++---------------- .../models/openai_moe/modeling_openai_moe.py | 1 + 2 files changed, 8 insertions(+), 37 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 63e0c381e798..2993f953ec90 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -23,9 +23,7 @@ register_kernel_mapping, replace_kernel_forward_from_hub, ) - from kernels import ( - use_kernel_forward_from_hub as original_use_kernel_forward_from_hub, - ) + from kernels import use_kernel_forward_from_hub _hub_kernels_available = True @@ -56,44 +54,16 @@ layer_name="TritonLlamaMLP", ) }, + "MegaBlocksMoeMLP": { + "cuda": LayerRepository( + repo_id="kernels-community/megablocks", + layer_name="MegaBlocksMoeMLP", + ) + }, } register_kernel_mapping(_KERNEL_MAPPING) - def use_kernel_forward_from_hub(*args, **kwargs): - """ - Expands `kernels`' `use_kernel_forward_from_hub` to NOT use a kernel at compile time. This should be removed - when `kernels` supports `torch.compile`. - - If the layer has a `config` attribute, we can also set `config.disable_custom_kernels = True` to disable the - kernel. - """ - - def decorator_with_compile_path(cls): - # Keeps a reference to the original forward method - original_forward = cls.forward - - # Applies the original decorator - decorator = original_use_kernel_forward_from_hub(*args, **kwargs) - cls = decorator(cls) - - # Replaces the kernel forward with a compile-friendly version - kernel_forward = cls.forward - - def forward_with_compile_path(*forward_args, **forward_kwargs): - disable_custom_kernels = hasattr(cls, "config") and getattr(cls.config, "disable_custom_kernels", None) - if is_torchdynamo_compiling() or disable_custom_kernels: - return original_forward(*forward_args, **forward_kwargs) - else: - return kernel_forward(*forward_args, **forward_kwargs) - - cls.forward = forward_with_compile_path - - return cls - - return decorator_with_compile_path - - except ImportError: # Stub to make decorators int transformers work when `kernels` # is not installed. diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 42f6d621183a..2678e4f125f6 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -123,6 +123,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig return next_states +@use_kernel_forward_from_hub("MegaBlocksMoeMLP") class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() From f4fea2a20a0c22cbdf04e70f260d16a7c3d311ca Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Wed, 2 Jul 2025 16:52:19 +0200 Subject: [PATCH 087/342] =?UTF-8?q?Update=20jinja=20template=20based=20on?= =?UTF-8?q?=20Harmony.=20=F0=9F=A4=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../convert_openai_weights_to_hf.py | 65 ++++++++++++------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 3312ae58f71a..64883392458a 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -256,30 +256,47 @@ def __init__( def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): - # Chat template - chat_template = ( - "{% for message in messages %}" - "{% if loop.index0 == 0 %}" - "{{ bos_token }}" - "{% endif %}" - "{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}" - "{% if message['content'] is string %}" - "{{ message['content'] }}" - "{% else %}" - "{% for content in message['content'] %}" - "{% if content['type'] == 'image' %}" - "{{ '<|image|>' }}" - "{% elif content['type'] == 'text' %}" - "{{ content['text'] }}" - "{% endif %}" - "{% endfor %}" - "{% endif %}" - "{{ '<|eot_id|>' }}" - "{% endfor %}" - "{% if add_generation_prompt %}" - "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" - "{% endif %}" - ) + # Updated Harmony chat template + chat_template = """{# Harmony chat template -------------------------------------------------- + This template mirrors the message rendering logic implemented in + `harmony/src/encoding.rs`. It can be consumed by Hugging Face + Transformers (``chat_template`` field) so that *text → tokens* + conversion of chat conversations happens fully on the Python side + without relying on the Rust renderer. + + Supported *message* keys (per ``chat::Message``): + - role (user│assistant│system│developer│tool) + - name (optional author name) + - recipient (optional recipient – omitted or "all" → broadcast) + - channel (optional meta channel) + - content_type (optional content-type qualifier) + - content (string – the actual message payload) + + The template renders each historical message *fully* (incl. the + trailing <|end|>/<|return|> sentinel) and – if ``add_generation_prompt`` + is True – appends a partial header for the **next** assistant turn + exactly like ``render_conversation_for_completion`` does on the Rust + side: ``<|start|>assistant``. +#} + +{%- macro harmony_header(m) -%} + <|start|>{% if m['role'] == 'tool' %}{{ m['name'] }}{% else %}{{ m['role'] }}{% if m.get('name') %}:{{ m['name'] }}{% endif %}{% endif %}{% if m.get('recipient') and m['recipient'] != 'all' %} to={{ m['recipient'] }}{% endif %}{% if m.get('channel') %}<|channel|>{{ m['channel'] }}{% endif %}{% if m.get('content_type') %} {{ m['content_type'] }}{% endif %}<|message|> +{%- endmacro -%} + +{# --------------------------------------------------------------------- + Render complete history +#} +{%- for message in messages -%} + {{- harmony_header(message) -}}{{ message['content'] }}{%- if message['role'] == 'assistant' -%}<|return|>{%- else -%}<|end|>{%- endif -%} +{%- endfor -%} + +{# --------------------------------------------------------------------- + Generation prompt for *next* assistant answer +#} +{%- if add_generation_prompt -%} +<|start|>assistant +{%- endif -%} +""" converter = OpenAIMoeConverter( vocab_file=tokenizer_path, From ba37581778cf10864390e0b7cbb793ffac67d2d7 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Wed, 2 Jul 2025 17:06:35 +0200 Subject: [PATCH 088/342] Add constrain + call token. --- .../convert_openai_weights_to_hf.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 64883392458a..0a39816bb846 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -237,13 +237,25 @@ def __init__( # TODO 1st donwload the vocabfile!!! tokenizer = tiktoken.get_encoding(vocab_file) self.additional_special_tokens = {} - # 199998 is not defined either - self.additional_special_tokens["<|reserved_199998|>"] = 199998 - self.additional_special_tokens = {"<|endoftext|>": 199999, "<|endofprompt|>": 200018} + # Build Harmony special tokens → IDs. + special_tokens_map = { + "<|reserved_199998|>": 199998, + "<|endoftext|>": 199999, # same as <|end|> + "<|constrain|>": 200003, + "<|call|>": 200012, + "<|endofprompt|>": 200018, + } + + # Add the remaining reserved slots while skipping already-used IDs. for k in range(199999, 200018): - self.additional_special_tokens[f"<|reserved_{k}|>"] = k - sorted_list = sorted(self.additional_special_tokens.items(), key=lambda x: x[1]) - self.additional_special_tokens = [k[0] for k in sorted_list] + if k in (200003, 200012): + continue + special_tokens_map.setdefault(f"<|reserved_{k}|>", k) + + # Keep only token strings (sorted by ID) for TikTokenConverter. + self.additional_special_tokens = [ + tok for tok, _ in sorted(special_tokens_map.items(), key=lambda x: x[1]) + ] tokenizer = self.converted() if chat_template is not None: kwargs["chat_template"] = chat_template From 22067ff2eb0d0d0e63e82dc27f46bb9d5a8de563 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Wed, 2 Jul 2025 17:10:05 +0000 Subject: [PATCH 089/342] refactor self.router to remove EP code from modeling --- .../integrations/tensor_parallel.py | 41 ++++++++++++++++++- .../openai_moe/configuration_openai_moe.py | 9 ++-- .../models/openai_moe/modeling_openai_moe.py | 34 ++++++++------- 3 files changed, 65 insertions(+), 19 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index e791eb1c0397..39628ec66068 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -127,7 +127,7 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> Opti The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight"). """ generic_param_name = re.sub(r"\d+", "*", parameter_name) - if generic_param_name in tp_plan: + if generic_param_name in tp_plan: # TODO: i can't define hooks for parent modules, only leaf modules who have params return tp_plan[generic_param_name] elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan: return tp_plan[generic_param_name.rsplit(".", 1)[0]] @@ -751,6 +751,44 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, param = param.contiguous() return param +class RouterParallel(TensorParallelLayer): + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.use_dtensor = False + + @staticmethod + def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + input_tensor = inputs[0] + if isinstance(input_tensor, DTensor): + raise ValueError("RouterParallel does not support DTensor input for now") + return input_tensor + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size() + num_local_experts = mod.num_experts // ep_size + router_scores, router_indices = outputs + router_scores = router_scores[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] + return router_scores, router_indices + + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + param = param[...].to(param_casting_dtype) + if to_contiguous: + param = param.contiguous() + return param + + + def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: + # TODO: need an abstract Parallel class that is different from TensorParallelLayer + distribute_module( + module, + device_mesh, + partial(self._prepare_input_fn, None, None), + partial(self._prepare_output_fn, None, None), + ) + + class ParallelInterface(GeneralInterface): # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if # a new instance is created (in order to locally override a given entry) @@ -768,6 +806,7 @@ class ParallelInterface(GeneralInterface): "sequence_parallel": SequenceParallel(), "replicate": ReplicateParallel(), "grouped_gemm": GroupedGemmParallel(), + "ep_router": RouterParallel(), } if is_torch_greater_or_equal("2.5") and _torch_distributed_available else {} diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 668cb897572f..558605894c1e 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -41,10 +41,11 @@ class OpenAIMoeConfig(PretrainedConfig): # "layers.*.mlp.experts.down_proj_bias": "local", # TODO: maybe add smthg that says bias exists only once for all TPs # "layers.*.mlp.experts": "gather", - 'model.layers.*.mlp.experts.gate_up_proj': "grouped_gemm", - 'model.layers.*.mlp.experts.gate_up_proj_bias': "grouped_gemm", - 'model.layers.*.mlp.experts.down_proj': "grouped_gemm", - 'model.layers.*.mlp.experts.down_proj_bias': "grouped_gemm", + 'layers.*.mlp.experts.gate_up_proj': "grouped_gemm", + 'layers.*.mlp.experts.gate_up_proj_bias': "grouped_gemm", + 'layers.*.mlp.experts.down_proj': "grouped_gemm", + 'layers.*.mlp.experts.down_proj_bias': "grouped_gemm", + "layers.*.mlp.router": "ep_router", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 951d8cb7ed8e..e5ef9f1296b2 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -128,40 +128,46 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig next_states = next_states.view(-1, self.hidden_size) return next_states, None # bcz GatherParallel allreduces first element +class TopKRouter(nn.Linear): # TODO: if i inherit from module it'll be annoying to attach hook to parent module + def __init__(self, config): + super().__init__(config.hidden_size, config.num_local_experts, bias=True) + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + # TODO: is it better to define self.weight and self.bias as nn.Parameter instead to keep the same namings: mlp.router.weight instead of mlp.router.router.weight? + + def forward(self, hidden_states): + router_logits = super().forward(hidden_states) # (seq_len, num_experts) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) # (num_experts, seq_len) + return router_scores, router_indices + class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() - self.top_k = config.num_experts_per_tok - # self.top_k = 1 self.hidden_dim = config.hidden_size self.enable_expert_parallel = config.enable_expert_parallel self.ep_size = dist.get_world_size() if (self.enable_expert_parallel and dist.is_initialized()) else 1 self.ep_rank = dist.get_rank() if (self.enable_expert_parallel and dist.is_initialized()) else 0 self.num_local_experts = config.num_local_experts // self.ep_size # TODO: bad naming self.experts = OpenAIMoeExperts(config) - self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=True) + self.router = TopKRouter(config) def forward(self, hidden_states): # we don't slice weight as its not compile compatible batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.router(hidden_states) # (seq_len, num_experts) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) # (num_experts, seq_len) - if self.enable_expert_parallel: - local_router_scores = router_scores[self.ep_rank * self.num_local_experts:(self.ep_rank + 1) * self.num_local_experts] - else: - local_router_scores = router_scores - routed_out, _ = self.experts(hidden_states, router_indices, local_router_scores) #TODO: router_indices isn't used inside this func + router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) + + routed_out, _ = self.experts(hidden_states, router_indices, router_scores) #TODO: router_indices isn't used inside this func if self.training: output_states = routed_out.view(batch_size, -1, self.hidden_dim) else: - routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * local_router_scores[..., None] # we're throwing away computed routed_out for rest of experts + routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None] # we're throwing away computed routed_out for rest of experts output_states = routed_out.view(self.num_local_experts, batch_size, -1, self.hidden_dim).sum(dim=0) if self.enable_expert_parallel: torch.distributed.all_reduce(output_states, op=torch.distributed.ReduceOp.SUM) # TODO: move to hook - return output_states, local_router_scores + return output_states, router_scores class OpenAIMoeRotaryEmbedding(nn.Module): From 80fdcbc7e1c08e673b63eb6babaf2841633730d4 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Wed, 2 Jul 2025 17:33:11 +0000 Subject: [PATCH 090/342] refactor OpenAIMoeExperts to remove EP code --- .../integrations/tensor_parallel.py | 20 +++++++- .../openai_moe/configuration_openai_moe.py | 1 + .../models/openai_moe/modeling_openai_moe.py | 47 ++++++++----------- 3 files changed, 39 insertions(+), 29 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 39628ec66068..b603b3d39d16 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -751,6 +751,24 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, param = param.contiguous() return param + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + if mod.config.enable_expert_parallel: + torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.SUM) # TODO: move to hook + return outputs + + + def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: + # TODO: need an abstract Parallel class that is different from TensorParallelLayer + distribute_module( + module, + device_mesh, + partial(self._prepare_input_fn, None, None), + partial(self._prepare_output_fn, None, None), + ) + + class RouterParallel(TensorParallelLayer): def __init__(self, *args, **kwargs): self.args = args @@ -761,7 +779,7 @@ def __init__(self, *args, **kwargs): def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): input_tensor = inputs[0] if isinstance(input_tensor, DTensor): - raise ValueError("RouterParallel does not support DTensor input for now") + raise NotImplementedError("RouterParallel does not support DTensor input for now") return input_tensor @staticmethod diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 558605894c1e..8cc76763d246 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -45,6 +45,7 @@ class OpenAIMoeConfig(PretrainedConfig): 'layers.*.mlp.experts.gate_up_proj_bias': "grouped_gemm", 'layers.*.mlp.experts.down_proj': "grouped_gemm", 'layers.*.mlp.experts.down_proj_bias': "grouped_gemm", + 'layers.*.mlp.experts': "grouped_gemm", "layers.*.mlp.router": "ep_router", } base_model_pp_plan = { diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index e5ef9f1296b2..7fa93b36f0cd 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -65,19 +65,18 @@ def extra_repr(self): class OpenAIMoeExperts(nn.Module): def __init__(self, config): super().__init__() - self.enable_expert_parallel = config.enable_expert_parallel - self.ep_size = dist.get_world_size() if (self.enable_expert_parallel and dist.is_initialized()) else 1 - self.num_experts = config.num_local_experts // self.ep_size + self.config = config # needed for expert parallelism self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) - self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) - self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) - self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) + self.gate_up_proj = nn.Parameter(torch.empty(config.num_local_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj_bias = nn.Parameter(torch.empty(config.num_local_experts, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((config.num_local_experts, self.expert_dim, self.hidden_size))) + self.down_proj_bias = nn.Parameter(torch.empty(config.num_local_experts, self.hidden_size)) self.alpha = 1.702 + self.enable_expert_parallel = config.enable_expert_parallel - def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, batch_size: int, router_indices=None, routing_weights=None) -> torch.Tensor: """ When training is is more efficient to just loop over the experts and compute the output for each expert as otherwise the memory would explode. @@ -88,13 +87,15 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) selected_experts (torch.Tensor): (batch_size * token_num, top_k) routing_weights (torch.Tensor): (batch_size * token_num, top_k) + batch_size (int) # TODO: it's ugly to pass batch_size here :/ Returns: torch.Tensor """ + num_experts = routing_weights.shape[0] if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute( + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts).permute( 2, 1, 0 ) expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() @@ -115,9 +116,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig ) # (num_tokens, hidden_dim) weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) + next_states = next_states.view(batch_size, -1, self.hidden_size) else: - hidden_states = hidden_states.repeat(self.num_experts, 1) - hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors glu = gate * torch.sigmoid(gate * self.alpha) @@ -125,8 +127,9 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig if not self.enable_expert_parallel: torch.distributed.all_reduce(next_states, op=torch.distributed.ReduceOp.SUM) # TODO: if we can attach hook to `down_proj` we can move this to hook next_states = next_states + self.down_proj_bias[..., None, :] - next_states = next_states.view(-1, self.hidden_size) - return next_states, None # bcz GatherParallel allreduces first element + next_states = next_states.view(num_experts, -1, self.hidden_size) * routing_weights[..., None] # we're throwing away computed routed_out for rest of experts + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size).sum(dim=0) + return next_states class TopKRouter(nn.Linear): # TODO: if i inherit from module it'll be annoying to attach hook to parent module def __init__(self, config): @@ -146,28 +149,16 @@ class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() self.hidden_dim = config.hidden_size - self.enable_expert_parallel = config.enable_expert_parallel - self.ep_size = dist.get_world_size() if (self.enable_expert_parallel and dist.is_initialized()) else 1 - self.ep_rank = dist.get_rank() if (self.enable_expert_parallel and dist.is_initialized()) else 0 - self.num_local_experts = config.num_local_experts // self.ep_size # TODO: bad naming - self.experts = OpenAIMoeExperts(config) self.router = TopKRouter(config) + self.experts = OpenAIMoeExperts(config) def forward(self, hidden_states): # we don't slice weight as its not compile compatible batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) - - routed_out, _ = self.experts(hidden_states, router_indices, router_scores) #TODO: router_indices isn't used inside this func - if self.training: - output_states = routed_out.view(batch_size, -1, self.hidden_dim) - else: - routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None] # we're throwing away computed routed_out for rest of experts - output_states = routed_out.view(self.num_local_experts, batch_size, -1, self.hidden_dim).sum(dim=0) - if self.enable_expert_parallel: - torch.distributed.all_reduce(output_states, op=torch.distributed.ReduceOp.SUM) # TODO: move to hook - return output_states, router_scores + routed_out = self.experts(hidden_states, batch_size=batch_size, router_indices=router_indices, routing_weights=router_scores) #TODO: router_indices isn't used inside this func + return routed_out, router_scores class OpenAIMoeRotaryEmbedding(nn.Module): From 1e9ffa2a241ba2abb4d689d5d1b3a66c575c9fe9 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 3 Jul 2025 09:44:22 +0000 Subject: [PATCH 091/342] revert GatherParallel --- src/transformers/integrations/tensor_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index b603b3d39d16..22259d7efcbd 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -391,7 +391,7 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_ @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): # this op cannot be async, otherwise it completely breaks the outputs of models - # torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False) # TODO: need to create a new class for backward compatibility + torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False) # TODO: rename GatherParallel to ReduceParallel or something return outputs From e33c9057ef346fd6cbe1843b132345d5e32378ee Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 3 Jul 2025 09:50:54 +0000 Subject: [PATCH 092/342] cleaned EP code from modeling --- hf_generate.py | 2 +- src/transformers/configuration_utils.py | 5 ----- .../integrations/tensor_parallel.py | 8 +++---- src/transformers/modeling_utils.py | 22 ++++++++++++++++++- .../openai_moe/configuration_openai_moe.py | 15 ++++--------- .../models/openai_moe/modeling_openai_moe.py | 5 ++--- 6 files changed, 32 insertions(+), 25 deletions(-) diff --git a/hf_generate.py b/hf_generate.py index 2232908cd2cd..d22ce8a0ec73 100644 --- a/hf_generate.py +++ b/hf_generate.py @@ -64,8 +64,8 @@ def main(): enable_expert_parallel=os.environ.get("ENABLE_EXPERT_PARALLEL", "0") == "1", tp_size=tp_size, torch_dtype=torch.bfloat16, - # torch_dtype=torch.float32, attn_implementation="eager", + # key_mapping={"mlp.router": "mlp.router.router"}, ) logger.info(f"Model loaded onto device mesh: {tp_mesh}") device = torch.device(f"cuda:{local_rank}") diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 91f5f878de25..205a7dde8f28 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -302,9 +302,6 @@ def __init__(self, **kwargs): self._attn_implementation_internal = kwargs.pop("attn_implementation", None) self._attn_implementation_autoset = False - # Expert parallelism - self.enable_expert_parallel = kwargs.pop("enable_expert_parallel", False) - # Drop the transformers version info self.transformers_version = kwargs.pop("transformers_version", None) @@ -1020,8 +1017,6 @@ def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None: del d["_attn_implementation_internal"] if "_attn_implementation_autoset" in d: del d["_attn_implementation_autoset"] - if "enable_expert_parallel" in d: - del d["enable_expert_parallel"] # Do not serialize `base_model_tp_plan` for now if "base_model_tp_plan" in d: del d["base_model_tp_plan"] diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 22259d7efcbd..bf53490bc3df 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -736,7 +736,8 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) -class GroupedGemmParallel(TensorParallelLayer): +class GroupedGemmParallel(TensorParallelLayer): # self.experts + # Applies EP to MoE experts def __init__(self): super().__init__() self.use_dtensor = False @@ -751,11 +752,9 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, param = param.contiguous() return param - @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): - if mod.config.enable_expert_parallel: - torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.SUM) # TODO: move to hook + torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.SUM) return outputs @@ -770,6 +769,7 @@ def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: class RouterParallel(TensorParallelLayer): + # applies EP to router def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7cd641971b7f..bc1758509f0d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -777,7 +777,7 @@ def _load_state_dict_into_meta_model( file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) for param_name, empty_param in state_dict.items(): - if param_name not in expected_keys: + if param_name not in expected_keys: # when loading from ckpt, we skip param if doesnt exist in modeling continue # we need to use serialized_param_name as file pointer is untouched @@ -1998,6 +1998,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # - `_pp_plan["layers"][PipelineParallel.outputs]` _pp_plan = None + # Whether expert parallelism is enabled for the model. In that case we override + # `base_model_tp_plan` with expert parallel plan + _enable_expert_parallel = False + # This flag signal that the model can be used as an efficient backend in TGI and vLLM # In practice, it means that they support attention interface functions, fully pass the kwargs # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan @@ -4262,6 +4266,7 @@ def from_pretrained( gguf_file = kwargs.pop("gguf_file", None) tp_plan = kwargs.pop("tp_plan", None) tp_size = kwargs.pop("tp_size", None) + enable_expert_parallel = kwargs.pop("enable_expert_parallel", False) device_mesh = kwargs.pop("device_mesh", None) trust_remote_code = kwargs.pop("trust_remote_code", None) @@ -4624,6 +4629,21 @@ def from_pretrained( device_map=device_map, ) + if enable_expert_parallel: + # TODO: add proper support for ep_plan independently of tp_plan + if config.base_model_tp_plan is None: + raise ValueError("base_model_tp_plan is required when enable_expert_parallel is True") + # We apply ep on MoE layers + config.base_model_tp_plan.update({ + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj_bias": "grouped_gemm", + # TODO: i shouldn't have to do the above, but when removing it, it doesnt partition them + 'layers.*.mlp.experts': "grouped_gemm", + "layers.*.mlp.router": "ep_router", + }) + with ContextManagers(model_init_context): # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 8cc76763d246..e233878b1e4f 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -35,18 +35,11 @@ class OpenAIMoeConfig(PretrainedConfig): # "layers.*.self_attn.o_proj": "rowwise", # "layers.*.self_attn.sinks": "local_rowwise", - # "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", - # "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", - # "layers.*.mlp.experts.down_proj": "local_colwise", - # "layers.*.mlp.experts.down_proj_bias": "local", # TODO: maybe add smthg that says bias exists only once for all TPs + "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", + "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", + "layers.*.mlp.experts.down_proj": "local_colwise", + "layers.*.mlp.experts.down_proj_bias": "local", # TODO: maybe add smthg that says bias exists only once for all TPs # "layers.*.mlp.experts": "gather", - - 'layers.*.mlp.experts.gate_up_proj': "grouped_gemm", - 'layers.*.mlp.experts.gate_up_proj_bias': "grouped_gemm", - 'layers.*.mlp.experts.down_proj': "grouped_gemm", - 'layers.*.mlp.experts.down_proj_bias': "grouped_gemm", - 'layers.*.mlp.experts': "grouped_gemm", - "layers.*.mlp.router": "ep_router", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 7fa93b36f0cd..c3d1aa370785 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -74,7 +74,6 @@ def __init__(self, config): self.down_proj = nn.Parameter(torch.empty((config.num_local_experts, self.expert_dim, self.hidden_size))) self.down_proj_bias = nn.Parameter(torch.empty(config.num_local_experts, self.hidden_size)) self.alpha = 1.702 - self.enable_expert_parallel = config.enable_expert_parallel def forward(self, hidden_states: torch.Tensor, batch_size: int, router_indices=None, routing_weights=None) -> torch.Tensor: """ @@ -124,8 +123,8 @@ def forward(self, hidden_states: torch.Tensor, batch_size: int, router_indices=N gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors glu = gate * torch.sigmoid(gate * self.alpha) next_states = torch.bmm(((up + 1) * glu), self.down_proj) - if not self.enable_expert_parallel: - torch.distributed.all_reduce(next_states, op=torch.distributed.ReduceOp.SUM) # TODO: if we can attach hook to `down_proj` we can move this to hook + # if not self.enable_expert_parallel: + # torch.distributed.all_reduce(next_states, op=torch.distributed.ReduceOp.SUM) # TODO: if we can attach hook to `down_proj` we can move this to hook next_states = next_states + self.down_proj_bias[..., None, :] next_states = next_states.view(num_experts, -1, self.hidden_size) * routing_weights[..., None] # we're throwing away computed routed_out for rest of experts next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size).sum(dim=0) From b6f132afec99212ec2c8c6fb226d6a0f133bbfe8 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 3 Jul 2025 09:52:34 +0000 Subject: [PATCH 093/342] assert -> raise --- src/transformers/integrations/tensor_parallel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index bf53490bc3df..b9d37e074716 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -745,7 +745,8 @@ def __init__(self): def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): ep_rank = rank global_num_experts = empty_param.shape[0] - assert global_num_experts % device_mesh.size() == 0, f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0" + if global_num_experts % device_mesh.size() != 0: + raise ValueError(f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0") local_num_experts = global_num_experts // device_mesh.size() param = param[ep_rank*local_num_experts:(ep_rank+1)*local_num_experts].to(param_casting_dtype) if to_contiguous: From 910671c3bb23cd4c0c2110b96e545c54bc94c6ad Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 3 Jul 2025 10:23:58 +0000 Subject: [PATCH 094/342] fix and cleanup TP=2 --- hf_generate.py | 2 +- .../integrations/tensor_parallel.py | 5 ++- .../openai_moe/configuration_openai_moe.py | 4 +- .../models/openai_moe/modeling_openai_moe.py | 39 ++++++++++++------- 4 files changed, 32 insertions(+), 18 deletions(-) diff --git a/hf_generate.py b/hf_generate.py index d22ce8a0ec73..733ce6b48755 100644 --- a/hf_generate.py +++ b/hf_generate.py @@ -1,5 +1,5 @@ """ -TP_SIZE=8 torchrun --nproc_per_node=8 hf_generate.py +ENABLE_EXPERT_PARALLEL=1 TP_SIZE=8 torchrun --nproc_per_node=8 hf_generate.py """ from transformers import AutoTokenizer, OpenAIMoeForCausalLM diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index b9d37e074716..22011f9fec04 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -391,6 +391,8 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_ @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): # this op cannot be async, otherwise it completely breaks the outputs of models + if isinstance(outputs, torch.Tensor): + raise ValueError("GatherParallel should not be used with a single tensor, it should be used with a tuple of tensors") torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False) # TODO: rename GatherParallel to ReduceParallel or something return outputs @@ -418,7 +420,8 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, param = param[...].to(param_casting_dtype) if to_contiguous: param = param.contiguous() - # param = param / device_mesh.size() # TODO should be optionable + param = param / device_mesh.size() # TODO should be optionable + # TODO: assumes parent module will allreduce the output afterwards (e.g rowlinear bias is IsolatedParallel and parent module is GatherParallel) return param def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index e233878b1e4f..a58edd581b37 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -38,8 +38,8 @@ class OpenAIMoeConfig(PretrainedConfig): "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", "layers.*.mlp.experts.down_proj": "local_colwise", - "layers.*.mlp.experts.down_proj_bias": "local", # TODO: maybe add smthg that says bias exists only once for all TPs - # "layers.*.mlp.experts": "gather", + "layers.*.mlp.experts.down_proj_bias": "local", # TODO: add smthg that says bias exists only once for all TPs + "layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index c3d1aa370785..25d4acc4ca90 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Tuple, Union -import torch.distributed as dist import torch from torch import nn @@ -75,7 +74,7 @@ def __init__(self, config): self.down_proj_bias = nn.Parameter(torch.empty(config.num_local_experts, self.hidden_size)) self.alpha = 1.702 - def forward(self, hidden_states: torch.Tensor, batch_size: int, router_indices=None, routing_weights=None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: """ When training is is more efficient to just loop over the experts and compute the output for each expert as otherwise the memory would explode. @@ -83,13 +82,14 @@ def forward(self, hidden_states: torch.Tensor, batch_size: int, router_indices=N For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. Args: - hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) selected_experts (torch.Tensor): (batch_size * token_num, top_k) routing_weights (torch.Tensor): (batch_size * token_num, top_k) - batch_size (int) # TODO: it's ugly to pass batch_size here :/ Returns: torch.Tensor """ + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) num_experts = routing_weights.shape[0] if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) @@ -123,41 +123,52 @@ def forward(self, hidden_states: torch.Tensor, batch_size: int, router_indices=N gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors glu = gate * torch.sigmoid(gate * self.alpha) next_states = torch.bmm(((up + 1) * glu), self.down_proj) - # if not self.enable_expert_parallel: - # torch.distributed.all_reduce(next_states, op=torch.distributed.ReduceOp.SUM) # TODO: if we can attach hook to `down_proj` we can move this to hook next_states = next_states + self.down_proj_bias[..., None, :] - next_states = next_states.view(num_experts, -1, self.hidden_size) * routing_weights[..., None] # we're throwing away computed routed_out for rest of experts - next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size).sum(dim=0) - return next_states + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) # (num_experts, batch_size, seq_len, hidden_size) + return next_states, None class TopKRouter(nn.Linear): # TODO: if i inherit from module it'll be annoying to attach hook to parent module def __init__(self, config): super().__init__(config.hidden_size, config.num_local_experts, bias=True) self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size # TODO: is it better to define self.weight and self.bias as nn.Parameter instead to keep the same namings: mlp.router.weight instead of mlp.router.router.weight? def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = super().forward(hidden_states) # (seq_len, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) # (num_experts, seq_len) return router_scores, router_indices +class TokenDispatcher(nn.Module): + # TODO: this only exists to add EP hook + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + def forward(self, routed_out, routing_weights): + # routed_out is (num_experts, batch_size, seq_len, hidden_size) + routed_out = routed_out * routing_weights[:, None, :, None] # we're throwing away computed routed_out for rest of experts + routed_out = routed_out.sum(dim=0) # (batch_size, seq_len, hidden_size) + return routed_out + class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() - self.hidden_dim = config.hidden_size self.router = TopKRouter(config) self.experts = OpenAIMoeExperts(config) + self.token_dispatcher = TokenDispatcher(config) # TODO: i need this class because TP needs hook right after down_proj_bias, and EP needs hook right after routing_weights def forward(self, hidden_states): # we don't slice weight as its not compile compatible - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) - routed_out = self.experts(hidden_states, batch_size=batch_size, router_indices=router_indices, routing_weights=router_scores) #TODO: router_indices isn't used inside this func - return routed_out, router_scores + routed_out, _ = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) #TODO: router_indices isn't used inside this func + hidden_states = self.token_dispatcher(routed_out, router_scores) + return hidden_states, router_scores class OpenAIMoeRotaryEmbedding(nn.Module): From e575e85abfc5e1ce4e7cef4eb75301bf9523cf63 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 3 Jul 2025 10:57:32 +0000 Subject: [PATCH 095/342] can't get both EP and TP cases hooked correctly, im gonna cry T.T --- src/transformers/integrations/tensor_parallel.py | 6 ++++-- src/transformers/modeling_utils.py | 8 +++++++- src/transformers/models/openai_moe/modeling_openai_moe.py | 1 + 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 22011f9fec04..96d80763942e 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -758,7 +758,9 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): - torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.SUM) + if isinstance(outputs, torch.Tensor): + raise ValueError("GroupedGemmParallel should not be used with a single tensor, it should be used with a tuple of tensors") + torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM) return outputs @@ -939,7 +941,7 @@ def shard_and_distribute_module( """ param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name tp_plan = model._tp_plan - module_to_tp = model.get_submodule(param_name) + module_to_tp = model.get_submodule(param_name) # TODO: can i loop over modules? rank = int(rank) current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bc1758509f0d..efe8f4cb77ce 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4640,9 +4640,15 @@ def from_pretrained( "layers.*.mlp.experts.down_proj": "grouped_gemm", "layers.*.mlp.experts.down_proj_bias": "grouped_gemm", # TODO: i shouldn't have to do the above, but when removing it, it doesnt partition them - 'layers.*.mlp.experts': "grouped_gemm", + 'layers.*.mlp.experts': None, + 'layers.*.mlp.token_dispatcher': "gather", + 'layers.*.mlp': "gather", "layers.*.mlp.router": "ep_router", }) + # Remove None values from the tp plan + config.base_model_tp_plan = {k: v for k, v in config.base_model_tp_plan.items() if v is not None} + + print(f"using base_model_tp_plan: {config.base_model_tp_plan}") with ContextManagers(model_init_context): # Let's make sure we don't run the init function of buffer modules diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 25d4acc4ca90..2d2afb2b9027 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -154,6 +154,7 @@ def forward(self, routed_out, routing_weights): # routed_out is (num_experts, batch_size, seq_len, hidden_size) routed_out = routed_out * routing_weights[:, None, :, None] # we're throwing away computed routed_out for rest of experts routed_out = routed_out.sum(dim=0) # (batch_size, seq_len, hidden_size) + torch.distributed.all_reduce(routed_out, op=torch.distributed.ReduceOp.SUM) # TODO: move to hook return routed_out class OpenAIMoeMLP(nn.Module): From c0118ade6837116319f6688db438ccae534966c8 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Thu, 3 Jul 2025 13:33:36 +0200 Subject: [PATCH 096/342] update missing mapping. --- .../openai_moe/convert_openai_weights_to_hf.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 0a39816bb846..56a7ada7aa67 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -237,18 +237,24 @@ def __init__( # TODO 1st donwload the vocabfile!!! tokenizer = tiktoken.get_encoding(vocab_file) self.additional_special_tokens = {} - # Build Harmony special tokens → IDs. + # Complete list of Harmony special tokens as per o200k_harmony spec special_tokens_map = { - "<|reserved_199998|>": 199998, - "<|endoftext|>": 199999, # same as <|end|> + "<|startoftext|>": 199998, + "<|endoftext|>": 199999, + "<|return|>": 200002, "<|constrain|>": 200003, + "<|channel|>": 200005, + "<|start|>": 200006, + "<|end|>": 200007, + "<|message|>": 200008, "<|call|>": 200012, "<|endofprompt|>": 200018, } - # Add the remaining reserved slots while skipping already-used IDs. + # Add the remaining reserved slots while skipping IDs already present above. + used_ids = set(special_tokens_map.values()) for k in range(199999, 200018): - if k in (200003, 200012): + if k in used_ids: continue special_tokens_map.setdefault(f"<|reserved_{k}|>", k) From 3f5178c85db387ed66724d0af4377c24bd1f31bb Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 3 Jul 2025 11:42:12 +0000 Subject: [PATCH 097/342] we can now attach hooks to modules separately --- .../integrations/tensor_parallel.py | 5 +++-- src/transformers/modeling_utils.py | 18 +++++++++++++++++- .../models/openai_moe/modeling_openai_moe.py | 1 - 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 96d80763942e..998e355b81cb 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -392,7 +392,8 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): # this op cannot be async, otherwise it completely breaks the outputs of models if isinstance(outputs, torch.Tensor): - raise ValueError("GatherParallel should not be used with a single tensor, it should be used with a tuple of tensors") + torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.SUM, async_op=False) + # TODO: we assume we want to allreduce first element of tuple torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False) # TODO: rename GatherParallel to ReduceParallel or something return outputs @@ -929,7 +930,7 @@ def __init__(self): def shard_and_distribute_module( model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh -): +): # TODO: rename to shard_and_distribute_param r""" Main uses cases: - column / rowise parallelism, you just shard all the weights of the layer (weight and bias) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index efe8f4cb77ce..a9dd17d9eac3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2095,6 +2095,23 @@ def post_init(self): f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}" ) + # loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit + # device_mesh = self._device_mesh # TODO: can we attach device_mesh to model + device_mesh = torch.distributed.device_mesh.DeviceMesh.from_group(torch.distributed.distributed_c10d._get_default_group(), device_type="cuda") # TODO: + for name, module in self.named_modules(): + if not getattr(module, "_is_hooked", False): # this adds gather hook to layers.*.mlp.experts and skips the rest + from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module + add_tensor_parallel_hooks_to_module( + model=self, + module=module, + tp_plan=self._tp_plan, + layer_name="", # TODO: make this optional? + current_module_plan=_get_parameter_tp_plan(parameter_name=name, tp_plan=self._tp_plan), + device_mesh=device_mesh, + parameter_name=None + ) + module._is_hooked = True + def dequantize(self): """ Potentially dequantize the model in case it has been quantized by a quantization method that support @@ -4642,7 +4659,6 @@ def from_pretrained( # TODO: i shouldn't have to do the above, but when removing it, it doesnt partition them 'layers.*.mlp.experts': None, 'layers.*.mlp.token_dispatcher': "gather", - 'layers.*.mlp': "gather", "layers.*.mlp.router": "ep_router", }) # Remove None values from the tp plan diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 2d2afb2b9027..25d4acc4ca90 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -154,7 +154,6 @@ def forward(self, routed_out, routing_weights): # routed_out is (num_experts, batch_size, seq_len, hidden_size) routed_out = routed_out * routing_weights[:, None, :, None] # we're throwing away computed routed_out for rest of experts routed_out = routed_out.sum(dim=0) # (batch_size, seq_len, hidden_size) - torch.distributed.all_reduce(routed_out, op=torch.distributed.ReduceOp.SUM) # TODO: move to hook return routed_out class OpenAIMoeMLP(nn.Module): From 7ce144d50efecf52ab70a9a0804ea1cb36de205a Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 3 Jul 2025 12:24:59 +0000 Subject: [PATCH 098/342] =?UTF-8?q?amend.=20Now=20EP2,=20EP1,=20TP2,=20TP1?= =?UTF-8?q?=20match=20=F0=9F=A5=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transformers/integrations/tensor_parallel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 998e355b81cb..3c70da3b5b22 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -393,8 +393,9 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me # this op cannot be async, otherwise it completely breaks the outputs of models if isinstance(outputs, torch.Tensor): torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.SUM, async_op=False) - # TODO: we assume we want to allreduce first element of tuple - torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False) # TODO: rename GatherParallel to ReduceParallel or something + else: + # TODO: we assume we want to allreduce first element of tuple + torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False) # TODO: rename GatherParallel to ReduceParallel or something return outputs From c39dc4a0cca55b1c3a93038da83200a8de400673 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Thu, 3 Jul 2025 14:37:36 +0200 Subject: [PATCH 099/342] add tests for tokenizer. --- .../openai_moe/test_modeling_openai_moe.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index 8310505ddd6a..ad85ab73e331 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -31,6 +31,8 @@ require_read_token, require_torch, require_torch_accelerator, + require_tokenizers, + require_tiktoken, slow, torch_device, ) @@ -399,3 +401,102 @@ def test_compile_static_cache(self): ) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + +# ============================================================= +# Tokenizer ↔️ tiktoken equivalence checks +# ============================================================= + + +@require_tokenizers +@require_tiktoken +class OpenAIMoeTokenizationIntegrationTest(unittest.TestCase): + """Ensure the HF tokenizer extracted with `OpenAIMoeConverter` remains byte-level identical to + the reference `o200k_harmony` tiktoken encoding. + + Set the environment variable ``OPENAI_MOE_TOKENIZER_PATH`` to override the tokenizer location + (repo ID or local path). When the variable is unset the default public checkpoint is used. + """ + + @classmethod + def setUpClass(cls): + import os + import tiktoken + + # Load the HF tokenizer (fast implementation) + tokenizer_id = os.getenv("OPENAI_MOE_TOKENIZER_PATH", "meta-openai/Openai-2-7b-hf") + cls.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True) + + # Build the (pre-release) o200k_harmony encoding for tiktoken + o200k_base = tiktoken.get_encoding("o200k_base") + cls.tkt_encoding = tiktoken.Encoding( + name="o200k_harmony", + pat_str=o200k_base._pat_str, + mergeable_ranks=o200k_base._mergeable_ranks, + special_tokens={ + **o200k_base._special_tokens, + "<|startoftext|>": 199998, + "<|endoftext|>": 199999, + "<|return|>": 200002, + "<|constrain|>": 200003, + "<|channel|>": 200005, + "<|start|>": 200006, + "<|end|>": 200007, + "<|message|>": 200008, + "<|call|>": 200012, + }, + ) + + # -------------------------------------------- + # Helper util + # -------------------------------------------- + def _assert_equivalent(self, string: str): + # Encode + ids_hf = self.tokenizer.encode(string) + ids_tk = self.tkt_encoding.encode(string, allowed_special="all") + self.assertEqual(ids_hf, ids_tk, msg=f"HF vs tiktoken mismatch on: {string!r}") + + # Decode round-trip (with special tokens preserved) + decoded_hf = self.tokenizer.decode(ids_hf, skip_special_tokens=False) + decoded_tk = self.tkt_encoding.decode(ids_tk) + self.assertEqual(decoded_hf, decoded_tk, msg=f"Decode diff on: {string!r}") + + # -------------------------------------------- + # Quick unit test on a handful of hand-picked strings + # -------------------------------------------- + def test_equivalence_on_simple_strings(self): + samples = [ + "", + "Hello world!", + " ", + "I ❤️ Transformers 🤗", + "def foo(x): return x * x", + "<|start|><|message|>user\nHello<|end|>", + ] + + for s in samples: + with self.subTest(sample=s): + self._assert_equivalent(s) + + # -------------------------------------------- + # Heavier integration test gated behind env flag + # -------------------------------------------- + @unittest.skipIf( + os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0", + "Set RUN_TOKENIZER_INTEGRATION=1 to enable slow tokenizer equivalence tests", + ) + @slow + def test_equivalence_on_public_datasets(self): + import tqdm + from datasets import load_dataset + + # 1) Code-to-text dataset + ds = load_dataset("google/code_x_glue_ct_code_to_text", "go") + for item in tqdm.tqdm(ds["validation"]): + self._assert_equivalent(item["code"]) + + # 2) XNLI premises across all languages + ds = load_dataset("facebook/xnli", "all_languages") + for item in tqdm.tqdm(ds["train"]): + for premise in item["premise"].values(): + self._assert_equivalent(premise) From e28234a27eff74ac0e77037aac23d64eee189bbc Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Thu, 3 Jul 2025 14:51:03 +0200 Subject: [PATCH 100/342] up. --- tests/models/openai_moe/test_modeling_openai_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index ad85ab73e331..dea18f8d0b3d 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -424,8 +424,7 @@ def setUpClass(cls): import tiktoken # Load the HF tokenizer (fast implementation) - tokenizer_id = os.getenv("OPENAI_MOE_TOKENIZER_PATH", "meta-openai/Openai-2-7b-hf") - cls.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True) + cls.tokenizer = AutoTokenizer.from_pretrained("/fsx/vb/converted_model") # Build the (pre-release) o200k_harmony encoding for tiktoken o200k_base = tiktoken.get_encoding("o200k_base") From fc1a636e50680cdae5899bfaffc9008ab0c22ab6 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Thu, 3 Jul 2025 14:53:24 +0200 Subject: [PATCH 101/342] up. --- tests/models/openai_moe/test_modeling_openai_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index dea18f8d0b3d..a698903d0819 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -13,6 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch openai model.""" +import os import unittest import torch From 95ffa0ef2ce1e299da491463fe02336558918f1e Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Thu, 3 Jul 2025 14:57:36 +0200 Subject: [PATCH 102/342] up. --- tests/models/openai_moe/test_modeling_openai_moe.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index a698903d0819..e799a1e2770a 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -414,9 +414,6 @@ def test_compile_static_cache(self): class OpenAIMoeTokenizationIntegrationTest(unittest.TestCase): """Ensure the HF tokenizer extracted with `OpenAIMoeConverter` remains byte-level identical to the reference `o200k_harmony` tiktoken encoding. - - Set the environment variable ``OPENAI_MOE_TOKENIZER_PATH`` to override the tokenizer location - (repo ID or local path). When the variable is unset the default public checkpoint is used. """ @classmethod @@ -481,10 +478,6 @@ def test_equivalence_on_simple_strings(self): # -------------------------------------------- # Heavier integration test gated behind env flag # -------------------------------------------- - @unittest.skipIf( - os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0", - "Set RUN_TOKENIZER_INTEGRATION=1 to enable slow tokenizer equivalence tests", - ) @slow def test_equivalence_on_public_datasets(self): import tqdm From c7d2b60983763b6f62025a91428bfb2eb2e58741 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 3 Jul 2025 13:04:43 +0000 Subject: [PATCH 103/342] final cleaning --- .../integrations/tensor_parallel.py | 29 +++++---------- .../models/openai_moe/modeling_openai_moe.py | 35 +++++++++++++------ 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 3c70da3b5b22..8c09489b9609 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -741,8 +741,10 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) -class GroupedGemmParallel(TensorParallelLayer): # self.experts - # Applies EP to MoE experts +class GroupedGemmParallel(TensorParallelLayer): + """ + Applies Expert Parallelism to MoE experts by loading the correct experts on each device. + """ def __init__(self): super().__init__() self.use_dtensor = False @@ -758,26 +760,10 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, param = param.contiguous() return param - @staticmethod - def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): - if isinstance(outputs, torch.Tensor): - raise ValueError("GroupedGemmParallel should not be used with a single tensor, it should be used with a tuple of tensors") - torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM) - return outputs - - - def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: - # TODO: need an abstract Parallel class that is different from TensorParallelLayer - distribute_module( - module, - device_mesh, - partial(self._prepare_input_fn, None, None), - partial(self._prepare_output_fn, None, None), - ) - - class RouterParallel(TensorParallelLayer): - # applies EP to router + """ + Applies Expert Parallelism to MoE router + """ def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs @@ -799,6 +785,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me return router_scores, router_indices def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + # TODO: i'd like for this to be the default param = param[...].to(param_casting_dtype) if to_contiguous: param = param.contiguous() diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 25d4acc4ca90..473dc0488b74 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -64,7 +64,6 @@ def extra_repr(self): class OpenAIMoeExperts(nn.Module): def __init__(self, config): super().__init__() - self.config = config # needed for expert parallelism self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size @@ -91,7 +90,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) num_experts = routing_weights.shape[0] - if self.training: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + num_experts = routing_weights.shape[0] + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts).permute( next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts).permute( @@ -112,17 +114,14 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig gated_output = (up + 1) * glu # (num_tokens, interm_dim) out = ( gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + next_states = next_states.view(batch_size, -1, self.hidden_size) ) # (num_tokens, hidden_dim) - weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) - next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) next_states = next_states.view(batch_size, -1, self.hidden_size) else: hidden_states = hidden_states.repeat(num_experts, 1) hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) - gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors - glu = gate * torch.sigmoid(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) # (num_experts, batch_size, seq_len, hidden_size) return next_states, None @@ -133,7 +132,6 @@ def __init__(self, config): self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size - # TODO: is it better to define self.weight and self.bias as nn.Parameter instead to keep the same namings: mlp.router.weight instead of mlp.router.router.weight? def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) @@ -144,7 +142,7 @@ def forward(self, hidden_states): return router_scores, router_indices class TokenDispatcher(nn.Module): - # TODO: this only exists to add EP hook + # this module is important to add EP hook def __init__(self, config): super().__init__() self.config = config @@ -156,6 +154,23 @@ def forward(self, routed_out, routing_weights): routed_out = routed_out.sum(dim=0) # (batch_size, seq_len, hidden_size) return routed_out +class TopKRouter(nn.Linear): # TODO: if i inherit from module it'll be annoying to attach hook to parent module + def __init__(self, config): + super().__init__(config.hidden_size, config.num_local_experts, bias=True) + self.router = TopKRouter(config) + # TODO: is it better to define self.weight and self.bias as nn.Parameter instead to keep the same namings: mlp.router.weight instead of mlp.router.router.weight? + self.token_dispatcher = TokenDispatcher(config) + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = super().forward(hidden_states) # (seq_len, num_experts) + router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) + routed_out, _ = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) #TODO: router_indices isn't used inside this func + hidden_states = self.token_dispatcher(routed_out, router_scores) + return hidden_states, router_scores + # routed_out is (num_experts, batch_size, seq_len, hidden_size) + routed_out = routed_out * routing_weights[:, None, :, None] # we're throwing away computed routed_out for rest of experts + routed_out = routed_out.sum(dim=0) # (batch_size, seq_len, hidden_size) + return routed_out + class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() From b0ef0ec8fe08fa1d2ce2b9b336575f3ec1245140 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Thu, 3 Jul 2025 15:05:18 +0200 Subject: [PATCH 104/342] up. --- tests/models/openai_moe/test_modeling_openai_moe.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index e799a1e2770a..a6f2ea8e05c0 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -403,12 +403,6 @@ def test_compile_static_cache(self): static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) - -# ============================================================= -# Tokenizer ↔️ tiktoken equivalence checks -# ============================================================= - - @require_tokenizers @require_tiktoken class OpenAIMoeTokenizationIntegrationTest(unittest.TestCase): @@ -416,17 +410,16 @@ class OpenAIMoeTokenizationIntegrationTest(unittest.TestCase): the reference `o200k_harmony` tiktoken encoding. """ - @classmethod - def setUpClass(cls): + def setUp(self): import os import tiktoken # Load the HF tokenizer (fast implementation) - cls.tokenizer = AutoTokenizer.from_pretrained("/fsx/vb/converted_model") + self.tokenizer = AutoTokenizer.from_pretrained("/fsx/vb/converted_model") # Build the (pre-release) o200k_harmony encoding for tiktoken o200k_base = tiktoken.get_encoding("o200k_base") - cls.tkt_encoding = tiktoken.Encoding( + self.tkt_encoding = tiktoken.Encoding( name="o200k_harmony", pat_str=o200k_base._pat_str, mergeable_ranks=o200k_base._mergeable_ranks, From 271caa41fbe12daf344510ccfb75d23a8b44a159 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 3 Jul 2025 13:08:42 +0000 Subject: [PATCH 105/342] amend --- .../models/openai_moe/modeling_openai_moe.py | 34 +++++-------------- 1 file changed, 8 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 473dc0488b74..2b0c1ea2d701 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -90,10 +90,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) num_experts = routing_weights.shape[0] - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) - num_experts = routing_weights.shape[0] - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts).permute( + if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts).permute( @@ -114,14 +111,17 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig gated_output = (up + 1) * glu # (num_tokens, interm_dim) out = ( gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - next_states = next_states.view(batch_size, -1, self.hidden_size) ) # (num_tokens, hidden_dim) - hidden_states = hidden_states.repeat(num_experts, 1) - hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) + next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) next_states = next_states.view(batch_size, -1, self.hidden_size) else: hidden_states = hidden_states.repeat(num_experts, 1) hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) # (num_experts, batch_size, seq_len, hidden_size) return next_states, None @@ -154,30 +154,12 @@ def forward(self, routed_out, routing_weights): routed_out = routed_out.sum(dim=0) # (batch_size, seq_len, hidden_size) return routed_out -class TopKRouter(nn.Linear): # TODO: if i inherit from module it'll be annoying to attach hook to parent module - def __init__(self, config): - super().__init__(config.hidden_size, config.num_local_experts, bias=True) - self.router = TopKRouter(config) - # TODO: is it better to define self.weight and self.bias as nn.Parameter instead to keep the same namings: mlp.router.weight instead of mlp.router.router.weight? - self.token_dispatcher = TokenDispatcher(config) - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = super().forward(hidden_states) # (seq_len, num_experts) - router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) - routed_out, _ = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) #TODO: router_indices isn't used inside this func - hidden_states = self.token_dispatcher(routed_out, router_scores) - return hidden_states, router_scores - # routed_out is (num_experts, batch_size, seq_len, hidden_size) - routed_out = routed_out * routing_weights[:, None, :, None] # we're throwing away computed routed_out for rest of experts - routed_out = routed_out.sum(dim=0) # (batch_size, seq_len, hidden_size) - return routed_out - class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() self.router = TopKRouter(config) self.experts = OpenAIMoeExperts(config) - self.token_dispatcher = TokenDispatcher(config) # TODO: i need this class because TP needs hook right after down_proj_bias, and EP needs hook right after routing_weights - + self.token_dispatcher = TokenDispatcher(config) def forward(self, hidden_states): # we don't slice weight as its not compile compatible router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) From bea1c418bae77beffd22598ecb752f9d1784cd0b Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 3 Jul 2025 13:09:25 +0000 Subject: [PATCH 106/342] . --- src/transformers/integrations/tensor_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 8c09489b9609..3a9cbac56d11 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -127,7 +127,7 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> Opti The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight"). """ generic_param_name = re.sub(r"\d+", "*", parameter_name) - if generic_param_name in tp_plan: # TODO: i can't define hooks for parent modules, only leaf modules who have params + if generic_param_name in tp_plan: return tp_plan[generic_param_name] elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan: return tp_plan[generic_param_name.rsplit(".", 1)[0]] From 37a3caf086e415ed7c59b3e0f553bb9eea865879 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Thu, 3 Jul 2025 15:52:23 +0200 Subject: [PATCH 107/342] up. --- .../openai_moe/test_modeling_openai_moe.py | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index a6f2ea8e05c0..c81ed898116f 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -403,8 +403,6 @@ def test_compile_static_cache(self): static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) -@require_tokenizers -@require_tiktoken class OpenAIMoeTokenizationIntegrationTest(unittest.TestCase): """Ensure the HF tokenizer extracted with `OpenAIMoeConverter` remains byte-level identical to the reference `o200k_harmony` tiktoken encoding. @@ -437,9 +435,6 @@ def setUp(self): }, ) - # -------------------------------------------- - # Helper util - # -------------------------------------------- def _assert_equivalent(self, string: str): # Encode ids_hf = self.tokenizer.encode(string) @@ -451,26 +446,6 @@ def _assert_equivalent(self, string: str): decoded_tk = self.tkt_encoding.decode(ids_tk) self.assertEqual(decoded_hf, decoded_tk, msg=f"Decode diff on: {string!r}") - # -------------------------------------------- - # Quick unit test on a handful of hand-picked strings - # -------------------------------------------- - def test_equivalence_on_simple_strings(self): - samples = [ - "", - "Hello world!", - " ", - "I ❤️ Transformers 🤗", - "def foo(x): return x * x", - "<|start|><|message|>user\nHello<|end|>", - ] - - for s in samples: - with self.subTest(sample=s): - self._assert_equivalent(s) - - # -------------------------------------------- - # Heavier integration test gated behind env flag - # -------------------------------------------- @slow def test_equivalence_on_public_datasets(self): import tqdm From 300e3c9b425a71c92ac5f0772a06e86d1f11e127 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 3 Jul 2025 14:35:09 +0000 Subject: [PATCH 108/342] =?UTF-8?q?create=20DistributedConfig=20=E2=9C=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- hf_generate.py | 6 +- src/transformers/distributed/__init__.py | 35 ++++++ .../distributed/configuration_utils.py | 118 ++++++++++++++++++ src/transformers/modeling_utils.py | 8 +- 4 files changed, 161 insertions(+), 6 deletions(-) create mode 100644 src/transformers/distributed/__init__.py create mode 100644 src/transformers/distributed/configuration_utils.py diff --git a/hf_generate.py b/hf_generate.py index 733ce6b48755..6750d65085e2 100644 --- a/hf_generate.py +++ b/hf_generate.py @@ -10,6 +10,7 @@ from torch.distributed.tensor.experimental import context_parallel from torch.nn.attention import SDPBackend, sdpa_kernel from torch.distributed.device_mesh import DeviceMesh +from transformers.distributed import DistributedConfig # model_id = "ft-hf-o-c/random-checkpoint-converted-20b" # openai = "/fsx/vb/pytorch-os-mini-final-quantized-moe-sharded" @@ -61,12 +62,15 @@ def main(): model_id, device_mesh=tp_mesh if dist.is_initialized() else None, tp_plan="auto", - enable_expert_parallel=os.environ.get("ENABLE_EXPERT_PARALLEL", "0") == "1", tp_size=tp_size, torch_dtype=torch.bfloat16, attn_implementation="eager", # key_mapping={"mlp.router": "mlp.router.router"}, + distributed_config=DistributedConfig( + enable_expert_parallel=os.environ.get("ENABLE_EXPERT_PARALLEL", "0") == "1" + ) ) + logger.info(f"Model loaded onto device mesh: {tp_mesh}") device = torch.device(f"cuda:{local_rank}") logger.info(f"Using device: {device} for non-model tensors") diff --git a/src/transformers/distributed/__init__.py b/src/transformers/distributed/__init__.py new file mode 100644 index 000000000000..8565699f7ebf --- /dev/null +++ b/src/transformers/distributed/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ..utils import _LazyModule + + +_import_structure = { + "configuration_utils": [ + "DistributedConfig", + ], +} + + +if TYPE_CHECKING: + from .configuration_utils import ( + DistributedConfig, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/distributed/configuration_utils.py b/src/transformers/distributed/configuration_utils.py new file mode 100644 index 000000000000..e3087b0700ae --- /dev/null +++ b/src/transformers/distributed/configuration_utils.py @@ -0,0 +1,118 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +import json +from typing import Any, Dict, Union +import os +import copy + + +@dataclass +class DistributedConfig: + """ + Base class for distributed configs + """ + + enable_expert_parallel: bool = False + # TODO: add tp_plan, pp_plan, device_mesh etc.. + + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a DistributedConfig instance from a dictionary of parameters. + Args: + config_dict (Dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. + + Returns: + DistributedConfig: Instance of DistributedConfig constructed from the dictionary. + """ + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a9dd17d9eac3..3dc5b69d6478 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -47,7 +47,6 @@ from transformers.utils import is_torchao_available - if is_torchao_available(): from torchao.quantization import Int4WeightOnlyConfig @@ -55,6 +54,7 @@ from .configuration_utils import PretrainedConfig from .dynamic_module_utils import custom_object_save from .generation import CompileConfig, GenerationConfig +from .distributed import DistributedConfig from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled from .integrations.accelerate import find_tied_parameters, init_empty_weights from .integrations.deepspeed import _load_state_dict_into_zero3_model @@ -4283,7 +4283,7 @@ def from_pretrained( gguf_file = kwargs.pop("gguf_file", None) tp_plan = kwargs.pop("tp_plan", None) tp_size = kwargs.pop("tp_size", None) - enable_expert_parallel = kwargs.pop("enable_expert_parallel", False) + distributed_config : DistributedConfig = kwargs.pop("distributed_config", None) device_mesh = kwargs.pop("device_mesh", None) trust_remote_code = kwargs.pop("trust_remote_code", None) @@ -4646,7 +4646,7 @@ def from_pretrained( device_map=device_map, ) - if enable_expert_parallel: + if distributed_config is not None and distributed_config.enable_expert_parallel: # TODO: add proper support for ep_plan independently of tp_plan if config.base_model_tp_plan is None: raise ValueError("base_model_tp_plan is required when enable_expert_parallel is True") @@ -4664,8 +4664,6 @@ def from_pretrained( # Remove None values from the tp plan config.base_model_tp_plan = {k: v for k, v in config.base_model_tp_plan.items() if v is not None} - print(f"using base_model_tp_plan: {config.base_model_tp_plan}") - with ContextManagers(model_init_context): # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs) From 4201a2dce60ce73060349b7910cefd514be4cc07 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 3 Jul 2025 14:56:51 +0000 Subject: [PATCH 109/342] move ep plan to config --- src/transformers/modeling_utils.py | 22 +++---------------- .../openai_moe/configuration_openai_moe.py | 15 +++++++++++++ 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3dc5b69d6478..a9b921ea01da 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1998,10 +1998,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # - `_pp_plan["layers"][PipelineParallel.outputs]` _pp_plan = None - # Whether expert parallelism is enabled for the model. In that case we override - # `base_model_tp_plan` with expert parallel plan - _enable_expert_parallel = False - # This flag signal that the model can be used as an efficient backend in TGI and vLLM # In practice, it means that they support attention interface functions, fully pass the kwargs # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan @@ -4648,21 +4644,9 @@ def from_pretrained( if distributed_config is not None and distributed_config.enable_expert_parallel: # TODO: add proper support for ep_plan independently of tp_plan - if config.base_model_tp_plan is None: - raise ValueError("base_model_tp_plan is required when enable_expert_parallel is True") - # We apply ep on MoE layers - config.base_model_tp_plan.update({ - "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", - "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", - "layers.*.mlp.experts.down_proj": "grouped_gemm", - "layers.*.mlp.experts.down_proj_bias": "grouped_gemm", - # TODO: i shouldn't have to do the above, but when removing it, it doesnt partition them - 'layers.*.mlp.experts': None, - 'layers.*.mlp.token_dispatcher': "gather", - "layers.*.mlp.router": "ep_router", - }) - # Remove None values from the tp plan - config.base_model_tp_plan = {k: v for k, v in config.base_model_tp_plan.items() if v is not None} + if config.base_model_ep_plan is None: + raise ValueError("base_model_ep_plan is required when enable_expert_parallel is True") + config.base_model_tp_plan = config.base_model_ep_plan # TODO: hack for now with ContextManagers(model_init_context): # Let's make sure we don't run the init function of buffer modules diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index a58edd581b37..f6124e7cb5b2 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -46,6 +46,21 @@ class OpenAIMoeConfig(PretrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + # "layers.*.self_attn.q_proj": "colwise", + # "layers.*.self_attn.k_proj": "colwise", + # "layers.*.self_attn.v_proj": "colwise", + # "layers.*.self_attn.o_proj": "rowwise", + # "layers.*.self_attn.sinks": "local_rowwise", + + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj_bias": "grouped_gemm", + # TODO: i shouldn't have to do the above, but when removing it, it doesnt partition them + 'layers.*.mlp.token_dispatcher': "gather", + "layers.*.mlp.router": "ep_router", + } def __init__( self, From 9267f9d48c330d7236b9f41a121faa4be7ff1221 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 3 Jul 2025 15:24:15 +0000 Subject: [PATCH 110/342] . --- .../models/openai_moe/configuration_openai_moe.py | 6 +++--- .../models/openai_moe/modeling_openai_moe.py | 9 ++++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index f6124e7cb5b2..3b78a7381272 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -53,13 +53,13 @@ class OpenAIMoeConfig(PretrainedConfig): # "layers.*.self_attn.o_proj": "rowwise", # "layers.*.self_attn.sinks": "local_rowwise", + # TODO: i shouldn't have to do the above, but when removing it, it doesnt partition them + 'layers.*.mlp.token_dispatcher': "gather", + "layers.*.mlp.router": "ep_router", "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", "layers.*.mlp.experts.down_proj": "grouped_gemm", "layers.*.mlp.experts.down_proj_bias": "grouped_gemm", - # TODO: i shouldn't have to do the above, but when removing it, it doesnt partition them - 'layers.*.mlp.token_dispatcher': "gather", - "layers.*.mlp.router": "ep_router", } def __init__( diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 2b0c1ea2d701..1f72560edb15 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -22,6 +22,7 @@ import torch from torch import nn +from torch.nn import functional as F from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -126,16 +127,18 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) # (num_experts, batch_size, seq_len, hidden_size) return next_states, None -class TopKRouter(nn.Linear): # TODO: if i inherit from module it'll be annoying to attach hook to parent module +class TopKRouter(nn.Module): # TODO: if i inherit from module it'll be annoying to attach hook to parent module def __init__(self, config): - super().__init__(config.hidden_size, config.num_local_experts, bias=True) + super().__init__() self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + self.bias = nn.Parameter(torch.empty(self.num_experts)) def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = super().forward(hidden_states) # (seq_len, num_experts) + router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) # (num_experts, seq_len) From b683ce758460bc7fe4f333bdf1beeb7a3adf596a Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Fri, 4 Jul 2025 11:28:35 +0000 Subject: [PATCH 111/342] smol fixes --- src/transformers/integrations/tensor_parallel.py | 2 +- src/transformers/modeling_utils.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 3a9cbac56d11..2400d5254e4a 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -944,7 +944,7 @@ def shard_and_distribute_module( # Add hooks to the module if not done yet # add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) - if not getattr(module_to_tp, "_is_hooked", False): # this adds gather hook to layers.*.mlp.experts and skips the rest + if not getattr(module_to_tp, "_is_hooked", False): add_tensor_parallel_hooks_to_module( model, module_to_tp, tp_plan, param_name, current_shard_plan, device_mesh, parameter_name ) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a9b921ea01da..edd90536cda3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2092,10 +2092,9 @@ def post_init(self): ) # loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit - # device_mesh = self._device_mesh # TODO: can we attach device_mesh to model - device_mesh = torch.distributed.device_mesh.DeviceMesh.from_group(torch.distributed.distributed_c10d._get_default_group(), device_type="cuda") # TODO: + device_mesh = self.config.device_mesh for name, module in self.named_modules(): - if not getattr(module, "_is_hooked", False): # this adds gather hook to layers.*.mlp.experts and skips the rest + if not getattr(module, "_is_hooked", False): from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module add_tensor_parallel_hooks_to_module( model=self, @@ -4648,6 +4647,8 @@ def from_pretrained( raise ValueError("base_model_ep_plan is required when enable_expert_parallel is True") config.base_model_tp_plan = config.base_model_ep_plan # TODO: hack for now + config.device_mesh = device_mesh # Used in post_init + with ContextManagers(model_init_context): # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs) From fbd04f7b3e64761d1320053e09b9a8d9bcd6337a Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Fri, 4 Jul 2025 11:29:07 +0000 Subject: [PATCH 112/342] bring back TP applied to attn, still gives same generations --- .../openai_moe/configuration_openai_moe.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 3b78a7381272..bf8a9bef8cad 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -29,11 +29,11 @@ class OpenAIMoeConfig(PretrainedConfig): # Default tensor parallel plan for base model `OpenaiModel` # a bit special, but this seems to work alright base_model_tp_plan = { - # "layers.*.self_attn.q_proj": "colwise", - # "layers.*.self_attn.k_proj": "colwise", - # "layers.*.self_attn.v_proj": "colwise", - # "layers.*.self_attn.o_proj": "rowwise", - # "layers.*.self_attn.sinks": "local_rowwise", + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.sinks": "local_rowwise", "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", @@ -47,11 +47,11 @@ class OpenAIMoeConfig(PretrainedConfig): "norm": (["hidden_states"], ["hidden_states"]), } base_model_ep_plan = { - # "layers.*.self_attn.q_proj": "colwise", - # "layers.*.self_attn.k_proj": "colwise", - # "layers.*.self_attn.v_proj": "colwise", - # "layers.*.self_attn.o_proj": "rowwise", - # "layers.*.self_attn.sinks": "local_rowwise", + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.sinks": "local_rowwise", # TODO: i shouldn't have to do the above, but when removing it, it doesnt partition them 'layers.*.mlp.token_dispatcher': "gather", From 17c5682dd5be163e9b726e170e27fe06b9c979b8 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Fri, 4 Jul 2025 11:32:43 +0000 Subject: [PATCH 113/342] final cleanup --- hf_generate.py | 88 ---------------------------- src/transformers/generation/utils.py | 5 -- 2 files changed, 93 deletions(-) delete mode 100644 hf_generate.py diff --git a/hf_generate.py b/hf_generate.py deleted file mode 100644 index 6750d65085e2..000000000000 --- a/hf_generate.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -ENABLE_EXPERT_PARALLEL=1 TP_SIZE=8 torchrun --nproc_per_node=8 hf_generate.py -""" - -from transformers import AutoTokenizer, OpenAIMoeForCausalLM -import torch -import os -import logging -import torch.distributed as dist -from torch.distributed.tensor.experimental import context_parallel -from torch.nn.attention import SDPBackend, sdpa_kernel -from torch.distributed.device_mesh import DeviceMesh -from transformers.distributed import DistributedConfig - -# model_id = "ft-hf-o-c/random-checkpoint-converted-20b" -# openai = "/fsx/vb/pytorch-os-mini-final-quantized-moe-sharded" - -# model_id = "/scratch/pytorch-os-mini-final-quantized-moe-sharded_hf" -model_id = "pytorch-os-mini-final-quantized-moe-sharded_hf" - - -# torch.use_deterministic_algorithms(True) -torch.backends.cudnn.deterministic = True - -# Set up logging -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, -) -logger = logging.getLogger(__name__) - - -def main(): - tp_size = int(os.environ.get("TP_SIZE", 8)) - - # Initialize distributed environment - if "RANK" in os.environ and "WORLD_SIZE" in os.environ: - dist.init_process_group("nccl") - rank = dist.get_rank() - world_size = dist.get_world_size() - local_rank = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(local_rank) - - mesh = torch.arange(world_size).reshape(tp_size) - world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("tp",)) - tp_mesh = world_mesh["tp"] - logger.info(f"Created DeviceMesh: {world_mesh}") - logger.info( - f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, TP: {tp_mesh.get_local_rank()}" - ) - - tokenizer = AutoTokenizer.from_pretrained(model_id) - - # messages = [ - # {"role": "user", "content": "Who are you?"}, - # ] - # inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True) - # inputs = tokenizer("Hello! How are you?", return_tensors="pt") - inputs = tokenizer("Who are you? And who made you?", return_tensors="pt") - model = OpenAIMoeForCausalLM.from_pretrained( - model_id, - device_mesh=tp_mesh if dist.is_initialized() else None, - tp_plan="auto", - tp_size=tp_size, - torch_dtype=torch.bfloat16, - attn_implementation="eager", - # key_mapping={"mlp.router": "mlp.router.router"}, - distributed_config=DistributedConfig( - enable_expert_parallel=os.environ.get("ENABLE_EXPERT_PARALLEL", "0") == "1" - ) - ) - - logger.info(f"Model loaded onto device mesh: {tp_mesh}") - device = torch.device(f"cuda:{local_rank}") - logger.info(f"Using device: {device} for non-model tensors") - model.eval() - - outputs = model.generate(**inputs.to(model.device), max_new_tokens=100, use_cache=False, do_sample=False) - outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:]) - print(outputs[0]) - - if dist.is_initialized(): - dist.destroy_process_group() - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 23838076f043..713d57a8994d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3573,8 +3573,6 @@ def _sample( else: is_prefill = True - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained("pytorch-os-mini-final-quantized-moe-sharded_hf") while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -3633,9 +3631,6 @@ def _sample( else: next_tokens = torch.argmax(next_token_scores, dim=-1) - if torch.distributed.get_rank() == 0: - print(f"Generated token: {repr(tokenizer.decode(next_tokens))}, logprob: {next_token_logits.max()}") - # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) From a6b956f4f383b147733c5559ee26abc3f2c61c78 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Fri, 4 Jul 2025 11:36:11 +0000 Subject: [PATCH 114/342] . --- .../models/openai_moe/modeling_openai_moe.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 1f72560edb15..93166f15addb 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -66,12 +66,13 @@ class OpenAIMoeExperts(nn.Module): def __init__(self, config): super().__init__() self.intermediate_size = config.intermediate_size + self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.empty(config.num_local_experts, self.hidden_size, 2 * self.expert_dim)) - self.gate_up_proj_bias = nn.Parameter(torch.empty(config.num_local_experts, 2 * self.expert_dim)) - self.down_proj = nn.Parameter(torch.empty((config.num_local_experts, self.expert_dim, self.hidden_size))) - self.down_proj_bias = nn.Parameter(torch.empty(config.num_local_experts, self.hidden_size)) + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) self.alpha = 1.702 def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: @@ -127,7 +128,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) # (num_experts, batch_size, seq_len, hidden_size) return next_states, None -class TopKRouter(nn.Module): # TODO: if i inherit from module it'll be annoying to attach hook to parent module +class TopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok From 27f64f204cba1623496e9ef185f56e160e74305b Mon Sep 17 00:00:00 2001 From: 3outeille Date: Fri, 4 Jul 2025 12:19:21 +0000 Subject: [PATCH 115/342] fix qkv TP by passing around device_mesh instead of recreating it --- src/transformers/modeling_utils.py | 5 ++--- .../openai_moe/configuration_openai_moe.py | 10 +++++----- .../models/openai_moe/modeling_openai_moe.py | 18 +++++++++++------- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a9b921ea01da..8426420a80ba 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2052,7 +2052,7 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): self._no_split_modules = self._no_split_modules or [] - def post_init(self): + def post_init(self, device_mesh): """ A method executed at the end of each Transformer model initialization, to execute code that needs the model's modules properly initialized (such as weight initialization). @@ -2092,8 +2092,6 @@ def post_init(self): ) # loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit - # device_mesh = self._device_mesh # TODO: can we attach device_mesh to model - device_mesh = torch.distributed.device_mesh.DeviceMesh.from_group(torch.distributed.distributed_c10d._get_default_group(), device_type="cuda") # TODO: for name, module in self.named_modules(): if not getattr(module, "_is_hooked", False): # this adds gather hook to layers.*.mlp.experts and skips the rest from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module @@ -4650,6 +4648,7 @@ def from_pretrained( with ContextManagers(model_init_context): # Let's make sure we don't run the init function of buffer modules + model_kwargs["device_mesh"] = device_mesh model = cls(config, *model_args, **model_kwargs) # Make sure to tie the weights correctly diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 3b78a7381272..452a39422323 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -29,11 +29,11 @@ class OpenAIMoeConfig(PretrainedConfig): # Default tensor parallel plan for base model `OpenaiModel` # a bit special, but this seems to work alright base_model_tp_plan = { - # "layers.*.self_attn.q_proj": "colwise", - # "layers.*.self_attn.k_proj": "colwise", - # "layers.*.self_attn.v_proj": "colwise", - # "layers.*.self_attn.o_proj": "rowwise", - # "layers.*.self_attn.sinks": "local_rowwise", + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.sinks": "local_rowwise", "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 1f72560edb15..efd8ef7c8646 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -432,8 +432,8 @@ def _init_weights(self, module): class OpenAIMoeModel(OpenAIMoePreTrainedModel): _no_split_modules = ["OpenAIMoeDecoderLayer"] - def __init__(self, config: OpenAIMoeConfig): - super().__init__(config) + def __init__(self, config: OpenAIMoeConfig, **kwargs): + super().__init__(config, **kwargs) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -445,8 +445,10 @@ def __init__(self, config: OpenAIMoeConfig): self.rotary_emb = OpenAIMoeRotaryEmbedding(config=config) self.gradient_checkpointing = False + self.device_mesh = kwargs.get("device_mesh") + # Initialize weights and apply final processing - self.post_init() + self.post_init(self.device_mesh) def get_input_embeddings(self): return self.embed_tokens @@ -659,17 +661,19 @@ class OpenAIMoeForCausalLM(OpenAIMoePreTrainedModel, GenerationMixin): _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - def __init__(self, config): - super().__init__(config) - self.model = OpenAIMoeModel(config) + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.model = OpenAIMoeModel(config, **kwargs) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok + self.device_mesh = kwargs.get("device_mesh") + # Initialize weights and apply final processing - self.post_init() + self.post_init(self.device_mesh) def get_input_embeddings(self): return self.model.embed_tokens From 7f69dc18efb495875f4e64599ac824eb10f6dfbb Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Fri, 4 Jul 2025 15:20:38 +0000 Subject: [PATCH 116/342] fix: (hack) only attach hooks in post_init for EP --- src/transformers/modeling_utils.py | 35 ++++++++++++++++-------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index edd90536cda3..a221d3a00492 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -47,6 +47,7 @@ from transformers.utils import is_torchao_available + if is_torchao_available(): from torchao.quantization import Int4WeightOnlyConfig @@ -2090,22 +2091,23 @@ def post_init(self): raise ValueError( f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}" ) - - # loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit - device_mesh = self.config.device_mesh - for name, module in self.named_modules(): - if not getattr(module, "_is_hooked", False): - from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module - add_tensor_parallel_hooks_to_module( - model=self, - module=module, - tp_plan=self._tp_plan, - layer_name="", # TODO: make this optional? - current_module_plan=_get_parameter_tp_plan(parameter_name=name, tp_plan=self._tp_plan), - device_mesh=device_mesh, - parameter_name=None - ) - module._is_hooked = True + + if hasattr(self.config, "attach_module_hooks") and self.config.attach_module_hooks: + # loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit + device_mesh = self.config.device_mesh + for name, module in self.named_modules(): + if not getattr(module, "_is_hooked", False): + from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module + add_tensor_parallel_hooks_to_module( + model=self, + module=module, + tp_plan=self._tp_plan, + layer_name="", # TODO: make this optional? + current_module_plan=_get_parameter_tp_plan(parameter_name=name, tp_plan=self._tp_plan), + device_mesh=device_mesh, + parameter_name=None + ) + module._is_hooked = True def dequantize(self): """ @@ -4646,6 +4648,7 @@ def from_pretrained( if config.base_model_ep_plan is None: raise ValueError("base_model_ep_plan is required when enable_expert_parallel is True") config.base_model_tp_plan = config.base_model_ep_plan # TODO: hack for now + config.attach_module_hooks = True # TODO: hack for now config.device_mesh = device_mesh # Used in post_init From 14cc0d7f3f0cd7ec9fdffaf2356cf1650d86792c Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 4 Jul 2025 16:11:08 +0000 Subject: [PATCH 117/342] fix --- src/transformers/trainer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7b5b5589ed70..6410d55fb2f5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2306,9 +2306,7 @@ def _inner_training_loop( else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - delay_optimizer_creation = ( - is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled or self.is_tp_enabled - ) + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404 is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2) @@ -2368,7 +2366,7 @@ def _inner_training_loop( if self.use_apex: model = self.accelerator.prepare(self.model) else: - if delay_optimizer_creation: + if self.is_tp_enabled: self.optimizer = self.accelerator.prepare(self.optimizer) else: model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) @@ -3923,6 +3921,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) self.model_wrapped.save_checkpoint(output_dir) + # TODO: check why this is failing if we don't remove that # elif self.args.should_save: self._save(output_dir) From c9d52e7100924261c4f3bc6d9a106e5e4611d55c Mon Sep 17 00:00:00 2001 From: 3outeille Date: Sat, 5 Jul 2025 13:42:19 +0000 Subject: [PATCH 118/342] breaking: begin to add vocab parallel embedding --- .../integrations/tensor_parallel.py | 38 +++++++++++++++++++ .../openai_moe/configuration_openai_moe.py | 11 +++--- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 3a9cbac56d11..6d2eeba357c9 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -469,6 +469,43 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False) return param +class VocabParallel(TensorParallelLayer): + """ + VocabParallel is used to shard the embedding table. + """ + + def __init__(self, use_dtensor=True): + super().__init__() + # Set the output layouts to trigger all-reduce + self.output_layouts = (Replicate(),) + self.use_local_output = True + self.use_dtensor = use_dtensor + + @staticmethod + def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + print("VocabParallel _prepare_input_fn") + return inputs + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + print("VocabParallel _prepare_output_fn") + # The forward pass should produce a DTensor with Shard(0) placement + # This redistribute call will automatically perform all-reduce when going from Shard(0) -> Replicate() + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=False) + return outputs.to_local() if use_local_output else outputs + + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + #TODO(3outeille): if several device_mesh dims, should be shard given TP axes + # Shard the embedding table along dim 0 (vocab dimension) + parameter = get_tensor_shard(param, empty_param, device_mesh, rank, 0) + shard = [Shard(0)] + parameter = parameter.to(param_casting_dtype) + if to_contiguous: + parameter = parameter.contiguous() + if self.use_dtensor: + parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False) + return nn.Parameter(parameter) class ColwiseParallel(TensorParallelLayer): """ @@ -807,6 +844,7 @@ class ParallelInterface(GeneralInterface): # a new instance is created (in order to locally override a given entry) _global_mapping = ( { + "vocab_parallel": VocabParallel(), "colwise": ColwiseParallel(), "rowwise": RowwiseParallel(), "colwise_rep": ColwiseParallel(output_layouts=Replicate()), diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 452a39422323..f8353cb1b0a0 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -29,17 +29,18 @@ class OpenAIMoeConfig(PretrainedConfig): # Default tensor parallel plan for base model `OpenaiModel` # a bit special, but this seems to work alright base_model_tp_plan = { + "embed_tokens": "vocab_parallel", "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.self_attn.sinks": "local_rowwise", - "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", - "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", - "layers.*.mlp.experts.down_proj": "local_colwise", - "layers.*.mlp.experts.down_proj_bias": "local", # TODO: add smthg that says bias exists only once for all TPs - "layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output + # "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", + # "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", + # "layers.*.mlp.experts.down_proj": "local_colwise", + # "layers.*.mlp.experts.down_proj_bias": "local", # TODO: add smthg that says bias exists only once for all TPs + # "layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), From 5053bd75a9d0ba893197b045a4a3de3f37665bdd Mon Sep 17 00:00:00 2001 From: 3outeille Date: Sat, 5 Jul 2025 16:04:04 +0000 Subject: [PATCH 119/342] breaking: add masking + use partial to trigger all_reduce instead of all_gather --- .../integrations/tensor_parallel.py | 68 +++++++++++++++---- .../openai_moe/configuration_openai_moe.py | 10 +-- 2 files changed, 58 insertions(+), 20 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 6d2eeba357c9..0088c8402f1b 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -36,7 +36,7 @@ if is_torch_greater_or_equal("2.5") and _torch_distributed_available: - from torch.distributed.tensor import DTensor, Placement, Replicate, Shard + from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, Partial def initialize_tensor_parallelism(tp_plan, tp_size=None): @@ -474,37 +474,75 @@ class VocabParallel(TensorParallelLayer): VocabParallel is used to shard the embedding table. """ - def __init__(self, use_dtensor=True): + def __init__( + self, + *, + input_layouts: Optional[Placement] = None, + output_layouts: Optional[Placement] = None, + use_local_output: bool = True, + use_dtensor=True,): super().__init__() - # Set the output layouts to trigger all-reduce - self.output_layouts = (Replicate(),) - self.use_local_output = True + self.input_layouts = (input_layouts or Replicate(),) + self.desired_input_layouts = (Replicate(),) + self.output_layouts = (output_layouts or Replicate(),) + self.use_local_output = use_local_output self.use_dtensor = use_dtensor @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): - print("VocabParallel _prepare_input_fn") - return inputs + rank = device_mesh.get_rank() + world_size = device_mesh.size() + vocab_size = mod.num_embeddings + local_vocab_size = vocab_size // world_size + vocab_start = rank * local_vocab_size + vocab_end = vocab_start + local_vocab_size + + input_tensor = inputs if isinstance(inputs, torch.Tensor) else inputs[0] + + input_mask = (input_tensor < vocab_start) | (input_tensor >= vocab_end) + masked_input = input_tensor.clone() - vocab_start + masked_input[input_mask] = 0 # Set invalid tokens to 0 + # Store the mask for use in _prepare_output_fn + mod._vocab_parallel_input_mask = input_mask + if not isinstance(input_tensor, DTensor): + masked_input = DTensor.from_local(masked_input, device_mesh, input_layouts, run_check=False) + + return masked_input @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): - print("VocabParallel _prepare_output_fn") - # The forward pass should produce a DTensor with Shard(0) placement - # This redistribute call will automatically perform all-reduce when going from Shard(0) -> Replicate() - if outputs.placements != output_layouts: - outputs = outputs.redistribute(placements=output_layouts, async_op=False) - return outputs.to_local() if use_local_output else outputs + output_tensor = outputs if isinstance(outputs, DTensor) else outputs[0] + + # Retrieve the input mask from the module + input_mask = getattr(mod, '_vocab_parallel_input_mask', None) + if input_mask is not None: + #TODO(3outeille): double check if there is another workaround + if isinstance(output_tensor, DTensor): + local_tensor = output_tensor.to_local() + input_mask = input_mask.to(local_tensor.device) + local_tensor[input_mask] = 0.0 + else: + output_tensor[input_mask] = 0.0 + + if not isinstance(output_tensor, DTensor): + output_tensor = DTensor.from_local(output_tensor, device_mesh, output_layouts, run_check=False) + + # The forward pass should produce a DTensor with Partial() placement + # This redistribute call will automatically perform all-reduce when going from Partial() -> Replicate() + if output_tensor.placements != output_layouts: + output_tensor = output_tensor.redistribute(placements=output_layouts, async_op=False) + return output_tensor.to_local() if use_local_output else output_tensor def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): #TODO(3outeille): if several device_mesh dims, should be shard given TP axes # Shard the embedding table along dim 0 (vocab dimension) parameter = get_tensor_shard(param, empty_param, device_mesh, rank, 0) - shard = [Shard(0)] + placements = [Partial()] parameter = parameter.to(param_casting_dtype) if to_contiguous: parameter = parameter.contiguous() if self.use_dtensor: - parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False) + parameter = DTensor.from_local(parameter, device_mesh, placements, run_check=False) return nn.Parameter(parameter) class ColwiseParallel(TensorParallelLayer): diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index f8353cb1b0a0..289329285842 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -36,11 +36,11 @@ class OpenAIMoeConfig(PretrainedConfig): "layers.*.self_attn.o_proj": "rowwise", "layers.*.self_attn.sinks": "local_rowwise", - # "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", - # "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", - # "layers.*.mlp.experts.down_proj": "local_colwise", - # "layers.*.mlp.experts.down_proj_bias": "local", # TODO: add smthg that says bias exists only once for all TPs - # "layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output + "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", + "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", + "layers.*.mlp.experts.down_proj": "local_colwise", + "layers.*.mlp.experts.down_proj_bias": "local", # TODO: add smthg that says bias exists only once for all TPs + "layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), From 0e176d07d265fbb64d5e34e9e41a371813a9fd32 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 7 Jul 2025 11:49:05 +0200 Subject: [PATCH 120/342] Apply suggestions from code review --- src/transformers/models/openai_moe/modeling_openai_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 93166f15addb..7829739057dd 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -164,6 +164,7 @@ def __init__(self, config): self.router = TopKRouter(config) self.experts = OpenAIMoeExperts(config) self.token_dispatcher = TokenDispatcher(config) + def forward(self, hidden_states): # we don't slice weight as its not compile compatible router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) From c7e694bc115e0c05c59590e9d42cc0f71d01aa77 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Mon, 7 Jul 2025 11:59:51 +0200 Subject: [PATCH 121/342] Update src/transformers/modeling_utils.py --- src/transformers/modeling_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a221d3a00492..e9f705f6dcac 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4648,7 +4648,6 @@ def from_pretrained( if config.base_model_ep_plan is None: raise ValueError("base_model_ep_plan is required when enable_expert_parallel is True") config.base_model_tp_plan = config.base_model_ep_plan # TODO: hack for now - config.attach_module_hooks = True # TODO: hack for now config.device_mesh = device_mesh # Used in post_init From 5928038741c6f8721e6b9a7394a28e5d17e6a38b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 7 Jul 2025 10:15:58 +0000 Subject: [PATCH 122/342] update modular to record router logits and disable sdpa --- .../models/openai_moe/modular_openai_moe.py | 41 ++++--------------- 1 file changed, 7 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 511e622196ef..2aa2acaf2320 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -206,39 +206,6 @@ def eager_attention_forward( attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights -# def openai_flex_attention_forward( -# module: nn.Module, -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# attention_mask: Optional[torch.Tensor], -# scaling: float, -# dropout: float = 0.0, -# **kwargs, -# ): -# sinks = module.sinks.view(1, -1, 1, 1).expand(-1, -1, key.shape[-2], -1) - -# def attention_sink(score, b, h, q_idx, kv_idx): -# score = torch.cat([score, sinks], dim=-1) -# return score - -# # TODO I need to remove the -1 sinks -# return flex_attention_forward( -# module, -# query, -# key, -# value, -# attention_mask, -# scaling=scaling, -# dropout=dropout, -# attention_sink=attention_sink, -# score_mod=attention_sink, -# **kwargs, -# ) - -# ALL_ATTENTION_FUNCTIONS.register("openai_flex_attention", openai_flex_attention_forward) - - class OpenAIMoeAttention(Qwen2Attention): def __init__(self, config: OpenAIMoeConfig, layer_idx: int): super().__init__(config, layer_idx) @@ -312,7 +279,13 @@ def forward( class OpenAIMoePreTrainedModel(LlamaPreTrainedModel): _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] - + _supports_sdpa = False + _supports_flex_attention = False + _can_record_outputs = { + "router_logits": OpenAIMoeExperts, + "hidden_states": OpenAIMoeDecoderLayer, + "attentions": OpenAIMoeAttention + } def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): From ade2cf860ea72a823f689d346d5a39424f157e0c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 7 Jul 2025 10:20:36 +0000 Subject: [PATCH 123/342] cleanup --- .../models/openai_moe/modular_openai_moe.py | 77 +++---------------- 1 file changed, 9 insertions(+), 68 deletions(-) diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 2aa2acaf2320..719510fce24e 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -19,13 +19,12 @@ from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( MoeModelOutputWithPast, ) from ...modeling_rope_utils import dynamic_rope_update from ...processing_utils import Unpack -from ...utils import auto_docstring, can_return_tuple, logging +from ...utils import auto_docstring, can_return_tuple, logging, TransformersKwargs, OutputRecorder from ..llama.modeling_llama import ( LlamaDecoderLayer, LlamaPreTrainedModel, @@ -240,17 +239,14 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[TransformersKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -262,19 +258,11 @@ def forward( **kwargs, ) hidden_states = residual + hidden_states - - # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states, router_logits = self.mlp(hidden_states) + hidden_states, _ = self.mlp(hidden_states) hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - if kwargs.get("output_router_logits", False): - outputs += (router_logits,) - return outputs + return hidden_states class OpenAIMoePreTrainedModel(LlamaPreTrainedModel): @@ -282,7 +270,7 @@ class OpenAIMoePreTrainedModel(LlamaPreTrainedModel): _supports_sdpa = False _supports_flex_attention = False _can_record_outputs = { - "router_logits": OpenAIMoeExperts, + "router_logits": OutputRecorder(OpenAIMoeMLP, index=1), "hidden_states": OpenAIMoeDecoderLayer, "attentions": OpenAIMoeAttention } @@ -310,7 +298,7 @@ def _init_weights(self, module): class OpenAIMoeModel(MixtralModel): _no_split_modules = ["OpenAIMoeDecoderLayer"] - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, @@ -320,31 +308,12 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + **kwargs: Unpack[TransformersKwargs], ) -> MoeModelOutputWithPast: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - if use_cache and past_key_values is None: past_key_values = DynamicCache() @@ -361,7 +330,6 @@ def forward( # It may already have been prepared by e.g. `generate` if not isinstance(causal_mask_mapping := attention_mask, dict): - # Prepare mask arguments mask_kwargs = { "config": self.config, "input_embeds": inputs_embeds, @@ -369,27 +337,16 @@ def forward( "cache_position": cache_position, "past_key_values": past_key_values, } - # Create the masks causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), } hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, @@ -401,26 +358,10 @@ def forward( position_embeddings=position_embeddings, **flash_attn_kwargs, ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if output_router_logits: - all_router_logits += (layer_outputs[-1],) - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, ) From 7ce28ada4e22ccd6608d5c5d30470d4c6ab72a8c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 7 Jul 2025 10:22:52 +0000 Subject: [PATCH 124/342] fix modeling code as well! --- .../models/openai_moe/modeling_openai_moe.py | 121 +++++------------- 1 file changed, 34 insertions(+), 87 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 2678e4f125f6..b002784d05fa 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -18,7 +18,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import torch from torch import nn @@ -33,13 +33,11 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging +from ...utils import OutputRecorder, TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import check_model_inputs from .configuration_openai_moe import OpenAIMoeConfig -logger = logging.get_logger(__name__) - - @use_kernel_forward_from_hub("RMSNorm") class OpenAIMoeRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -123,7 +121,6 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig return next_states -@use_kernel_forward_from_hub("MegaBlocksMoeMLP") class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() @@ -236,9 +233,14 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = torch.cat([attn_weights, sinks], dim=-1) - attn_weights = torch.softmax(attn_weights, dim=-1) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + # scale the logits to prevent overflows + logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values + sinks = torch.exp(sinks - logits_max) + unnormalized_scores = torch.exp(attn_weights - logits_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights[..., :-1], value_states) # ignore the sinks attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights @@ -274,12 +276,12 @@ def __init__(self, config: OpenAIMoeConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], + position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -332,17 +334,14 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -354,19 +353,11 @@ def forward( **kwargs, ) hidden_states = residual + hidden_states - - # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states, router_logits = self.mlp(hidden_states) + hidden_states, _ = self.mlp(hidden_states) hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - if kwargs.get("output_router_logits", False): - outputs += (router_logits,) - return outputs + return hidden_states @auto_docstring @@ -377,13 +368,19 @@ class OpenAIMoePreTrainedModel(PreTrainedModel): _no_split_modules = ["OpenAIMoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True - _supports_sdpa = True + _supports_sdpa = False _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(OpenAIMoeMLP, index=1), + "hidden_states": OpenAIMoeDecoderLayer, + "attentions": OpenAIMoeAttention, + } _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] + _supports_flex_attention = False def _init_weights(self, module): std = self.config.initializer_range @@ -432,41 +429,22 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + **kwargs: Unpack[TransformersKwargs], ) -> MoeModelOutputWithPast: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - if use_cache and past_key_values is None: past_key_values = DynamicCache() @@ -483,7 +461,6 @@ def forward( # It may already have been prepared by e.g. `generate` if not isinstance(causal_mask_mapping := attention_mask, dict): - # Prepare mask arguments mask_kwargs = { "config": self.config, "input_embeds": inputs_embeds, @@ -491,27 +468,16 @@ def forward( "cache_position": cache_position, "past_key_values": past_key_values, } - # Create the masks causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), } hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, @@ -523,34 +489,15 @@ def forward( position_embeddings=position_embeddings, **flash_attn_kwargs, ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if output_router_logits: - all_router_logits += (layer_outputs[-1],) - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, ) -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - def load_balancing_loss_func( - gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], num_experts: Optional[int] = None, top_k=2, attention_mask: Optional[torch.Tensor] = None, @@ -558,7 +505,7 @@ def load_balancing_loss_func( r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. - See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced. @@ -674,7 +621,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -683,7 +630,7 @@ def forward( output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], + **kwargs: Unpack[TransformersKwargs], ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): From 0967a5404a27917ae84b09924d52823823456e1d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 7 Jul 2025 10:50:40 +0000 Subject: [PATCH 125/342] update modular use kernel --- src/transformers/models/openai_moe/modular_openai_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 719510fce24e..2d8691f49ba9 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -111,6 +111,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig return next_states +@use_kernel_forward_from_hub("MegaBlocksMoeMLP") class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() From ea3b4f5ce380aeee0b870aa018c170d9e2816147 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 7 Jul 2025 12:51:42 +0200 Subject: [PATCH 126/342] Update src/transformers/models/openai_moe/modeling_openai_moe.py --- src/transformers/models/openai_moe/modeling_openai_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index b002784d05fa..9d24bd9984ef 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -121,6 +121,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig return next_states +@use_kernel_forward_from_hub("MegaBlocksMoeMLP") class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() From 3051a73975f7e8d607725ca2ac0f98e8ca66e5e6 Mon Sep 17 00:00:00 2001 From: 3outeille Date: Mon, 7 Jul 2025 14:54:44 +0000 Subject: [PATCH 127/342] fix generation (but not exactly matching yet) --- src/transformers/integrations/tensor_parallel.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 0088c8402f1b..d6a46d830505 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -525,10 +525,8 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me output_tensor[input_mask] = 0.0 if not isinstance(output_tensor, DTensor): - output_tensor = DTensor.from_local(output_tensor, device_mesh, output_layouts, run_check=False) + output_tensor = DTensor.from_local(output_tensor, device_mesh, [Shard(0)], run_check=False) - # The forward pass should produce a DTensor with Partial() placement - # This redistribute call will automatically perform all-reduce when going from Partial() -> Replicate() if output_tensor.placements != output_layouts: output_tensor = output_tensor.redistribute(placements=output_layouts, async_op=False) return output_tensor.to_local() if use_local_output else output_tensor @@ -537,7 +535,7 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, #TODO(3outeille): if several device_mesh dims, should be shard given TP axes # Shard the embedding table along dim 0 (vocab dimension) parameter = get_tensor_shard(param, empty_param, device_mesh, rank, 0) - placements = [Partial()] + placements = [Shard(0)] parameter = parameter.to(param_casting_dtype) if to_contiguous: parameter = parameter.contiguous() From bd52d615f41fb039aa65b6718e99fc8a402b98ce Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 7 Jul 2025 17:03:59 +0200 Subject: [PATCH 128/342] Update src/transformers/models/openai_moe/modular_openai_moe.py --- src/transformers/models/openai_moe/modular_openai_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 2d8691f49ba9..6fc2f2570e89 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -202,7 +202,7 @@ def eager_attention_forward( scores = unnormalized_scores / normalizer attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights[..., :-1], value_states) # ignore the sinks + attn_output = torch.matmul(attn_weights, value_states) # ignore the sinks attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights From 8bc3475f99196f2e1d3666ebb24a28f928248e46 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 7 Jul 2025 17:04:38 +0200 Subject: [PATCH 129/342] Update src/transformers/models/openai_moe/modeling_openai_moe.py --- src/transformers/models/openai_moe/modeling_openai_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 9d24bd9984ef..16fe462a0cc7 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -242,7 +242,7 @@ def eager_attention_forward( scores = unnormalized_scores / normalizer attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights[..., :-1], value_states) # ignore the sinks + attn_output = torch.matmul(attn_weights, value_states) # ignore the sinks attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights From d15a5ff99d0793b94cad21e9138c7f085a6e1d74 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 7 Jul 2025 17:56:55 +0200 Subject: [PATCH 130/342] Update src/transformers/models/openai_moe/modeling_openai_moe.py --- src/transformers/models/openai_moe/modeling_openai_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 7829739057dd..7783eecc57d7 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -320,7 +320,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - raise ValueError(f"Attention implementation {self.config._attn_implementation} doesn't support sinks") attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( From 98ab5956fb55746896fb420e9f4b2ce044debc7d Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 7 Jul 2025 16:02:09 +0000 Subject: [PATCH 131/342] removed for now --- src/transformers/trainer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e125ae91a7d8..540729889e27 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3937,10 +3937,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa # remove the dummy state_dict remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) self.model_wrapped.save_checkpoint(output_dir) - - # TODO: check why this is failing if we don't remove that - # elif self.args.should_save: - self._save(output_dir) + + elif self.args.should_save: + self._save(output_dir) # Push to the Hub when `save_model` is called by the user. if self.args.push_to_hub and not _internal_call: From eec466e6d00414f5f461d6c649a65513735b5d04 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 7 Jul 2025 16:05:10 +0000 Subject: [PATCH 132/342] fix --- src/transformers/trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 540729889e27..add87f3a8e18 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2294,9 +2294,7 @@ def _inner_training_loop( else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - delay_optimizer_creation = ( - is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled or self.is_tp_enabled - ) + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404 is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2) From d67c5cfc13898ecebf5587babfd7c5bd5d686125 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 7 Jul 2025 18:06:22 +0200 Subject: [PATCH 133/342] Update src/transformers/modeling_utils.py --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e9f705f6dcac..6578a42ace56 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2092,7 +2092,7 @@ def post_init(self): f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}" ) - if hasattr(self.config, "attach_module_hooks") and self.config.attach_module_hooks: + if is_torch_greater_or_equal("2.5") and _torch_distributed_available: # loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit device_mesh = self.config.device_mesh for name, module in self.named_modules(): From a85634c1f275a23eb6f90a91015d32fe006b72af Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 7 Jul 2025 16:24:56 +0000 Subject: [PATCH 134/342] make sure modular matches! --- .../models/openai_moe/modular_openai_moe.py | 73 ++++++++++++------- 1 file changed, 48 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 6fc2f2570e89..9a3dc80ab0e5 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -52,8 +52,8 @@ def forward(self, hidden_states): class OpenAIMoeExperts(nn.Module): def __init__(self, config): super().__init__() - self.num_experts = config.num_local_experts self.intermediate_size = config.intermediate_size + self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) @@ -70,16 +70,19 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. Args: - hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) selected_experts (torch.Tensor): (batch_size * token_num, top_k) routing_weights (torch.Tensor): (batch_size * token_num, top_k) Returns: torch.Tensor """ + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + num_experts = routing_weights.shape[0] if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute( + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts).permute( 2, 1, 0 ) expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() @@ -100,42 +103,62 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig ) # (num_tokens, hidden_dim) weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) + next_states = next_states.view(batch_size, -1, self.hidden_size) else: - hidden_states = hidden_states.repeat(self.num_experts, 1) - hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors glu = gate * torch.sigmoid(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj) + self.down_proj_bias[..., None, :] - next_states = next_states.view(-1, self.hidden_size) - return next_states + next_states = torch.bmm(((up + 1) * glu), self.down_proj) + next_states = next_states + self.down_proj_bias[..., None, :] + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) # (num_experts, batch_size, seq_len, hidden_size) + return next_states, None +class TopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + self.bias = nn.Parameter(torch.empty(self.num_experts)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) # (num_experts, seq_len) + return router_scores, router_indices + +class TokenDispatcher(nn.Module): + # this module is important to add EP hook + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + def forward(self, routed_out, routing_weights): + # routed_out is (num_experts, batch_size, seq_len, hidden_size) + routed_out = routed_out * routing_weights[:, None, :, None] # we're throwing away computed routed_out for rest of experts + routed_out = routed_out.sum(dim=0) # (batch_size, seq_len, hidden_size) + return routed_out @use_kernel_forward_from_hub("MegaBlocksMoeMLP") class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() - self.top_k = config.num_experts_per_tok - self.hidden_dim = config.hidden_size - self.num_local_experts = config.num_local_experts + self.router = TopKRouter(config) self.experts = OpenAIMoeExperts(config) - self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=True) + self.token_dispatcher = TokenDispatcher(config) def forward(self, hidden_states): # we don't slice weight as its not compile compatible - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.router(hidden_states) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) - routed_out = self.experts(hidden_states, router_indices, router_top_value) - if self.training: - output_states = routed_out.view(batch_size, -1, self.hidden_dim) - else: - routed_out = routed_out.view(self.num_local_experts, -1, self.hidden_dim) * router_scores[..., None] - output_states = routed_out.view(self.num_local_experts, batch_size, -1, self.hidden_dim).sum(dim=0) - return output_states, router_scores + router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) + routed_out, _ = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) #TODO: router_indices isn't used inside this func + hidden_states = self.token_dispatcher(routed_out, router_scores) + return hidden_states, router_scores class OpenAIMoeRotaryEmbedding(LlamaRotaryEmbedding): From a90469637d53c27515e37bb18cbbab865862a8c7 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 7 Jul 2025 16:49:08 +0000 Subject: [PATCH 135/342] last nit --- src/transformers/models/openai_moe/modular_openai_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 9a3dc80ab0e5..4df8e552873d 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -16,6 +16,7 @@ import torch from torch import nn +from torch.nn import functional as F from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask From 4e1b2db2e7c83776dca44e9d3a89eca9748e8df9 Mon Sep 17 00:00:00 2001 From: 3outeille Date: Mon, 7 Jul 2025 18:00:28 +0000 Subject: [PATCH 136/342] add eager_attention_forward patch --- .../models/openai_moe/modeling_openai_moe.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index efd8ef7c8646..c3b02906314d 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -257,10 +257,15 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = torch.cat([attn_weights, sinks], dim=-1) - attn_weights = torch.softmax(attn_weights, dim=-1) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights[..., :-1], value_states) # ignore the sinks + # scale the logits to prevent overflows + logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values + sinks = torch.exp(sinks - logits_max) + unnormalized_scores = torch.exp(attn_weights - logits_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights From e7aaf9e67a3aeaf32841be3c3333d04892ffa41a Mon Sep 17 00:00:00 2001 From: 3outeille Date: Mon, 7 Jul 2025 18:01:00 +0000 Subject: [PATCH 137/342] finally matching logits. no need to do input masking as embedding would be stored in `_MaskPartial` which handles it --- .../integrations/tensor_parallel.py | 67 ++++++------------- 1 file changed, 19 insertions(+), 48 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index d6a46d830505..39883d0d3b5f 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -470,9 +470,6 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, return param class VocabParallel(TensorParallelLayer): - """ - VocabParallel is used to shard the embedding table. - """ def __init__( self, @@ -480,59 +477,16 @@ def __init__( input_layouts: Optional[Placement] = None, output_layouts: Optional[Placement] = None, use_local_output: bool = True, - use_dtensor=True,): + use_dtensor=True, + ): super().__init__() self.input_layouts = (input_layouts or Replicate(),) self.desired_input_layouts = (Replicate(),) self.output_layouts = (output_layouts or Replicate(),) self.use_local_output = use_local_output self.use_dtensor = use_dtensor - - @staticmethod - def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): - rank = device_mesh.get_rank() - world_size = device_mesh.size() - vocab_size = mod.num_embeddings - local_vocab_size = vocab_size // world_size - vocab_start = rank * local_vocab_size - vocab_end = vocab_start + local_vocab_size - - input_tensor = inputs if isinstance(inputs, torch.Tensor) else inputs[0] - - input_mask = (input_tensor < vocab_start) | (input_tensor >= vocab_end) - masked_input = input_tensor.clone() - vocab_start - masked_input[input_mask] = 0 # Set invalid tokens to 0 - # Store the mask for use in _prepare_output_fn - mod._vocab_parallel_input_mask = input_mask - if not isinstance(input_tensor, DTensor): - masked_input = DTensor.from_local(masked_input, device_mesh, input_layouts, run_check=False) - return masked_input - - @staticmethod - def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): - output_tensor = outputs if isinstance(outputs, DTensor) else outputs[0] - - # Retrieve the input mask from the module - input_mask = getattr(mod, '_vocab_parallel_input_mask', None) - if input_mask is not None: - #TODO(3outeille): double check if there is another workaround - if isinstance(output_tensor, DTensor): - local_tensor = output_tensor.to_local() - input_mask = input_mask.to(local_tensor.device) - local_tensor[input_mask] = 0.0 - else: - output_tensor[input_mask] = 0.0 - - if not isinstance(output_tensor, DTensor): - output_tensor = DTensor.from_local(output_tensor, device_mesh, [Shard(0)], run_check=False) - - if output_tensor.placements != output_layouts: - output_tensor = output_tensor.redistribute(placements=output_layouts, async_op=False) - return output_tensor.to_local() if use_local_output else output_tensor - def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): - #TODO(3outeille): if several device_mesh dims, should be shard given TP axes # Shard the embedding table along dim 0 (vocab dimension) parameter = get_tensor_shard(param, empty_param, device_mesh, rank, 0) placements = [Shard(0)] @@ -543,6 +497,23 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, parameter = DTensor.from_local(parameter, device_mesh, placements, run_check=False) return nn.Parameter(parameter) + @staticmethod + def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + #NOTE(3outeille): no need to do input masking as embedding would be stored in `_MaskPartial` which handles it (https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_embedding_ops.py#L70) + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + return input_tensor + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) + return outputs.to_local() if use_local_output else outputs + class ColwiseParallel(TensorParallelLayer): """ General tensor parallel layer for transformers. From a3cdac2bfa5a6fed7a2280077306f16b2d8a7cbd Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 8 Jul 2025 11:51:39 +0000 Subject: [PATCH 138/342] fix non TP cases for now --- .../integrations/tensor_parallel.py | 2 +- src/transformers/modeling_utils.py | 2 +- .../openai_moe/configuration_openai_moe.py | 3 +- .../models/openai_moe/modeling_openai_moe.py | 59 +++++-------------- 4 files changed, 18 insertions(+), 48 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 18c15fa37353..b736703d8da7 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -812,7 +812,7 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, class RouterParallel(TensorParallelLayer): """ - Applies Expert Parallelism to MoE router + Allows to reshape the router scores to support running expert parallel. """ def __init__(self, *args, **kwargs): self.args = args diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ffbe79c7d1d7..95575e310000 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2151,7 +2151,7 @@ def post_init(self): f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}" ) - if is_torch_greater_or_equal("2.5") and _torch_distributed_available: + if is_torch_greater_or_equal("2.5") and _torch_distributed_available and self.config.device_mesh is not None: # loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit device_mesh = self.config.device_mesh for name, module in self.named_modules(): diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index bf8a9bef8cad..dbde38eb8c7b 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -54,7 +54,8 @@ class OpenAIMoeConfig(PretrainedConfig): "layers.*.self_attn.sinks": "local_rowwise", # TODO: i shouldn't have to do the above, but when removing it, it doesnt partition them - 'layers.*.mlp.token_dispatcher': "gather", + "layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output + # 'layers.*.mlp.token_dispatcher': "gather", "layers.*.mlp.router": "ep_router", "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index fe885c732235..9c0df024e5c2 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -28,14 +28,13 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import OutputRecorder, TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import check_model_inputs, OutputRecorder from .configuration_openai_moe import OpenAIMoeConfig @@ -115,16 +114,19 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) next_states = next_states.view(batch_size, -1, self.hidden_size) + routing_weights = torch.ones_like(next_states) else: hidden_states = hidden_states.repeat(num_experts, 1) hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + gate, up = gate_up.chunk(2, dim=-1) glu = gate * torch.sigmoid(gate * self.alpha) next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) # (num_experts, batch_size, seq_len, hidden_size) - return next_states, None + next_states = next_states * routing_weights.view(num_experts, batch_size, -1)[...,None] + next_states = next_states.sum(dim=0) + return next_states, routing_weights class TopKRouter(nn.Module): def __init__(self, config): @@ -143,33 +145,17 @@ def forward(self, hidden_states): router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) # (num_experts, seq_len) return router_scores, router_indices -class TokenDispatcher(nn.Module): - # this module is important to add EP hook - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - - def forward(self, routed_out, routing_weights): - # routed_out is (num_experts, batch_size, seq_len, hidden_size) - routed_out = routed_out * routing_weights[:, None, :, None] # we're throwing away computed routed_out for rest of experts - routed_out = routed_out.sum(dim=0) # (batch_size, seq_len, hidden_size) - return routed_out - @use_kernel_forward_from_hub("MegaBlocksMoeMLP") class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() self.router = TopKRouter(config) self.experts = OpenAIMoeExperts(config) - self.token_dispatcher = TokenDispatcher(config) def forward(self, hidden_states): - # we don't slice weight as its not compile compatible router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) - routed_out, _ = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) #TODO: router_indices isn't used inside this func - hidden_states = self.token_dispatcher(routed_out, router_scores) - return hidden_states, router_scores + routed_out, router_weights = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) #TODO: router_indices isn't used inside this func + return routed_out, router_scores class OpenAIMoeRotaryEmbedding(nn.Module): @@ -305,8 +291,8 @@ def forward( attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -371,7 +357,6 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -492,6 +477,7 @@ def forward( "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids } causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), @@ -507,12 +493,10 @@ def forward( attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - **flash_attn_kwargs, + **kwargs, ) hidden_states = self.norm(hidden_states) return MoeModelOutputWithPast( @@ -555,7 +539,7 @@ def load_balancing_loss_func( if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device).transpose(0,1) for layer_gate in gate_logits], dim=0) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) @@ -572,7 +556,6 @@ def load_balancing_loss_func( else: batch_size, sequence_length = attention_mask.shape num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) - # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] @@ -650,8 +633,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, @@ -680,16 +661,6 @@ def forward( "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -697,8 +668,6 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, cache_position=cache_position, **kwargs, From b04fde8d565e507ad6d4db088a53777ae8c67174 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 8 Jul 2025 11:57:32 +0000 Subject: [PATCH 139/342] default dict not working as well as I thought --- src/transformers/utils/generic.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 5326d48d748b..b3615128d558 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -1030,7 +1030,10 @@ def wrapped_forward(*args, **kwargs): if not isinstance(output, tuple): collected_outputs[key] += (output,) elif output[index] is not None: - collected_outputs[key] += (output[index],) + if key not in collected_outputs: + collected_outputs[key] =(output[index],) + else: + collected_outputs[key] += (output[index],) return output return wrapped_forward From 55d870385d389d5d9b9ecc5ae5191faa0950ddc2 Mon Sep 17 00:00:00 2001 From: 3outeille Date: Tue, 8 Jul 2025 12:15:28 +0000 Subject: [PATCH 140/342] make api more generic to handle lm_head --- .../integrations/tensor_parallel.py | 41 ++++++++++++------- .../openai_moe/configuration_openai_moe.py | 3 +- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 39883d0d3b5f..6b05080ee6f3 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -471,11 +471,17 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, class VocabParallel(TensorParallelLayer): + """ + VocabParallel is a tensor parallel layer that shards the embedding table along the last dimension. + No need to do input masking as embedding would be stored in `_MaskPartial` which handles it (https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_embedding_ops.py#L70) + """ + def __init__( self, *, input_layouts: Optional[Placement] = None, output_layouts: Optional[Placement] = None, + weight_dim_sharding: int = 0, use_local_output: bool = True, use_dtensor=True, ): @@ -485,33 +491,37 @@ def __init__( self.output_layouts = (output_layouts or Replicate(),) self.use_local_output = use_local_output self.use_dtensor = use_dtensor - - def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): - # Shard the embedding table along dim 0 (vocab dimension) - parameter = get_tensor_shard(param, empty_param, device_mesh, rank, 0) - placements = [Shard(0)] - parameter = parameter.to(param_casting_dtype) - if to_contiguous: - parameter = parameter.contiguous() - if self.use_dtensor: - parameter = DTensor.from_local(parameter, device_mesh, placements, run_check=False) - return nn.Parameter(parameter) + self.weight_dim_sharding = weight_dim_sharding @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): - #NOTE(3outeille): no need to do input masking as embedding would be stored in `_MaskPartial` which handles it (https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_embedding_ops.py#L70) input_tensor = inputs[0] if not isinstance(input_tensor, DTensor): input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) if input_layouts != desired_input_layouts: - input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False) return input_tensor + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + # Shard the embedding table along dim 0 (vocab dimension) + if param_type == "bias": + parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1) + placements = [Shard(-1)] + else: + parameter = get_tensor_shard(param, empty_param, device_mesh, rank, self.weight_dim_sharding) + placements = [Shard(self.weight_dim_sharding)] + parameter = parameter.to(param_casting_dtype) + if to_contiguous: + parameter = parameter.contiguous() + if self.use_dtensor: + parameter = DTensor.from_local(parameter, device_mesh, placements, run_check=False) + return nn.Parameter(parameter) + @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): if outputs.placements != output_layouts: - outputs = outputs.redistribute(placements=output_layouts, async_op=True) + outputs = outputs.redistribute(placements=output_layouts, async_op=False) return outputs.to_local() if use_local_output else outputs class ColwiseParallel(TensorParallelLayer): @@ -851,7 +861,8 @@ class ParallelInterface(GeneralInterface): # a new instance is created (in order to locally override a given entry) _global_mapping = ( { - "vocab_parallel": VocabParallel(), + "vocab_parallel_rowwise": VocabParallel(weight_dim_sharding=0), + "vocab_parallel_colwise": VocabParallel(weight_dim_sharding=-2, output_layouts=Replicate()), "colwise": ColwiseParallel(), "rowwise": RowwiseParallel(), "colwise_rep": ColwiseParallel(output_layouts=Replicate()), diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 289329285842..47a83a4aa659 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -29,7 +29,7 @@ class OpenAIMoeConfig(PretrainedConfig): # Default tensor parallel plan for base model `OpenaiModel` # a bit special, but this seems to work alright base_model_tp_plan = { - "embed_tokens": "vocab_parallel", + "embed_tokens": "vocab_parallel_rowwise", "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", @@ -41,6 +41,7 @@ class OpenAIMoeConfig(PretrainedConfig): "layers.*.mlp.experts.down_proj": "local_colwise", "layers.*.mlp.experts.down_proj_bias": "local", # TODO: add smthg that says bias exists only once for all TPs "layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output + "lm_head": "vocab_parallel_colwise" } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), From 30b599a853fb28d4f5fc43e4888e223264e4ac31 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 8 Jul 2025 13:03:19 +0000 Subject: [PATCH 141/342] enable flash --- .../modeling_flash_attention_utils.py | 19 +++++++++++++++++-- src/transformers/modeling_utils.py | 2 +- .../models/openai_moe/modeling_openai_moe.py | 1 + 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 13da327dab00..20da32ac30a1 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -26,10 +26,15 @@ is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_torch_npu_available, + is_kernels_available, logging, ) +if is_kernels_available(): + from kernels import get_kernel + + logger = logging.get_logger(__name__) flash_attn_func = None @@ -404,7 +409,7 @@ def fa_peft_integration_check( flash_241 = is_flash_attn_greater_or_equal("2.4.1") deterministic_g = None - +kernel = None def _flash_attention_forward( query_states: torch.Tensor, @@ -473,6 +478,16 @@ def _flash_attention_forward( _pad_input = pad_input_fa2 _unpad_input = unpad_input_fa2 _is_fa3 = False + elif "kernel" in attn_implementation: + repo_id = attn_implementation.replace("kernel_", "").replace("_", "/") + global kernel + if kernel is None: + kernel = get_kernel(repo_id) + _flash_attn_varlen_func = getattr(kernel, "flash_attn_varlen_func") + _flash_attn_func = getattr(kernel, "flash_attn_func") + _pad_input = pad_input_fa2 + _unpad_input = unpad_input_fa2 + _is_fa3 = True if not use_top_left_mask: causal = is_causal @@ -485,7 +500,7 @@ def _flash_attention_forward( _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window ) flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} - + flash_kwargs["s_aux"] = kwargs["s_aux"] if _is_fa3: if dropout > 0.0: logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.") diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 95575e310000..7ad5df75803f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2309,7 +2309,7 @@ def _autoset_attn_implementation( try: kernel = get_kernel(repo_id) ALL_ATTENTION_FUNCTIONS.register( - f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name) + f"kernel_{repo_id.replace('/', '_')}", ALL_ATTENTION_FUNCTIONS["flash_attention_2"] # we need our extra layer to support kwargs ) config._attn_implementation = f"kernel_{repo_id.replace('/', '_')}" except FileNotFoundError as e: diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 9c0df024e5c2..3d8b4089304d 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -321,6 +321,7 @@ def forward( dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, # main diff with Llama + s_aux=self.sinks, **kwargs, ) From 1decf01d53e7236f84d79d3568f3de0f867cd0dd Mon Sep 17 00:00:00 2001 From: 3outeille Date: Tue, 8 Jul 2025 13:30:32 +0000 Subject: [PATCH 142/342] clean device_mesh --- .../integrations/tensor_parallel.py | 2 +- src/transformers/modeling_utils.py | 3 +-- .../models/openai_moe/modeling_openai_moe.py | 18 +++++++----------- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index a04b3aed9bb3..3c6b6691ed0e 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -35,7 +35,7 @@ if is_torch_greater_or_equal("2.5") and _torch_distributed_available: - from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, Partial + from torch.distributed.tensor import DTensor, Placement, Replicate, Shard def initialize_tensor_parallelism(tp_plan, tp_size=None): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1a3cdb814600..95575e310000 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2103,7 +2103,7 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): self._no_split_modules = self._no_split_modules or [] _CAN_RECORD_REGISTRY[self] = self._can_record_outputs # added for executorch support only - def post_init(self, device_mesh): + def post_init(self): """ A method executed at the end of each Transformer model initialization, to execute code that needs the model's modules properly initialized (such as weight initialization). @@ -4837,7 +4837,6 @@ def from_pretrained( with ContextManagers(model_init_context): # Let's make sure we don't run the init function of buffer modules - model_kwargs["device_mesh"] = device_mesh model = cls(config, *model_args, **model_kwargs) # Make sure to tie the weights correctly diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 085d21da4b91..b9b30d3a48e4 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -417,8 +417,8 @@ def _init_weights(self, module): class OpenAIMoeModel(OpenAIMoePreTrainedModel): _no_split_modules = ["OpenAIMoeDecoderLayer"] - def __init__(self, config: OpenAIMoeConfig, **kwargs): - super().__init__(config, **kwargs) + def __init__(self, config: OpenAIMoeConfig): + super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -430,10 +430,8 @@ def __init__(self, config: OpenAIMoeConfig, **kwargs): self.rotary_emb = OpenAIMoeRotaryEmbedding(config=config) self.gradient_checkpointing = False - self.device_mesh = kwargs.get("device_mesh") - # Initialize weights and apply final processing - self.post_init(self.device_mesh) + self.post_init() def get_input_embeddings(self): return self.embed_tokens @@ -594,19 +592,17 @@ class OpenAIMoeForCausalLM(OpenAIMoePreTrainedModel, GenerationMixin): _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - self.model = OpenAIMoeModel(config, **kwargs) + def __init__(self, config): + super().__init__(config) + self.model = OpenAIMoeModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok - self.device_mesh = kwargs.get("device_mesh") - # Initialize weights and apply final processing - self.post_init(self.device_mesh) + self.post_init() def get_input_embeddings(self): return self.model.embed_tokens From 3c484daa4ab82a1de9ad165e186220361832d7e0 Mon Sep 17 00:00:00 2001 From: 3outeille Date: Tue, 8 Jul 2025 15:58:37 +0000 Subject: [PATCH 143/342] dont transpose otherwise index out of bound when computing weighting_output --- src/transformers/models/openai_moe/modeling_openai_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index b9b30d3a48e4..740dbcda74a9 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -142,7 +142,7 @@ def forward(self, hidden_states): router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) # (num_experts, seq_len) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) return router_scores, router_indices @use_kernel_forward_from_hub("MegaBlocksMoeMLP") From 10ef90d5eaa61d4fd42e294394a28332b8035e06 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 8 Jul 2025 17:40:42 +0000 Subject: [PATCH 144/342] trtrying to fix --- src/transformers/integrations/flex_attention.py | 10 +++++++--- src/transformers/masking_utils.py | 5 ++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 9abff30e3961..1de7d2769a15 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -198,9 +198,6 @@ def default_mask_mod(batch_idx, head_idx, q_idx, kv_idx): mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod if offsets is not None: - q_offset = offsets[0] - kv_offset = offsets[1] - def mask_mod(batch_idx, head_idx, q_idx, kv_idx): offset_q = q_idx + q_offset offset_kv = kv_idx + kv_offset @@ -241,6 +238,7 @@ def flex_attention_forward( scaling: Optional[float] = None, softcap: Optional[float] = None, head_mask: Optional[torch.Tensor] = None, + s_aux: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: if head_mask is not None: @@ -271,6 +269,12 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): score = score + score_mask[batch_idx][0][q_idx][kv_idx] if head_mask is not None: score = score + head_mask[batch_idx][head_idx][0][0] + if s_aux is not None: + logits_max = torch.max(score, dim=-1, keepdim=True).values + sinks = torch.exp(s_aux - logits_max) + unnormalized_scores = torch.exp(score - logits_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer return score enable_gqa = True diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 8d5aab9f1342..eabd5af82f65 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -46,7 +46,7 @@ def and_masks(*mask_functions: list[Callable]) -> Callable: def and_mask(batch_idx, head_idx, q_idx, kv_idx): result = q_idx.new_ones((), dtype=torch.bool) for mask in mask_functions: - result = result & mask(batch_idx, head_idx, q_idx, kv_idx) + result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device) return result return and_mask @@ -60,7 +60,7 @@ def or_masks(*mask_functions: list[Callable]) -> Callable: def or_mask(batch_idx, head_idx, q_idx, kv_idx): result = q_idx.new_zeros((), dtype=torch.bool) for mask in mask_functions: - result = result | mask(batch_idx, head_idx, q_idx, kv_idx) + result = result | mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device) return result return or_mask @@ -141,7 +141,6 @@ def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offs This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, not start and end indices. """ - def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: return mask_function(batch_idx, head_idx, q_idx + q_offset, kv_idx + kv_offset) From dbb3c3b5b72b368475f2b5a10201460287e7188f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 8 Jul 2025 17:41:35 +0000 Subject: [PATCH 145/342] fix EP training, there was a pb, thanks bouteille --- src/transformers/integrations/tensor_parallel.py | 2 +- .../models/openai_moe/modeling_openai_moe.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index b736703d8da7..6844db7168e0 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -831,7 +831,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size() num_local_experts = mod.num_experts // ep_size router_scores, router_indices = outputs - router_scores = router_scores[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] + router_scores = router_scores[:, ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] return router_scores, router_indices def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 9c0df024e5c2..349fd343110a 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -88,7 +88,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig """ batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) - num_experts = routing_weights.shape[0] + num_experts = routing_weights.shape[1] if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): @@ -124,7 +124,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) # (num_experts, batch_size, seq_len, hidden_size) - next_states = next_states * routing_weights.view(num_experts, batch_size, -1)[...,None] + next_states = next_states * routing_weights.transpose(0,1).view(num_experts, batch_size, -1)[...,None] next_states = next_states.sum(dim=0) return next_states, routing_weights @@ -142,7 +142,7 @@ def forward(self, hidden_states): router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) # (num_experts, seq_len) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) return router_scores, router_indices @use_kernel_forward_from_hub("MegaBlocksMoeMLP") @@ -311,7 +311,7 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - + attn_output, attn_weights = attention_interface( self, query_states, @@ -375,7 +375,7 @@ class OpenAIMoePreTrainedModel(PreTrainedModel): config_class = OpenAIMoeConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["OpenAIMoeDecoderLayer"] + _no_split_modules = ["OpenAIMoeDecoderLayer", "OpenAIMoeAttention"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = False From c71777d1d4c7aebad52b5c37d749a1bbca6b6705 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 8 Jul 2025 18:10:23 +0000 Subject: [PATCH 146/342] fix again --- src/transformers/integrations/tensor_parallel.py | 1 + src/transformers/models/openai_moe/modeling_openai_moe.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 6844db7168e0..bbf3a376b461 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -832,6 +832,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me num_local_experts = mod.num_experts // ep_size router_scores, router_indices = outputs router_scores = router_scores[:, ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] + router_indices = router_indices[:, ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] return router_scores, router_indices def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 349fd343110a..d55a01cfb516 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -92,7 +92,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts).permute( + expert_mask = torch.nn.functional.one_hot(router_indices % num_experts, num_classes=num_experts).permute( 2, 1, 0 ) expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() @@ -539,7 +539,7 @@ def load_balancing_loss_func( if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat([layer_gate.to(compute_device).transpose(0,1) for layer_gate in gate_logits], dim=0) + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) From 6b402a960aaa942a557045b3cddfa3933ea70d0e Mon Sep 17 00:00:00 2001 From: 3outeille Date: Wed, 9 Jul 2025 09:54:46 +0000 Subject: [PATCH 147/342] no need for vocab parallel during inference --- .../models/openai_moe/configuration_openai_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 8fa68c30fc0a..591c91914b12 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -29,7 +29,7 @@ class OpenAIMoeConfig(PretrainedConfig): # Default tensor parallel plan for base model `OpenaiModel` # a bit special, but this seems to work alright base_model_tp_plan = { - "embed_tokens": "vocab_parallel_rowwise", + # "embed_tokens": "vocab_parallel_rowwise", "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", @@ -41,7 +41,7 @@ class OpenAIMoeConfig(PretrainedConfig): "layers.*.mlp.experts.down_proj": "local_colwise", "layers.*.mlp.experts.down_proj_bias": "local", # TODO: add smthg that says bias exists only once for all TPs "layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output - "lm_head": "vocab_parallel_colwise" + # "lm_head": "vocab_parallel_colwise" } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), From be973245b510cfcab3ee071a7b2cfd194ecfca2f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 9 Jul 2025 14:13:47 +0000 Subject: [PATCH 148/342] some fixes --- src/transformers/integrations/tensor_parallel.py | 2 +- src/transformers/models/openai_moe/modeling_openai_moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index b736703d8da7..1219ca68f78e 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -832,7 +832,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me num_local_experts = mod.num_experts // ep_size router_scores, router_indices = outputs router_scores = router_scores[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] - return router_scores, router_indices + return router_scores, router_indices % num_local_experts def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): # TODO: i'd like for this to be the default diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 3d8b4089304d..429bdbc71826 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -141,7 +141,7 @@ def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) # (num_experts, seq_len) return router_scores, router_indices From ac0d14015fd0f56a94a983dcdfff17fc325ed6c3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 9 Jul 2025 14:14:56 +0000 Subject: [PATCH 149/342] IDK why some of the fixes did not get in --- src/transformers/models/openai_moe/modeling_openai_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 429bdbc71826..883fd681d4ab 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -88,7 +88,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig """ batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) - num_experts = routing_weights.shape[0] + num_experts = routing_weights.shape[1] if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): @@ -142,7 +142,7 @@ def forward(self, hidden_states): router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) # (num_experts, seq_len) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) return router_scores, router_indices @use_kernel_forward_from_hub("MegaBlocksMoeMLP") From 5f5054ae5b9989feaa0dcb959fe3738ba6ef5176 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 9 Jul 2025 14:17:39 +0000 Subject: [PATCH 150/342] merges suckes a lot --- src/transformers/integrations/tensor_parallel.py | 1 + src/transformers/models/openai_moe/modeling_openai_moe.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 1219ca68f78e..a27d77434556 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -832,6 +832,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me num_local_experts = mod.num_experts // ep_size router_scores, router_indices = outputs router_scores = router_scores[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] + router_indices = router_indices[:, ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] return router_scores, router_indices % num_local_experts def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 883fd681d4ab..1dd99292e937 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -124,7 +124,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) # (num_experts, batch_size, seq_len, hidden_size) - next_states = next_states * routing_weights.view(num_experts, batch_size, -1)[...,None] + next_states = next_states * routing_weights.transopose(0,1).view(num_experts, batch_size, -1)[...,None] next_states = next_states.sum(dim=0) return next_states, routing_weights @@ -376,7 +376,7 @@ class OpenAIMoePreTrainedModel(PreTrainedModel): config_class = OpenAIMoeConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["OpenAIMoeDecoderLayer"] + _no_split_modules = ["OpenAIMoeDecoderLayer", "OpenAIMoeAttention"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = False From 58ee998e48882b3a6c5e7e14c8c8aa133e76f0a3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 9 Jul 2025 15:51:02 +0000 Subject: [PATCH 151/342] fix shape and index issues we had --- .../integrations/tensor_parallel.py | 8 ++++--- .../models/openai_moe/modeling_openai_moe.py | 21 +++++++++---------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index a27d77434556..710be54e37ac 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -831,9 +831,11 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size() num_local_experts = mod.num_experts // ep_size router_scores, router_indices = outputs - router_scores = router_scores[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] - router_indices = router_indices[:, ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] - return router_scores, router_indices % num_local_experts + if not mod.training: + router_scores = router_scores[:, ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] + # We tell the expert that thos indices are no his to handle + router_indices = router_indices.masked_fill(router_indices // ep_size == mod.num_experts, 0 ) % num_local_experts + return router_scores, router_indices def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): # TODO: i'd like for this to be the default diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 1dd99292e937..60c58f284c1a 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -92,16 +92,15 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts).permute( - 2, 1, 0 - ) + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + expert_mask = expert_mask.permute(2,1,0) expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hitted: + for expert_idx in expert_hitted[0]: with torch.no_grad(): - idx, top_x = torch.where( - expert_mask[expert_idx][0] + expert_idx, token_idx = torch.where( + expert_mask[expert_idx] ) # idx: top-1/top-2 indicator, top_x: token indices - current_state = hidden_states[top_x] # (num_tokens, hidden_dim) + current_state = hidden_states[token_idx] # (num_tokens, hidden_dim) gate_up = ( current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] ) # (num_tokens, 2 * interm_dim) @@ -111,8 +110,8 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig out = ( gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] ) # (num_tokens, hidden_dim) - weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) - next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) + weighted_output = out * routing_weights[token_idx, expert_idx, None] # (num_tokens, hidden_dim) + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)[0]) next_states = next_states.view(batch_size, -1, self.hidden_size) routing_weights = torch.ones_like(next_states) else: @@ -124,7 +123,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) # (num_experts, batch_size, seq_len, hidden_size) - next_states = next_states * routing_weights.transopose(0,1).view(num_experts, batch_size, -1)[...,None] + next_states = next_states * routing_weights.transpose(0,1).view(num_experts, batch_size, -1)[...,None] next_states = next_states.sum(dim=0) return next_states, routing_weights @@ -540,7 +539,7 @@ def load_balancing_loss_func( if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat([layer_gate.to(compute_device).transpose(0,1) for layer_gate in gate_logits], dim=0) + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) From 9a73b3f8da9225d667794c2bafae0f80dde4c90c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 9 Jul 2025 17:15:31 +0000 Subject: [PATCH 152/342] fix exploding memory issues in EP, need to test if non EP is still valid, and TP as well --- .../integrations/tensor_parallel.py | 37 +++++++++++++++---- src/transformers/modeling_utils.py | 4 +- .../openai_moe/configuration_openai_moe.py | 2 - .../models/openai_moe/modeling_openai_moe.py | 17 +++++---- 4 files changed, 41 insertions(+), 19 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 710be54e37ac..d9cfcb1bac97 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -22,6 +22,7 @@ import torch import torch.distributed as dist +from torch.autograd import Function from torch import nn from ..utils import is_torch_greater_or_equal, logging @@ -82,12 +83,12 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None): tp_device = torch.device(device_type, index) # Silence output for non-primary ranks - if index is not None and index > 0: - import sys - - sys.stdout = open(os.devnull, "w") - sys.stderr = open(os.devnull, "w") - + # if index is not None and index > 0: + # import sys + # + # sys.stdout = open(os.devnull, "w") + # sys.stderr = open(os.devnull, "w") + # device_map = tp_device tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size() device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,)) @@ -407,6 +408,25 @@ def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), ) +class AllReduce(Function): + @staticmethod + def forward(ctx, tensor: torch.Tensor) -> torch.Tensor: + # Make a copy so we don’t modify the original in-place + out = tensor.clone() + # Sum-reduce across all ranks + dist.all_reduce(out, op=dist.ReduceOp.SUM) + return out + # + # @staticmethod + # def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + # # The gradient itself needs to be summed across ranks, too + # grad = grad_output.clone() + # dist.all_reduce(grad, op=dist.ReduceOp.SUM) + # return grad + +def all_reduce(tensor: torch.Tensor) -> torch.Tensor: + return AllReduce.apply(tensor) + # use_dtensor needs to be set to false for nn.Parameter when you want to view, chunk, slice # you name it. Whatever you want to do that is a bit unconventional, you need local tensors @@ -438,10 +458,11 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): # this op cannot be async, otherwise it completely breaks the outputs of models if isinstance(outputs, torch.Tensor): - torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.SUM, async_op=False) + dist.all_reduce(outputs, op = dist.ReduceOp.SUM) else: # TODO: we assume we want to allreduce first element of tuple - torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False) # TODO: rename GatherParallel to ReduceParallel or something + # all_reduce(outputs[0]) + dist.all_reduce(outputs[0], op = dist.ReduceOp.SUM) return outputs diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7ad5df75803f..92b4dbeb6210 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4826,10 +4826,10 @@ def from_pretrained( torch_dtype=torch_dtype, device_map=device_map, ) - + # if distributed_config is not None and distributed_config.enable_expert_parallel: # TODO: add proper support for ep_plan independently of tp_plan - if config.base_model_ep_plan is None: + if getattr(config, "base_model_ep_plan", None)is None: raise ValueError("base_model_ep_plan is required when enable_expert_parallel is True") config.base_model_tp_plan = config.base_model_ep_plan # TODO: hack for now diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index dbde38eb8c7b..96f4fdd4ea47 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -53,9 +53,7 @@ class OpenAIMoeConfig(PretrainedConfig): "layers.*.self_attn.o_proj": "rowwise", "layers.*.self_attn.sinks": "local_rowwise", - # TODO: i shouldn't have to do the above, but when removing it, it doesnt partition them "layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output - # 'layers.*.mlp.token_dispatcher': "gather", "layers.*.mlp.router": "ep_router", "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 60c58f284c1a..e0789b53df28 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -97,9 +97,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hitted[0]: with torch.no_grad(): - expert_idx, token_idx = torch.where( + routing_idx, token_idx = torch.where( expert_mask[expert_idx] ) # idx: top-1/top-2 indicator, top_x: token indices + print(expert_idx) current_state = hidden_states[token_idx] # (num_tokens, hidden_dim) gate_up = ( current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] @@ -110,8 +111,9 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig out = ( gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] ) # (num_tokens, hidden_dim) - weighted_output = out * routing_weights[token_idx, expert_idx, None] # (num_tokens, hidden_dim) - next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)[0]) + weighted_output = out * routing_weights[token_idx, routing_idx, None] # (num_tokens, hidden_dim) + print(token_idx.shape, weighted_output.shape) + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.view(batch_size, -1, self.hidden_size) routing_weights = torch.ones_like(next_states) else: @@ -572,8 +574,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) .to(compute_device) ) @@ -581,8 +583,9 @@ def load_balancing_loss_func( router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( router_per_expert_attention_mask, dim=0 ) - - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + + rank = routing_weights.shape[1]*int(routing_weights.device.index) + overall_loss = torch.sum(tokens_per_expert[:,rank: rank +routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts From beb2ea6b183267a6249414e35d2f5e66e9ca3591 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 10 Jul 2025 01:01:04 +0000 Subject: [PATCH 153/342] back --- .../models/openai_moe/modeling_openai_moe.py | 30 +++++++------------ 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index d55a01cfb516..4adb2780260e 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -235,25 +235,17 @@ def eager_attention_forward( ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - sinks = module.sinks.reshape(1, -1, 1, 1).expand( - query.shape[0], -1, query.shape[-2], -1 - ) # TODO make sure the sink is like a new token attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - # scale the logits to prevent overflows - logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values - sinks = torch.exp(sinks - logits_max) - unnormalized_scores = torch.exp(attn_weights - logits_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - scores = unnormalized_scores / normalizer - - attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) # ignore the sinks + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -350,7 +342,7 @@ def forward( position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states + # residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, _ = self.self_attn( hidden_states=hidden_states, @@ -362,11 +354,11 @@ def forward( position_embeddings=position_embeddings, **kwargs, ) - hidden_states = residual + hidden_states - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states, _ = self.mlp(hidden_states) - hidden_states = residual + hidden_states + # hidden_states = residual + hidden_states + # residual = hidden_states + # hidden_states = self.post_attention_layernorm(hidden_states) + # hidden_states, _ = self.mlp(hidden_states) + # hidden_states = residual + hidden_states return hidden_states @@ -487,7 +479,7 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for decoder_layer in self.layers[:1]: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], From 108bff09ffcfcf11023542c8583a8b8cc1c4001c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 10 Jul 2025 01:07:12 +0000 Subject: [PATCH 154/342] fix --- src/transformers/models/openai_moe/modeling_openai_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 4adb2780260e..a51725b00cd3 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -241,9 +241,12 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask + sinks = module.sinks.view(1, -1, 1, 1).expand(query.shape[0], query.shape[1], query.shape[2], 1) + attn_weights = torch.cat([attn_weights, sinks], dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_weights_real = attn_weights[..., :-1] + attn_output = torch.matmul(attn_weights_real, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights From 546efee628796df3a66f5ae301313c721f15b805 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 10 Jul 2025 03:59:04 +0000 Subject: [PATCH 155/342] fix --- .../models/openai_moe/modeling_openai_moe.py | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index a51725b00cd3..e5eab5d3c2a5 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -235,20 +235,26 @@ def eager_attention_forward( ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + query.shape[0], -1, query.shape[-2], -1 + ) # TODO make sure the sink is like a new token attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - sinks = module.sinks.view(1, -1, 1, 1).expand(query.shape[0], query.shape[1], query.shape[2], 1) - attn_weights = torch.cat([attn_weights, sinks], dim=-1) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_weights_real = attn_weights[..., :-1] - attn_output = torch.matmul(attn_weights_real, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() + # scale the logits to prevent overflows + logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values + logits_max = torch.where(logits_max < 1e9, torch.tensor(0.0, device=logits_max.device), logits_max) # when logits_max is <1e9, just use 0 + sinks = torch.exp(sinks - logits_max) + unnormalized_scores = torch.exp(attn_weights - logits_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) # ignore the sinks + attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights @@ -345,7 +351,7 @@ def forward( position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - # residual = hidden_states + residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, _ = self.self_attn( hidden_states=hidden_states, @@ -357,11 +363,11 @@ def forward( position_embeddings=position_embeddings, **kwargs, ) - # hidden_states = residual + hidden_states - # residual = hidden_states - # hidden_states = self.post_attention_layernorm(hidden_states) - # hidden_states, _ = self.mlp(hidden_states) - # hidden_states = residual + hidden_states + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) + hidden_states = residual + hidden_states return hidden_states @@ -482,7 +488,7 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[:1]: + for decoder_layer in self.layers: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], From 7b077aeec48a1e69127a74bad0424212e58980db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 10 Jul 2025 04:45:22 +0000 Subject: [PATCH 156/342] nice --- .../models/openai_moe/modeling_openai_moe.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index e5eab5d3c2a5..dda07459ff18 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -235,23 +235,15 @@ def eager_attention_forward( ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - sinks = module.sinks.reshape(1, -1, 1, 1).expand( - query.shape[0], -1, query.shape[-2], -1 - ) # TODO make sure the sink is like a new token attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - # scale the logits to prevent overflows - logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values - logits_max = torch.where(logits_max < 1e9, torch.tensor(0.0, device=logits_max.device), logits_max) # when logits_max is <1e9, just use 0 - sinks = torch.exp(sinks - logits_max) - unnormalized_scores = torch.exp(attn_weights - logits_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - scores = unnormalized_scores / normalizer - + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + scores = nn.functional.softmax(combined_logits, dim=-1)[..., :-1] attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) # ignore the sinks attn_output = attn_output.transpose(1, 2).contiguous() From ebcae9a905661ba351a137328091eea195853e08 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 10 Jul 2025 15:52:16 +0000 Subject: [PATCH 157/342] where i am at --- .../integrations/tensor_parallel.py | 83 ++++++++++++++----- src/transformers/modeling_utils.py | 23 ++--- .../openai_moe/configuration_openai_moe.py | 12 ++- .../models/openai_moe/modeling_openai_moe.py | 71 ++++++++-------- 4 files changed, 117 insertions(+), 72 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index d9cfcb1bac97..96d568349293 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -83,12 +83,12 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None): tp_device = torch.device(device_type, index) # Silence output for non-primary ranks - # if index is not None and index > 0: - # import sys - # - # sys.stdout = open(os.devnull, "w") - # sys.stderr = open(os.devnull, "w") - # + if index is not None and index > 0: + import sys + + sys.stdout = open(os.devnull, "w") + sys.stderr = open(os.devnull, "w") + device_map = tp_device tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size() device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,)) @@ -408,6 +408,7 @@ def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), ) + class AllReduce(Function): @staticmethod def forward(ctx, tensor: torch.Tensor) -> torch.Tensor: @@ -416,6 +417,7 @@ def forward(ctx, tensor: torch.Tensor) -> torch.Tensor: # Sum-reduce across all ranks dist.all_reduce(out, op=dist.ReduceOp.SUM) return out + # # @staticmethod # def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: @@ -424,6 +426,7 @@ def forward(ctx, tensor: torch.Tensor) -> torch.Tensor: # dist.all_reduce(grad, op=dist.ReduceOp.SUM) # return grad + def all_reduce(tensor: torch.Tensor) -> torch.Tensor: return AllReduce.apply(tensor) @@ -450,19 +453,20 @@ def __init__( @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + mod.expert_parallel_group = device_mesh.get_group() if inputs and isinstance(inputs[0], DTensor): inputs = inputs[0].to_local() return inputs @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + if hasattr(mod, "kernel_layer_name"): # kernels usually handle this themselves + return outputs # this op cannot be async, otherwise it completely breaks the outputs of models if isinstance(outputs, torch.Tensor): - dist.all_reduce(outputs, op = dist.ReduceOp.SUM) + dist.all_reduce(outputs, op=dist.ReduceOp.SUM, async_op=False) else: - # TODO: we assume we want to allreduce first element of tuple - # all_reduce(outputs[0]) - dist.all_reduce(outputs[0], op = dist.ReduceOp.SUM) + dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False) return outputs @@ -812,10 +816,12 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) + class GroupedGemmParallel(TensorParallelLayer): """ Applies Expert Parallelism to MoE experts by loading the correct experts on each device. """ + def __init__(self): super().__init__() self.use_dtensor = False @@ -824,17 +830,21 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, ep_rank = rank global_num_experts = empty_param.shape[0] if global_num_experts % device_mesh.size() != 0: - raise ValueError(f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0") + raise ValueError( + f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0" + ) local_num_experts = global_num_experts // device_mesh.size() - param = param[ep_rank*local_num_experts:(ep_rank+1)*local_num_experts].to(param_casting_dtype) + param = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype) if to_contiguous: param = param.contiguous() return param + class RouterParallel(TensorParallelLayer): """ - Allows to reshape the router scores to support running expert parallel. + Allows to reshape the router scores to support running expert parallel. """ + def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs @@ -846,16 +856,48 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_ if isinstance(input_tensor, DTensor): raise NotImplementedError("RouterParallel does not support DTensor input for now") return input_tensor - + @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + """ + Imagine if you had 4 tokens, top_k = 4, and 128experts. + With EP = 8. + Imagine router_indices being: + [ 52, 42, 119, 67], + [102, 89, 61, 40], + [ 82, 103, 4, 34], + [ 93, 23, 109, 11], + + then you can map which rank should be getting which values + + [3, 2, 7, 4], + [6, 5, 3, 2], + [5, 6, 0, 2], + [5, 1, 6, 0], + + Thus for say rank 0, you fill with 0 the index tensor + + [ 0, 0, 0, 0], + [ 0, 0, 0, 0], + [ 0, 0, 4, 0], + [ 0, 0, 0, 11], + + This works well. For another rank you need to make sure you round to num_local_expert + because the next operation will one hot encode the router index vector. + + This allows us to know directly which local expert is hit. + Similarly the scores are indexed with something created form + router_indices. + + The kinda naive training loop that we use for device_map "auto" uses a similar logic. + Here we are just making each rank believe that he is alone, and he computes his part of the hiddenstates. + """ ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size() num_local_experts = mod.num_experts // ep_size router_scores, router_indices = outputs - if not mod.training: - router_scores = router_scores[:, ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] - # We tell the expert that thos indices are no his to handle - router_indices = router_indices.masked_fill(router_indices // ep_size == mod.num_experts, 0 ) % num_local_experts + router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts] + router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, -1) + router_indices = (router_indices + 1) % (num_local_experts + 1) return router_scores, router_indices def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): @@ -864,7 +906,6 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, if to_contiguous: param = param.contiguous() return param - def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: # TODO: need an abstract Parallel class that is different from TensorParallelLayer @@ -992,7 +1033,7 @@ def __init__(self): def shard_and_distribute_module( model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh -): # TODO: rename to shard_and_distribute_param +): # TODO: rename to shard_and_distribute_param r""" Main uses cases: - column / rowise parallelism, you just shard all the weights of the layer (weight and bias) @@ -1004,7 +1045,7 @@ def shard_and_distribute_module( """ param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name tp_plan = model._tp_plan - module_to_tp = model.get_submodule(param_name) # TODO: can i loop over modules? + module_to_tp = model.get_submodule(param_name) # TODO: can i loop over modules? rank = int(rank) current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 92b4dbeb6210..7262e56e1496 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -775,7 +775,7 @@ def _load_state_dict_into_meta_model( file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) for param_name, empty_param in state_dict.items(): - if param_name not in expected_keys: # when loading from ckpt, we skip param if doesnt exist in modeling + if param_name not in expected_keys: # when loading from ckpt, we skip param if doesnt exist in modeling continue # we need to use serialized_param_name as file pointer is untouched @@ -2150,21 +2150,22 @@ def post_init(self): raise ValueError( f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}" ) - + if is_torch_greater_or_equal("2.5") and _torch_distributed_available and self.config.device_mesh is not None: # loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit device_mesh = self.config.device_mesh for name, module in self.named_modules(): if not getattr(module, "_is_hooked", False): from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module + add_tensor_parallel_hooks_to_module( model=self, module=module, tp_plan=self._tp_plan, - layer_name="", # TODO: make this optional? + layer_name="", # TODO: make this optional? current_module_plan=_get_parameter_tp_plan(parameter_name=name, tp_plan=self._tp_plan), device_mesh=device_mesh, - parameter_name=None + parameter_name=None, ) module._is_hooked = True @@ -2309,7 +2310,8 @@ def _autoset_attn_implementation( try: kernel = get_kernel(repo_id) ALL_ATTENTION_FUNCTIONS.register( - f"kernel_{repo_id.replace('/', '_')}", ALL_ATTENTION_FUNCTIONS["flash_attention_2"] # we need our extra layer to support kwargs + f"kernel_{repo_id.replace('/', '_')}", + ALL_ATTENTION_FUNCTIONS["flash_attention_2"], # we need our extra layer to support kwargs ) config._attn_implementation = f"kernel_{repo_id.replace('/', '_')}" except FileNotFoundError as e: @@ -4463,7 +4465,7 @@ def from_pretrained( gguf_file = kwargs.pop("gguf_file", None) tp_plan = kwargs.pop("tp_plan", None) tp_size = kwargs.pop("tp_size", None) - distributed_config : DistributedConfig = kwargs.pop("distributed_config", None) + distributed_config: DistributedConfig = kwargs.pop("distributed_config", None) device_mesh = kwargs.pop("device_mesh", None) trust_remote_code = kwargs.pop("trust_remote_code", None) use_kernels = kwargs.pop("use_kernels", False) @@ -4829,11 +4831,11 @@ def from_pretrained( # if distributed_config is not None and distributed_config.enable_expert_parallel: # TODO: add proper support for ep_plan independently of tp_plan - if getattr(config, "base_model_ep_plan", None)is None: + if getattr(config, "base_model_ep_plan", None) is None: raise ValueError("base_model_ep_plan is required when enable_expert_parallel is True") - config.base_model_tp_plan = config.base_model_ep_plan # TODO: hack for now + config.base_model_tp_plan = config.base_model_ep_plan # TODO: hack for now - config.device_mesh = device_mesh # Used in post_init + config.device_mesh = device_mesh # Used in post_init with ContextManagers(model_init_context): # Let's make sure we don't run the init function of buffer modules @@ -4941,8 +4943,9 @@ def _assign_original_dtype(module): # check if using kernels if use_kernels: from kernels import Device, kernelize + from kernels.layer import Mode - kernelize(model, device=Device(type=model.device.type)) + kernelize(model, mode=Mode.TRAINING, device=Device(type=model.device.type)) # If it is a model with generation capabilities, attempt to load generation files (generation config, # custom generate function) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 96f4fdd4ea47..d901c8f6a26b 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -34,12 +34,11 @@ class OpenAIMoeConfig(PretrainedConfig): "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.self_attn.sinks": "local_rowwise", - "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", "layers.*.mlp.experts.down_proj": "local_colwise", - "layers.*.mlp.experts.down_proj_bias": "local", # TODO: add smthg that says bias exists only once for all TPs - "layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output + "layers.*.mlp.experts.down_proj_bias": "local", # TODO: add smthg that says bias exists only once for all TPs + # "layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), @@ -52,19 +51,18 @@ class OpenAIMoeConfig(PretrainedConfig): "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.self_attn.sinks": "local_rowwise", - - "layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output + "layers.*.mlp": "gather", "layers.*.mlp.router": "ep_router", "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", "layers.*.mlp.experts.down_proj": "grouped_gemm", - "layers.*.mlp.experts.down_proj_bias": "grouped_gemm", + "layers.*.mlp.experts.down_proj_bias": "grouped_gemm", } def __init__( self, num_hidden_layers: int = 36, - num_local_experts: int = 128, #TODO: rename to num_experts otherwise confusing with EP + num_local_experts: int = 128, # TODO: rename to num_experts otherwise confusing with EP vocab_size: int = 201088, hidden_size: int = 2880, intermediate_size: int = 2880, diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index e0789b53df28..31f6db834f7d 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -82,40 +82,34 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig Args: hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) selected_experts (torch.Tensor): (batch_size * token_num, top_k) - routing_weights (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, num_experts) Returns: torch.Tensor """ batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) num_experts = routing_weights.shape[1] if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - expert_mask = expert_mask.permute(2,1,0) + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence lenght to get which experts + # are hit this time around expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hitted[0]: + for expert_idx in expert_hitted[:]: with torch.no_grad(): - routing_idx, token_idx = torch.where( - expert_mask[expert_idx] - ) # idx: top-1/top-2 indicator, top_x: token indices - print(expert_idx) - current_state = hidden_states[token_idx] # (num_tokens, hidden_dim) - gate_up = ( - current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] - ) # (num_tokens, 2 * interm_dim) - gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim) - glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim) - gated_output = (up + 1) * glu # (num_tokens, interm_dim) - out = ( - gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - ) # (num_tokens, hidden_dim) - weighted_output = out * routing_weights[token_idx, routing_idx, None] # (num_tokens, hidden_dim) - print(token_idx.shape, weighted_output.shape) + _, token_idx = torch.where(expert_mask[expert_idx[0]]) + current_state = hidden_states[token_idx] + expert_idx -= 1 # because the index given by expert mask is off by one + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] + gate, up = gate_up.chunk(2, dim=-1) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + weighted_output = out[0] * routing_weights[token_idx - 1, expert_idx, None] next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.view(batch_size, -1, self.hidden_size) - routing_weights = torch.ones_like(next_states) else: hidden_states = hidden_states.repeat(num_experts, 1) hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) @@ -124,11 +118,12 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig glu = gate * torch.sigmoid(gate * self.alpha) next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] - next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) # (num_experts, batch_size, seq_len, hidden_size) - next_states = next_states * routing_weights.transpose(0,1).view(num_experts, batch_size, -1)[...,None] + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) + next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] next_states = next_states.sum(dim=0) return next_states, routing_weights + class TopKRouter(nn.Module): def __init__(self, config): super().__init__() @@ -137,15 +132,16 @@ def __init__(self, config): self.hidden_dim = config.hidden_size self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) self.bias = nn.Parameter(torch.empty(self.num_experts)) - + def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) return router_scores, router_indices + @use_kernel_forward_from_hub("MegaBlocksMoeMLP") class OpenAIMoeMLP(nn.Module): def __init__(self, config): @@ -154,8 +150,10 @@ def __init__(self, config): self.experts = OpenAIMoeExperts(config) def forward(self, hidden_states): - router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) - routed_out, router_weights = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) #TODO: router_indices isn't used inside this func + router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) + routed_out, router_weights = self.experts( + hidden_states, router_indices=router_indices, routing_weights=router_scores + ) return routed_out, router_scores @@ -245,13 +243,16 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - # scale the logits to prevent overflows + # # scale the logits to prevent overflows logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values sinks = torch.exp(sinks - logits_max) unnormalized_scores = torch.exp(attn_weights - logits_max) normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks scores = unnormalized_scores / normalizer + # sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + # combined_logits = torch.cat([attn_weights, sinks], dim=-1) + # scores = nn.functional.softmax(combined_logits, dim=-1)[..., :-1] attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) # ignore the sinks attn_output = attn_output.transpose(1, 2).contiguous() @@ -479,7 +480,7 @@ def forward( "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "position_ids": position_ids + "position_ids": position_ids, } causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), @@ -583,9 +584,11 @@ def load_balancing_loss_func( router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( router_per_expert_attention_mask, dim=0 ) - - rank = routing_weights.shape[1]*int(routing_weights.device.index) - overall_loss = torch.sum(tokens_per_expert[:,rank: rank +routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)) + + rank = routing_weights.shape[1] * int(routing_weights.device.index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) return overall_loss * num_experts From 528b3c859d3ba14e2aab8e3c8d62c084e4f02c40 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 10 Jul 2025 15:57:22 +0000 Subject: [PATCH 158/342] Bro this works --- src/transformers/integrations/tensor_parallel.py | 5 +++-- src/transformers/models/openai_moe/modeling_openai_moe.py | 5 ++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 96d568349293..c491ba41cdb7 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -896,8 +896,9 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me num_local_experts = mod.num_experts // ep_size router_scores, router_indices = outputs router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts] - router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, -1) - router_indices = (router_indices + 1) % (num_local_experts + 1) + router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, 0) + router_indices = (router_indices) % (num_local_experts) + # router_indices = (router_indices + 1) % (num_local_experts + 1) return router_scores, router_indices def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 31f6db834f7d..3eb0f862883e 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -92,7 +92,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) expert_mask = expert_mask.permute(2, 1, 0) # we sum on the top_k and on the sequence lenght to get which experts # are hit this time around @@ -101,13 +101,12 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig with torch.no_grad(): _, token_idx = torch.where(expert_mask[expert_idx[0]]) current_state = hidden_states[token_idx] - expert_idx -= 1 # because the index given by expert mask is off by one gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] gate, up = gate_up.chunk(2, dim=-1) glu = gate * torch.sigmoid(gate * self.alpha) gated_output = (up + 1) * glu out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - weighted_output = out[0] * routing_weights[token_idx - 1, expert_idx, None] + weighted_output = out[0] * routing_weights[token_idx, expert_idx, None] next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.view(batch_size, -1, self.hidden_size) else: From 297e47e2eebf23a0110d1b8804888fc33d59442e Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 11 Jul 2025 14:25:15 +0200 Subject: [PATCH 159/342] Update src/transformers/integrations/tensor_parallel.py --- src/transformers/integrations/tensor_parallel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index d720c96f3feb..828c118b7e52 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -520,6 +520,8 @@ class VocabParallel(TensorParallelLayer): """ VocabParallel is a tensor parallel layer that shards the embedding table along the last dimension. No need to do input masking as embedding would be stored in `_MaskPartial` which handles it (https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_embedding_ops.py#L70) + + This is useful if you want to train with long sequence length! """ def __init__( From 3d25cf756131c2d875ac23d62617808fcb550d15 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 11 Jul 2025 12:47:14 +0000 Subject: [PATCH 160/342] cleanups --- .../integrations/flex_attention.py | 7 ++++-- .../integrations/tensor_parallel.py | 21 ++++++++--------- .../openai_moe/configuration_openai_moe.py | 4 +--- .../models/openai_moe/modeling_openai_moe.py | 23 +++++++++++-------- 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 583fcee3f7d0..6fa9aa830636 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -198,6 +198,9 @@ def default_mask_mod(batch_idx, head_idx, q_idx, kv_idx): mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod if offsets is not None: + q_offset = offsets[0].to(device) + kv_offset = offsets[1].to(device) + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): offset_q = q_idx + q_offset offset_kv = kv_idx + kv_offset @@ -258,7 +261,7 @@ def flex_attention_forward( block_mask = attention_mask else: score_mask = attention_mask - + if score_mask is not None: score_mask = score_mask[:, :, :, : key.shape[-2]] @@ -274,7 +277,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): sinks = torch.exp(s_aux - logits_max) unnormalized_scores = torch.exp(score - logits_max) normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - scores = unnormalized_scores / normalizer + score = unnormalized_scores / normalizer return score enable_gqa = True diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index c491ba41cdb7..48cde8ff663b 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -418,13 +418,12 @@ def forward(ctx, tensor: torch.Tensor) -> torch.Tensor: dist.all_reduce(out, op=dist.ReduceOp.SUM) return out - # - # @staticmethod - # def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: - # # The gradient itself needs to be summed across ranks, too - # grad = grad_output.clone() - # dist.all_reduce(grad, op=dist.ReduceOp.SUM) - # return grad + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + # The gradient itself needs to be summed across ranks, too + grad = grad_output.clone() + dist.all_reduce(grad, op=dist.ReduceOp.SUM) + return grad def all_reduce(tensor: torch.Tensor) -> torch.Tensor: @@ -462,11 +461,10 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): if hasattr(mod, "kernel_layer_name"): # kernels usually handle this themselves return outputs - # this op cannot be async, otherwise it completely breaks the outputs of models if isinstance(outputs, torch.Tensor): - dist.all_reduce(outputs, op=dist.ReduceOp.SUM, async_op=False) + all_reduce(outputs) else: - dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False) + all_reduce(outputs[0]) return outputs @@ -897,8 +895,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me router_scores, router_indices = outputs router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts] router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, 0) - router_indices = (router_indices) % (num_local_experts) - # router_indices = (router_indices + 1) % (num_local_experts + 1) + router_indices = router_indices % num_local_experts return router_scores, router_indices def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index d901c8f6a26b..e1d1a5725d6e 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -26,8 +26,6 @@ class OpenAIMoeConfig(PretrainedConfig): """ model_type = "openai_moe" - # Default tensor parallel plan for base model `OpenaiModel` - # a bit special, but this seems to work alright base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", @@ -38,7 +36,7 @@ class OpenAIMoeConfig(PretrainedConfig): "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", "layers.*.mlp.experts.down_proj": "local_colwise", "layers.*.mlp.experts.down_proj_bias": "local", # TODO: add smthg that says bias exists only once for all TPs - # "layers.*.mlp.experts": "gather", # TODO: same, this should mean i want to allreduce output + "layers.*.mlp": "gather", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 3eb0f862883e..508c900c3f38 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -242,16 +242,18 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - # # scale the logits to prevent overflows - logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values - sinks = torch.exp(sinks - logits_max) - unnormalized_scores = torch.exp(attn_weights - logits_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - scores = unnormalized_scores / normalizer - - # sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) - # combined_logits = torch.cat([attn_weights, sinks], dim=-1) - # scores = nn.functional.softmax(combined_logits, dim=-1)[..., :-1] + # TODO: check wether both produce the same results or not! + # scale the logits to prevent overflows + # logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values + # sinks = torch.exp(sinks - logits_max) + # unnormalized_scores = torch.exp(attn_weights - logits_max) + # normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + # scores = unnormalized_scores / normalizer + + # TODO we are going with this one to fix gradients becoming nans + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + scores = nn.functional.softmax(combined_logits, dim=-1)[..., :-1] attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) # ignore the sinks attn_output = attn_output.transpose(1, 2).contiguous() @@ -585,6 +587,7 @@ def load_balancing_loss_func( ) rank = routing_weights.shape[1] * int(routing_weights.device.index) + # TODO not 100% this one is correct, will fix in another PR overall_loss = torch.sum( tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) ) From 29454d2e82958964eec6c9b68802c921ecb60273 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 11 Jul 2025 13:18:00 +0000 Subject: [PATCH 161/342] yups that was breaking --- .../integrations/tensor_parallel.py | 25 ++----------------- .../openai_moe/configuration_openai_moe.py | 2 +- .../models/openai_moe/modeling_openai_moe.py | 16 ++++++------ 3 files changed, 11 insertions(+), 32 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 48cde8ff663b..fe7bd65380ac 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -409,27 +409,6 @@ def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: ) -class AllReduce(Function): - @staticmethod - def forward(ctx, tensor: torch.Tensor) -> torch.Tensor: - # Make a copy so we don’t modify the original in-place - out = tensor.clone() - # Sum-reduce across all ranks - dist.all_reduce(out, op=dist.ReduceOp.SUM) - return out - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: - # The gradient itself needs to be summed across ranks, too - grad = grad_output.clone() - dist.all_reduce(grad, op=dist.ReduceOp.SUM) - return grad - - -def all_reduce(tensor: torch.Tensor) -> torch.Tensor: - return AllReduce.apply(tensor) - - # use_dtensor needs to be set to false for nn.Parameter when you want to view, chunk, slice # you name it. Whatever you want to do that is a bit unconventional, you need local tensors class GatherParallel(TensorParallelLayer): @@ -462,9 +441,9 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me if hasattr(mod, "kernel_layer_name"): # kernels usually handle this themselves return outputs if isinstance(outputs, torch.Tensor): - all_reduce(outputs) + dist.all_reduce(outputs, op=dist.ReduceOp.SUM, async_op=False) else: - all_reduce(outputs[0]) + dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False) return outputs diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index e1d1a5725d6e..f7b3e7567799 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -36,7 +36,7 @@ class OpenAIMoeConfig(PretrainedConfig): "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", "layers.*.mlp.experts.down_proj": "local_colwise", "layers.*.mlp.experts.down_proj_bias": "local", # TODO: add smthg that says bias exists only once for all TPs - "layers.*.mlp": "gather", + # "layers.*.mlp": "gather", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 508c900c3f38..252fa9a6b357 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -244,16 +244,16 @@ def eager_attention_forward( # TODO: check wether both produce the same results or not! # scale the logits to prevent overflows - # logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values - # sinks = torch.exp(sinks - logits_max) - # unnormalized_scores = torch.exp(attn_weights - logits_max) - # normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - # scores = unnormalized_scores / normalizer + logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values + sinks = torch.exp(sinks - logits_max) + unnormalized_scores = torch.exp(attn_weights - logits_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer # TODO we are going with this one to fix gradients becoming nans - sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) - combined_logits = torch.cat([attn_weights, sinks], dim=-1) - scores = nn.functional.softmax(combined_logits, dim=-1)[..., :-1] + # sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + # combined_logits = torch.cat([attn_weights, sinks], dim=-1) + # scores = nn.functional.softmax(combined_logits, dim=-1)[..., :-1] attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) # ignore the sinks attn_output = attn_output.transpose(1, 2).contiguous() From 15c85e0e32efa42b17e0544c0f3dcb29fc8cb878 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 11 Jul 2025 17:05:03 +0200 Subject: [PATCH 162/342] Update src/transformers/models/openai_moe/modeling_openai_moe.py --- src/transformers/models/openai_moe/modeling_openai_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index e30befb3f2b4..450b943d6ab5 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -239,7 +239,7 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) # TODO: check wether both produce the same results or not! # scale the logits to prevent overflows logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values From ad0fc38fbe900b5fb5227567fb71dea6e68decda Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 17 Jul 2025 12:26:10 +0000 Subject: [PATCH 163/342] gather on experts and not mlp --- src/transformers/models/openai_moe/configuration_openai_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index d5d3e41d63c4..03f553219605 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -50,7 +50,7 @@ class OpenAIMoeConfig(PretrainedConfig): "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.self_attn.sinks": "local_rowwise", - "layers.*.mlp": "gather", + "layers.*.mlp.experts": "gather", "layers.*.mlp.router": "ep_router", "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", From 4fb73451fb929c56a34ff2f1921bc2551cd9b185 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Fri, 18 Jul 2025 07:50:49 +0000 Subject: [PATCH 164/342] add changes for latest convert branch --- src/transformers/models/openai_moe/modeling_openai_moe.py | 4 ++-- src/transformers/models/openai_moe/modular_openai_moe.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 450b943d6ab5..859dadcaf708 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -102,7 +102,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig _, token_idx = torch.where(expert_mask[expert_idx[0]]) current_state = hidden_states[token_idx] gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] - gate, up = gate_up.chunk(2, dim=-1) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] glu = gate * torch.sigmoid(gate * self.alpha) gated_output = (up + 1) * glu out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] @@ -113,7 +113,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig hidden_states = hidden_states.repeat(num_experts, 1) hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] glu = gate * torch.sigmoid(gate * self.alpha) next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 4df8e552873d..a0efcb4806a3 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -96,7 +96,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig gate_up = ( current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] ) # (num_tokens, 2 * interm_dim) - gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim) gated_output = (up + 1) * glu # (num_tokens, interm_dim) out = ( @@ -109,7 +109,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig hidden_states = hidden_states.repeat(num_experts, 1) hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + gate, up = gate_up[..., ::2], gate_up[..., 1::2] glu = gate * torch.sigmoid(gate * self.alpha) next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] From 968238cab9886e33c9a7e607f77648ca4c80690c Mon Sep 17 00:00:00 2001 From: edbeeching Date: Fri, 18 Jul 2025 09:10:45 +0000 Subject: [PATCH 165/342] adds options to get output_router_logits from config --- src/transformers/models/openai_moe/modeling_openai_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 859dadcaf708..e9cc409eed8d 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -666,7 +666,9 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, From 4bc55572211bd8b666ed80ca80e8aad7d86cb4ad Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Tue, 22 Jul 2025 14:29:28 +0200 Subject: [PATCH 166/342] bring chat temlate + special tokens back into the script. --- .../convert_openai_weights_to_hf.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 56a7ada7aa67..f9c34883598e 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -301,11 +301,38 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): <|start|>{% if m['role'] == 'tool' %}{{ m['name'] }}{% else %}{{ m['role'] }}{% if m.get('name') %}:{{ m['name'] }}{% endif %}{% endif %}{% if m.get('recipient') and m['recipient'] != 'all' %} to={{ m['recipient'] }}{% endif %}{% if m.get('channel') %}<|channel|>{{ m['channel'] }}{% endif %}{% if m.get('content_type') %} {{ m['content_type'] }}{% endif %}<|message|> {%- endmacro -%} +{# Add CoT dropping logic -------------------------------------------- #} +{%- set last_final_idx = None -%} +{%- for idx in range(messages|length) -%} + {%- set m = messages[idx] -%} + {%- if m['role'] == 'assistant' and m.get('channel') == 'final' -%} + {%- set last_final_idx = idx -%} + {%- endif -%} +{%- endfor -%} +{%- set last_user_idx = None -%} +{%- if last_final_idx is not none -%} + {%- for idx in range(last_final_idx - 1, -1, -1) -%} + {%- if messages[idx]['role'] == 'user' -%} + {%- set last_user_idx = idx -%} + {%- break -%} + {%- endif -%} + {%- endfor -%} +{%- endif -%} + {# --------------------------------------------------------------------- - Render complete history + Render complete history (with CoT dropping) #} -{%- for message in messages -%} - {{- harmony_header(message) -}}{{ message['content'] }}{%- if message['role'] == 'assistant' -%}<|return|>{%- else -%}<|end|>{%- endif -%} +{%- for idx in range(messages|length) -%} + {%- set message = messages[idx] -%} + {%- set skip = false -%} + {%- if last_final_idx is not none and idx < last_final_idx and (last_user_idx is none or idx > last_user_idx) -%} + {%- if message['role'] == 'assistant' and message.get('channel') != 'final' -%} + {%- set skip = true -%} + {%- endif -%} + {%- endif -%} + {%- if not skip -%} + {{- harmony_header(message) -}}{{ message['content'] }}{%- if message['role'] == 'assistant' and message.get('channel') == 'final' -%}<|return|>{%- else -%}<|end|>{%- endif -%} + {%- endif -%} {%- endfor -%} {# --------------------------------------------------------------------- From 07bd34d44ea1b31eba18b92dd65a102ab324b87d Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 22 Jul 2025 14:59:06 +0000 Subject: [PATCH 167/342] initial commmit --- src/transformers/modeling_utils.py | 2 +- .../convert_openai_weights_to_hf.py | 118 +++++++++++++++--- 2 files changed, 104 insertions(+), 16 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ce514799d067..31ddc23d2576 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2151,7 +2151,7 @@ def post_init(self): f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}" ) - if is_torch_greater_or_equal("2.5") and _torch_distributed_available and self.config.device_mesh is not None: + if is_torch_greater_or_equal("2.5") and _torch_distributed_available and hasattr(self.config, "device_mesh") and self.config.device_mesh is not None: # loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit device_mesh = self.config.device_mesh for name, module in self.named_modules(): diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index f9c34883598e..f93bcb0dd5df 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -74,12 +74,63 @@ def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None): output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) return output_dict +FP4_VALUES = [ + +0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +] + +def convert_moe_packed_tensors( + blocks, + scales, + *, + dtype: torch.dtype = torch.bfloat16, + rows_per_chunk: int = 32768 * 1024, +) -> torch.Tensor: + import math + scales = scales.to(torch.int32) - 127 + + assert blocks.shape[:-1] == scales.shape, ( + f"{blocks.shape=} does not match {scales.shape=}" + ) + + lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) + + *prefix_shape, G, B = blocks.shape + rows_total = math.prod(prefix_shape) * G + + blocks = blocks.reshape(rows_total, B) + scales = scales.reshape(rows_total, 1) + + out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + del idx_lo, idx_hi, blk, exp + + out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + # to match for now existing implementation + return out.to(torch.float8_e5m2) + def write_model( model_path, input_base_path, safe_serialization=True, instruct=False, + unpack=True, ): os.makedirs(model_path, exist_ok=True) bos_token_id = 128000 @@ -99,14 +150,14 @@ def write_model( } config = OpenAIMoeConfig(num_local_experts=num_local_experts, rope_scaling=rope_scaling, **original_config) - + print(config) print(f"Fetching all parameters from the checkpoint at {input_base_path}...") final_ = {} for file in list(os.listdir(input_base_path)): if file.endswith(".safetensors"): final_.update(safe_load(os.path.join(input_base_path, file))) - print("Converting ..") + print("Converting ..", unpack) all_keys = final_.keys() new_keys = convert_old_keys_to_new_keys(all_keys) @@ -132,7 +183,28 @@ def write_model( state_dict[k_key] = k.contiguous().to(torch.bfloat16) state_dict[v_key] = v.contiguous().to(torch.bfloat16) elif re.search("gate_up_proj|down_proj", new_key) and "bias" not in new_key: - state_dict[new_key] = final_[key].permute(0, 2, 1).contiguous() # einsum in orignal, I use bmm + if unpack: + if "scales" in new_key: + continue + elif "blocks" in new_key: + # deal with packed weights + blocks = final_[key] + scales = final_[key.replace("blocks", "scales")] + new_key = new_key.replace(".blocks","") + unpacked_tensors = convert_moe_packed_tensors(blocks, scales, dtype=torch.bfloat16) + unpacked_tensors = unpacked_tensors.permute(0, 2, 1).contiguous() # einsum in orignal, I use bmm + state_dict[new_key] = unpacked_tensors + else: + raise(f"Unidentified {key}, please double check the state dict") + else: + if "scales" in new_key: + new_key = new_key.replace(".scales", "_scales") + state_dict[new_key] = final_[key].contiguous() + elif "blocks" in new_key: + new_key = new_key.replace(".blocks", "_blocks") + state_dict[new_key] = final_[key].contiguous() + else: + raise(f"Unidentified {key}, please double check the state dict") else: weight = final_[key] if not re.search("norm", new_key): @@ -142,16 +214,24 @@ def write_model( del final_ gc.collect() - print("Loading the checkpoint in a OpenAIMoe model") - with torch.device("meta"): - model = OpenAIMoeForCausalLM(config) - model.load_state_dict(state_dict, strict=True, assign=True) - print("Checkpoint loaded successfully.") - del config._name_or_path - - print("Saving the model") - model.save_pretrained(model_path, safe_serialization=safe_serialization) - del state_dict, model + if unpack: + print("Loading the checkpoint in a OpenAIMoe model for unpacked format") + with torch.device("meta"): + model = OpenAIMoeForCausalLM(config) + model.load_state_dict(state_dict, strict=True, assign=True) + print("Checkpoint loaded successfully.") + del config._name_or_path + + print("Saving the model") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + del state_dict, model + + else: + print("Saving the checkpoint in packed format") + from safetensors.torch import save_file + config.quantization_config = {"quant_method": "mxfp4"} + config.save_pretrained(model_path) + save_file(state_dict, os.path.join(model_path, "model.safetensors")) # Safety check: reload the converted model gc.collect() @@ -362,12 +442,12 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument( "--input_dir", - default="/fsx/arthur/oai", + default="/fsx/mohamed/oai-hf/tests/20b", help="Location of LLaMA weights, which contains tokenizer.model and model folders", ) parser.add_argument( "--output_dir", - default="/fsx/arthur/oai_hf", + default="/fsx/mohamed/oai-hf/tests/20b_converted_packed", help="Location to write HF model and tokenizer", ) parser.add_argument( @@ -385,12 +465,20 @@ def main(): action="store_true", help="Whether the model is an instruct model", ) + + parser.add_argument( + "--unpack", + action="store_true", + help="Whether to unpack the model or keep the scales as in the original format. Defaults to True if not specified.", + ) + args = parser.parse_args() write_model( model_path=args.output_dir, input_base_path=args.input_dir, safe_serialization=args.safe_serialization, instruct=args.instruct, + unpack=args.unpack, ) write_tokenizer( From b7987d2e433afed34c039189cc3c787846a09b9f Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 22 Jul 2025 15:07:10 +0000 Subject: [PATCH 168/342] update --- .../models/openai_moe/convert_openai_weights_to_hf.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index f93bcb0dd5df..f502a997b167 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -229,7 +229,14 @@ def write_model( else: print("Saving the checkpoint in packed format") from safetensors.torch import save_file - config.quantization_config = {"quant_method": "mxfp4"} + config.quantization_config = { + "quant_method": "mxfp4", + "modules_to_not_convert":[ + "model.layers.*.self_attn", + "model.layers.*.mlp.router", + "model.embed_tokens", + "lm_head" + ]} config.save_pretrained(model_path) save_file(state_dict, os.path.join(model_path, "model.safetensors")) From 2c0fd4d36bf7014479723d566ea4686bfe57be00 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 22 Jul 2025 15:45:23 +0000 Subject: [PATCH 169/342] working with shards --- .../convert_openai_weights_to_hf.py | 40 ++++++++++++++++--- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index f502a997b167..1b7a312956ed 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -228,7 +228,6 @@ def write_model( else: print("Saving the checkpoint in packed format") - from safetensors.torch import save_file config.quantization_config = { "quant_method": "mxfp4", "modules_to_not_convert":[ @@ -238,13 +237,16 @@ def write_model( "lm_head" ]} config.save_pretrained(model_path) - save_file(state_dict, os.path.join(model_path, "model.safetensors")) + save_sharded_model(state_dict, model_path) + del state_dict # Safety check: reload the converted model gc.collect() - print("Reloading the model to check if it's saved correctly.") - OpenAIMoeForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") - print("Model reloaded successfully.") + # TODO: remove when mxfp4 pr is merged + if unpack: + print("Reloading the model to check if it's saved correctly.") + OpenAIMoeForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") + print("Model reloaded successfully.") # generation config if instruct: @@ -260,6 +262,33 @@ def write_model( generation_config.save_pretrained(model_path) +def save_sharded_model(state_dict, model_path): + import math + from safetensors.torch import save_file + + max_shard_size = 4800000000 # 4.8 GB + os.makedirs(model_path, exist_ok=True) + shard_size_counter = 0 + shard_id = 0 + shard_state_dict = {} + total_sharded_dict = {} + for key in state_dict.keys(): + size = state_dict[key].numel()*state_dict[key].element_size() + if shard_size_counter + size > max_shard_size: + total_sharded_dict[shard_id] = shard_state_dict + shard_id += 1 + shard_size_counter = 0 + shard_state_dict = {} + shard_state_dict[key] = state_dict[key] + shard_size_counter += size + total_sharded_dict[shard_id] = shard_state_dict + num_shards = len(total_sharded_dict) - 1 + for shard_id, shard_state_dict in total_sharded_dict.items(): + save_file( + shard_state_dict, + os.path.join(model_path, f"model-{shard_id:05d}-of-{num_shards:05d}.safetensors") + ) + # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode def bytes_to_unicode(): """ @@ -359,7 +388,6 @@ def __init__( **kwargs, ) - def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): # Updated Harmony chat template chat_template = """{# Harmony chat template -------------------------------------------------- From 1d03f3ac527ba1a500e866a7dd5fd802ecfcfcb4 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 22 Jul 2025 15:54:32 +0000 Subject: [PATCH 170/342] add model.safetensors.index.json --- .../openai_moe/convert_openai_weights_to_hf.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 1b7a312956ed..6756ae06df11 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -263,7 +263,6 @@ def write_model( def save_sharded_model(state_dict, model_path): - import math from safetensors.torch import save_file max_shard_size = 4800000000 # 4.8 GB @@ -272,8 +271,13 @@ def save_sharded_model(state_dict, model_path): shard_id = 0 shard_state_dict = {} total_sharded_dict = {} + safetensors_index = {} + safetensors_index["metadata"] = {"total_size": 0} + safetensors_index["weight_map"] = {} for key in state_dict.keys(): size = state_dict[key].numel()*state_dict[key].element_size() + safetensors_index["metadata"]["total_size"] += size + safetensors_index["weight_map"][key] = shard_id if shard_size_counter + size > max_shard_size: total_sharded_dict[shard_id] = shard_state_dict shard_id += 1 @@ -288,6 +292,14 @@ def save_sharded_model(state_dict, model_path): shard_state_dict, os.path.join(model_path, f"model-{shard_id:05d}-of-{num_shards:05d}.safetensors") ) + create_safetensors_index(safetensors_index, num_shards, model_path) + +def create_safetensors_index(safetensors_index, num_shards, model_path): + for shard_id in range(num_shards): + safetensors_index["weight_map"][f"model-{shard_id:05d}-of-{num_shards:05d}.safetensors"] = shard_id + with open(os.path.join(model_path, "model.safetensors.index.json"), "w") as f: + json.dump(safetensors_index, f) + # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode def bytes_to_unicode(): @@ -477,12 +489,12 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument( "--input_dir", - default="/fsx/mohamed/oai-hf/tests/20b", + default="/fsx/mohamed/oai-hf/tests/120b", help="Location of LLaMA weights, which contains tokenizer.model and model folders", ) parser.add_argument( "--output_dir", - default="/fsx/mohamed/oai-hf/tests/20b_converted_packed", + default="/fsx/mohamed/oai-hf/tests/120b_converted_packed", help="Location to write HF model and tokenizer", ) parser.add_argument( From 40e379d1e6c470ce105938bf845ee4867e9f5952 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 22 Jul 2025 15:58:17 +0000 Subject: [PATCH 171/342] fix --- .../models/openai_moe/convert_openai_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 6756ae06df11..714ab2cdd330 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -296,7 +296,7 @@ def save_sharded_model(state_dict, model_path): def create_safetensors_index(safetensors_index, num_shards, model_path): for shard_id in range(num_shards): - safetensors_index["weight_map"][f"model-{shard_id:05d}-of-{num_shards:05d}.safetensors"] = shard_id + safetensors_index["weight_map"][shard_id] = f"model-{shard_id:05d}-of-{num_shards:05d}.safetensors" with open(os.path.join(model_path, "model.safetensors.index.json"), "w") as f: json.dump(safetensors_index, f) From b68aa6b4237b48072d04f56f489ae9b558294177 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 22 Jul 2025 16:03:53 +0000 Subject: [PATCH 172/342] fix --- .../models/openai_moe/convert_openai_weights_to_hf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 714ab2cdd330..ddf398048f04 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -295,8 +295,9 @@ def save_sharded_model(state_dict, model_path): create_safetensors_index(safetensors_index, num_shards, model_path) def create_safetensors_index(safetensors_index, num_shards, model_path): - for shard_id in range(num_shards): - safetensors_index["weight_map"][shard_id] = f"model-{shard_id:05d}-of-{num_shards:05d}.safetensors" + for key in safetensors_index["weight_map"].keys(): + shard_id = safetensors_index["weight_map"][key] + safetensors_index["weight_map"][key] = f"model-{shard_id:05d}-of-{num_shards:05d}.safetensors" with open(os.path.join(model_path, "model.safetensors.index.json"), "w") as f: json.dump(safetensors_index, f) From a87db4f48d83241de4f1aa4351c867cde45ccf97 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 22 Jul 2025 16:18:19 +0000 Subject: [PATCH 173/342] mxfp4 flag --- .../openai_moe/convert_openai_weights_to_hf.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index ddf398048f04..a1cdbf82a8a8 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -130,7 +130,7 @@ def write_model( input_base_path, safe_serialization=True, instruct=False, - unpack=True, + mxfp4=False, ): os.makedirs(model_path, exist_ok=True) bos_token_id = 128000 @@ -157,7 +157,7 @@ def write_model( if file.endswith(".safetensors"): final_.update(safe_load(os.path.join(input_base_path, file))) - print("Converting ..", unpack) + print("Converting ..", mxfp4) all_keys = final_.keys() new_keys = convert_old_keys_to_new_keys(all_keys) @@ -183,7 +183,7 @@ def write_model( state_dict[k_key] = k.contiguous().to(torch.bfloat16) state_dict[v_key] = v.contiguous().to(torch.bfloat16) elif re.search("gate_up_proj|down_proj", new_key) and "bias" not in new_key: - if unpack: + if not mxfp4: if "scales" in new_key: continue elif "blocks" in new_key: @@ -214,7 +214,7 @@ def write_model( del final_ gc.collect() - if unpack: + if not mxfp4: print("Loading the checkpoint in a OpenAIMoe model for unpacked format") with torch.device("meta"): model = OpenAIMoeForCausalLM(config) @@ -227,7 +227,7 @@ def write_model( del state_dict, model else: - print("Saving the checkpoint in packed format") + print("Saving the checkpoint in mxfp4 format") config.quantization_config = { "quant_method": "mxfp4", "modules_to_not_convert":[ @@ -243,7 +243,7 @@ def write_model( # Safety check: reload the converted model gc.collect() # TODO: remove when mxfp4 pr is merged - if unpack: + if not mxfp4: print("Reloading the model to check if it's saved correctly.") OpenAIMoeForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") print("Model reloaded successfully.") @@ -515,9 +515,9 @@ def main(): ) parser.add_argument( - "--unpack", + "--mxfp4", action="store_true", - help="Whether to unpack the model or keep the scales as in the original format. Defaults to True if not specified.", + help="Whether to use the original model with mxfp4 quantization or default to the full precision model.", ) args = parser.parse_args() @@ -526,7 +526,7 @@ def main(): input_base_path=args.input_dir, safe_serialization=args.safe_serialization, instruct=args.instruct, - unpack=args.unpack, + mxfp4=args.mxfp4, ) write_tokenizer( From c3c01f0703170c44b14a3f58daaeb5523bc2b779 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 22 Jul 2025 16:24:13 +0000 Subject: [PATCH 174/342] rm print --- .../models/openai_moe/convert_openai_weights_to_hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index a1cdbf82a8a8..3abb58755e57 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -150,14 +150,14 @@ def write_model( } config = OpenAIMoeConfig(num_local_experts=num_local_experts, rope_scaling=rope_scaling, **original_config) - print(config) + print(f"Fetching all parameters from the checkpoint at {input_base_path}...") final_ = {} for file in list(os.listdir(input_base_path)): if file.endswith(".safetensors"): final_.update(safe_load(os.path.join(input_base_path, file))) - print("Converting ..", mxfp4) + print("Converting ..") all_keys = final_.keys() new_keys = convert_old_keys_to_new_keys(all_keys) From 863630d99f9930f0fc8528de5ceb6ccc5d4b4e2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 10 Jul 2025 15:04:06 -0700 Subject: [PATCH 175/342] Fix PAD/EOS/BOS (#18) * fix pad/eos/bos * base model maybe one day --- .../models/openai_moe/configuration_openai_moe.py | 6 ------ .../models/openai_moe/convert_openai_weights_to_hf.py | 9 +++------ 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 03f553219605..581a623280c4 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -75,9 +75,6 @@ def __init__( initializer_range: float = 0.02, max_position_embeddings=131072, rms_norm_eps: float = 1e-5, - pad_token_id: int = 0, - bos_token_id: int = 1, - eos_token_id: int = 2, rope_scaling={"rope_type": "yarn", "factor": 32.0, "beta_fast": 32.0, "beta_slow": 1.0, "truncate": False}, attention_dropout: float = 0.0, num_experts_per_tok=4, @@ -127,9 +124,6 @@ def __init__( self.output_router_logits = output_router_logits self.use_cache = use_cache super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 3abb58755e57..b258c8e2a981 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -133,9 +133,7 @@ def write_model( mxfp4=False, ): os.makedirs(model_path, exist_ok=True) - bos_token_id = 128000 - eos_token_id = 199999 if not instruct else [199999, 200018] - pad_token_id = 128004 + eos_token_id = 199999 if not instruct else 200002 original_config = json.loads((Path(input_base_path) / "config.json").read_text()) @@ -149,7 +147,7 @@ def write_model( "original_max_position_embeddings": 4096 } - config = OpenAIMoeConfig(num_local_experts=num_local_experts, rope_scaling=rope_scaling, **original_config) + config = OpenAIMoeConfig(num_local_experts=num_local_experts, rope_scaling=rope_scaling, eos_token_id=eos_token_id, **original_config) print(f"Fetching all parameters from the checkpoint at {input_base_path}...") final_ = {} @@ -255,9 +253,7 @@ def write_model( do_sample=True, temperature=0.6, top_p=0.9, - bos_token_id=bos_token_id, eos_token_id=eos_token_id, - pad_token_id=pad_token_id, ) generation_config.save_pretrained(model_path) @@ -396,6 +392,7 @@ def __init__( kwargs["chat_template"] = chat_template self.tokenizer = PreTrainedTokenizerFast( tokenizer_object=tokenizer, + eos_token="<|return|>" if chat_template else "<|endoftext|>", model_input_names=["input_ids", "attention_mask"], model_max_length=model_max_length, **kwargs, From eab251f7fc5e4cacc2e77363c4ab9776b4ced908 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 23 Jul 2025 10:41:21 +0000 Subject: [PATCH 176/342] add some doc --- .../models/openai_moe/convert_openai_weights_to_hf.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index b258c8e2a981..49343c703f46 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -511,6 +511,11 @@ def main(): help="Whether the model is an instruct model", ) + # Only specify this if you want to use the model with mxfp4 quantization + # It means the model will be unpacked, and quantized using mxfp4 during inference if all the triton requirements are satisfied (triton >= 3.4.0) + # Else we have a fallback to the full precision model (bfloat16) + # If not specified, the model will be unpacked during conversion, and will be in fp8/bfloat16 during inference + # Note: mxfp4 should bring an important speedup in inference time with blackwell gpus parser.add_argument( "--mxfp4", action="store_true", From 9280e5903f0a5f879e3ac1ec5341d26bc23b347a Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Wed, 23 Jul 2025 15:04:30 +0200 Subject: [PATCH 177/342] special tokens based on harmony. --- .../models/openai_moe/convert_openai_weights_to_hf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 49343c703f46..fed38f04ebf4 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -250,10 +250,12 @@ def write_model( if instruct: print("Saving generation config...") generation_config = GenerationConfig( + bos_token_id=199998, # <|startoftext|> do_sample=True, + eos_token_id=[200002, 199999], # <|return|>, <|endoftext|> + pad_token_id=199999, # <|endoftext|> temperature=0.6, top_p=0.9, - eos_token_id=eos_token_id, ) generation_config.save_pretrained(model_path) From b382c5e0b6e9a90008bc2db91c8b699b2ab2b5ab Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Wed, 23 Jul 2025 15:14:20 +0200 Subject: [PATCH 178/342] add in tokenizer config as well. --- .../models/openai_moe/convert_openai_weights_to_hf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index fed38f04ebf4..c8daf8a43ca5 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -394,7 +394,9 @@ def __init__( kwargs["chat_template"] = chat_template self.tokenizer = PreTrainedTokenizerFast( tokenizer_object=tokenizer, + bos_token="<|startoftext|>", eos_token="<|return|>" if chat_template else "<|endoftext|>", + pad_token="<|endoftext|>", model_input_names=["input_ids", "attention_mask"], model_max_length=model_max_length, **kwargs, From f8f3e40a0d01538baceb69648dfb7a5adf8d6533 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 24 Jul 2025 10:52:26 +0200 Subject: [PATCH 179/342] prepare for rebase with main --- .../modeling_flash_attention_utils.py | 19 ++----------------- src/transformers/modeling_utils.py | 6 ++---- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 1989a0d4abc1..46a48fb11485 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -26,15 +26,10 @@ is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_torch_npu_available, - is_kernels_available, logging, ) -if is_kernels_available(): - from kernels import get_kernel - - logger = logging.get_logger(__name__) flash_attn_func = None @@ -411,7 +406,7 @@ def fa_peft_integration_check( flash_241 = is_flash_attn_greater_or_equal("2.4.1") deterministic_g = None -kernel = None + def _flash_attention_forward( query_states: torch.Tensor, @@ -480,16 +475,6 @@ def _flash_attention_forward( _pad_input = pad_input_fa2 _unpad_input = unpad_input_fa2 _is_fa3 = False - elif "kernel" in attn_implementation: - repo_id = attn_implementation.replace("kernel_", "").replace("_", "/") - global kernel - if kernel is None: - kernel = get_kernel(repo_id) - _flash_attn_varlen_func = getattr(kernel, "flash_attn_varlen_func") - _flash_attn_func = getattr(kernel, "flash_attn_func") - _pad_input = pad_input_fa2 - _unpad_input = unpad_input_fa2 - _is_fa3 = True if not use_top_left_mask: causal = is_causal @@ -502,7 +487,7 @@ def _flash_attention_forward( _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window ) flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} - flash_kwargs["s_aux"] = kwargs["s_aux"] + if _is_fa3: if dropout > 0.0: logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.") diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ce514799d067..2a94743c2cb4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2310,8 +2310,7 @@ def _autoset_attn_implementation( try: kernel = get_kernel(repo_id) ALL_ATTENTION_FUNCTIONS.register( - f"kernel_{repo_id.replace('/', '_')}", - ALL_ATTENTION_FUNCTIONS["flash_attention_2"], # we need our extra layer to support kwargs + f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name) ) config._attn_implementation = f"kernel_{repo_id.replace('/', '_')}" except FileNotFoundError as e: @@ -4944,9 +4943,8 @@ def _assign_original_dtype(module): # check if using kernels if use_kernels: from kernels import Device, kernelize - from kernels.layer import Mode - kernelize(model, mode=Mode.TRAINING, device=Device(type=model.device.type)) + kernelize(model, device=Device(type=model.device.type)) # If it is a model with generation capabilities, attempt to load generation files (generation config, # custom generate function) From 60af84194641badf7b4c6e6dfd5d699c3dee548b Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Thu, 24 Jul 2025 11:06:33 +0200 Subject: [PATCH 180/342] Fix for initialize_tensor_parallelism now returning 4-tuple ``` [rank0]: File "/fsx/edward/work/openai-tsm-examples/examples/generate.py", line 17, in [rank0]: model = AutoModelForCausalLM.from_pretrained( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/edward/work/new-model-addition-openai/src/transformers/models/auto/auto_factory.py", line 600, in from_pretrained [rank0]: return model_class.from_pretrained( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/edward/work/new-model-addition-openai/src/transformers/modeling_utils.py", line 316, in _wrapper [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/edward/work/new-model-addition-openai/src/transformers/modeling_utils.py", line 4748, in from_pretrained [rank0]: tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: ValueError: too many values to unpack (expected 3) ``` --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5afcf000c6b3..08e28dbe47e1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4745,7 +4745,7 @@ def from_pretrained( # `device_map` pointing to the correct device if tp_plan is not None: if device_mesh is None: - tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None) + tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=None) else: if "tp" not in device_mesh.mesh_dim_names: raise ValueError( From 1ce172b47025ec93de9cda3bb69c00b8b22935ab Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 10 Jun 2025 14:28:20 +0000 Subject: [PATCH 181/342] mxfp4 --- src/transformers/integrations/__init__.py | 1 + src/transformers/integrations/mxfp4.py | 182 +++++++++++++++++ .../quantizers/quantizer_mxfp4.py | 187 ++++++++++++++++++ src/transformers/utils/quantization_config.py | 21 ++ 4 files changed, 391 insertions(+) create mode 100644 src/transformers/integrations/mxfp4.py create mode 100644 src/transformers/quantizers/quantizer_mxfp4.py diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 0c4d169380b5..d3f0c983eb70 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -119,6 +119,7 @@ "run_hp_search_sigopt", "run_hp_search_wandb", ], + "mxpf4":["replace_with_mxfp4_linear"], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], "spqr": ["replace_with_spqr_linear"], diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py new file mode 100644 index 000000000000..d7711782f2d3 --- /dev/null +++ b/src/transformers/integrations/mxfp4.py @@ -0,0 +1,182 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..activations import ACT2FN +from ..utils import is_accelerate_available, is_torch_available, logging + +if is_torch_available(): + import torch + from torch import nn + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +logger = logging.get_logger(__name__) + +class Mxfp4Linear(torch.nn.Linear): + def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32): + super().__init__(in_features, out_features, bias) + self.in_features = in_features + self.out_features = out_features + + # dtype torch.float4_e2m1fn not supported yet + self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn)) + self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype)) + + if bias: + self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype)) + else: + self.bias = None + + def forward(self, x): + """ + update + """ + return + +class Mxfp4OaiTextExperts(nn.Module): + def __init__(self, config, dtype=torch.float32): + super().__init__() + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.act_fn = ACT2FN[config.hidden_act] + + # + self.gate_up_proj = torch.nn.Parameter( + torch.zeros((self.num_experts, self.hidden_size, 2 * self.expert_dim), dtype=torch.float8_e4m3fn) + ) + # self.gate_up_proj_scale = torch.nn.Parameter( + # torch.zeros((self.num_experts, 1, self.expert_dim * 2), dtype=torch.float32) + # ) + self.down_proj = torch.nn.Parameter( + torch.zeros((self.num_experts, self.expert_dim, self.hidden_size), dtype=torch.float8_e4m3fn) + ) + # self.down_proj_scale = torch.nn.Parameter( + # torch.zeros((self.num_experts, self.hidden_size, 1), dtype=torch.float32) + # ) + + def forward(self, hidden_states): + # Reshape hidden states for expert computation + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + num_tokens = None + + # Pre-allocate tensor for all expert outputs with same shape as hidden_states + next_states = torch.empty_like(hidden_states) + + for i in range(self.num_experts): + expert_hidden = hidden_states[i] + expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size) + """ + Update + + """ + # next_states[i] = ... + next_states = next_states.to(hidden_states.device) + return next_states.view(-1, self.hidden_size) + +def _replace_with_mxfp4_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, + pre_quantized=False, + config=None, + tp_plan=None, +): + import re + + if current_key_name is None: + current_key_name = [] + + for name, module in model.named_children(): + current_key_name.append(name) + + if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert: + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + with init_empty_weights(include_buffers=True): + in_features = module.in_features + out_features = module.out_features + model._modules[name] = Mxfp4Linear( + in_features, + out_features, + module.bias is not None, + ) + has_been_replaced = True + model._modules[name].requires_grad_(False) + if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert: + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + with init_empty_weights(include_buffers=True): + tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None + model._modules[name] = Mxfp4OaiTextExperts( + config.text_config, + ) + model._modules[name].input_scale_ub = torch.tensor( + [quantization_config.activation_scale_ub], dtype=torch.float + ) + + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_mxfp4_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + pre_quantized=pre_quantized, + config=config, + tp_plan=tp_plan, + ) + current_key_name.pop(-1) + return model, has_been_replaced + +def replace_with_mxfp4_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + pre_quantized=False, + config=None, + tp_plan=None, +): + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + + if quantization_config.modules_to_not_convert is not None: + modules_to_not_convert.extend(quantization_config.modules_to_not_convert) + modules_to_not_convert = list(set(modules_to_not_convert)) + model, has_been_replaced = _replace_with_mxfp4_linear( + model, + modules_to_not_convert, + current_key_name, + quantization_config, + pre_quantized=pre_quantized, + config=config, + tp_plan=tp_plan, + ) + if not has_been_replaced: + logger.warning( + "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model \ No newline at end of file diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py new file mode 100644 index 000000000000..d30b10e87845 --- /dev/null +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -0,0 +1,187 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from .base import HfQuantizer + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging +from .quantizers_utils import get_module_from_name + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class Mxfp4HfQuantizer(HfQuantizer): + """ + FP4 quantization using fbgemm kernels + """ + + requires_parameters_quantization = True + requires_calibration = False + + required_packages = ["accelerate"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + self.quantization_config = quantization_config + + def validate_environment(self, *args, **kwargs): + if not is_torch_available(): + raise ImportError( + "Using mxfp4 quantization requires torch" + "Please install the latest version of torch ( pip install --upgrade torch )" + ) + if not torch.cuda.is_available(): + raise RuntimeError("Using MXFP4 quantized models requires a GPU") + + compute_capability = torch.cuda.get_device_capability() + major, minor = compute_capability + if major < 9: + raise ValueError( + "FP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)" + ) + + device_map = kwargs.get("device_map", None) + if device_map is None: + logger.warning_once( + "You have loaded an FP4 model on CPU and have a CUDA device available, make sure to set " + "your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. " + ) + elif device_map is not None: + if ( + not self.pre_quantized + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + raise ValueError( + "You are attempting to load an FP4 model with a device_map that contains a CPU or disk device." + "This is not supported when the model is quantized on the fly. " + "Please use a quantized checkpoint or remove the CPU or disk device from the device_map." + ) + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + torch_dtype = torch.bfloat16 + logger.info( + "Overriding torch_dtype=%s with `torch_dtype=torch.bfloat16` due to " + "requirements of `fbgemm-gpu` to enable model loading in fp4. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.bfloat16 to remove this warning.", + torch_dtype, + ) + return torch_dtype + + def check_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ): + from ..integrations import Mxfp4Linear, Mxfp4OaiTextExperts + + module, tensor_name = get_module_from_name(model, param_name) + + if isinstance(module, Mxfp4Linear): + if self.pre_quantized or tensor_name == "bias": + if tensor_name == "weight" and param_value.dtype != torch.float4_e2m1fn: + raise ValueError("Expect quantized weights but got an unquantized weight") + return False + return True + if isinstance(module, Mxfp4OaiTextExperts): + if self.pre_quantized or tensor_name == "bias": + return False + else: + if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale": + raise ValueError("Expect unquantized weights but got a quantized weight_scale") + return True + return False + + def create_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: Optional[List[str]] = None, + ): + """ + Quantizes weights into weight and weight_scale + """ + + pass + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + return model + + def _process_model_before_weight_loading( + self, + model: "PreTrainedModel", + keep_in_fp32_modules: Optional[List[str]] = None, + **kwargs, + ): + from ..integrations import replace_with_mxfp4_linear + + tp_plan = model._tp_plan + self.modules_to_not_convert = self.get_modules_to_not_convert( + model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules + ) + + config = model.config + model = replace_with_mxfp4_linear( + model, + modules_to_not_convert=self.modules_to_not_convert, + quantization_config=self.quantization_config, + pre_quantized=self.pre_quantized, + config=config, + tp_plan=tp_plan, + ) + + model.config.quantization_config = self.quantization_config + + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: + from ..integrations import Mxfp4Linear, Mxfp4Llama4TextExperts + + not_missing_keys = [] + for name, module in model.named_modules(): + if isinstance(module, Mxfp4Linear) or isinstance(module, Mxfp4Llama4TextExperts): + for missing in missing_keys: + if ( + (name in missing or name in f"{prefix}.{missing}") + and not missing.endswith(".weight") + and not missing.endswith(".bias") + ): + not_missing_keys.append(missing) + return [k for k in missing_keys if k not in not_missing_keys] + + def update_tp_plan(self, config): + if "OAI" in config.__class__.__name__: + # config.base_model_tp_plan = text_plan + return config + return config + + def is_serializable(self, safe_serialization=None): + return True + + @property + def is_trainable(self) -> bool: + return False + \ No newline at end of file diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 0bc616c6ff9c..ac7aeeb4b465 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -65,6 +65,7 @@ class QuantizationMethod(str, Enum): QUARK = "quark" FPQUANT = "fp_quant" AUTOROUND = "auto-round" + MXFP4 = "mxfp4" class AWQLinearVersion(str, Enum): @@ -2048,3 +2049,23 @@ def __init__( self.json_export_config = JsonExporterConfig() self.quant_method = QuantizationMethod.QUARK + +@dataclass +class Mxfp4Config(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using mxfp4 quantization. + + Args: + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision. + """ + + def __init__( + self, + modules_to_not_convert: Optional[List] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.MXFP4 + self.modules_to_not_convert = modules_to_not_convert \ No newline at end of file From c0bee22217a9cf87ea3d5f9fcd2b6a45c3c78fd2 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 12 Jun 2025 09:58:33 +0000 Subject: [PATCH 182/342] mxfp4 draft --- src/transformers/integrations/__init__.py | 3 +- src/transformers/integrations/mxfp4.py | 156 ++++++++++-------- .../integrations/tensor_parallel.py | 2 + src/transformers/modeling_utils.py | 5 +- src/transformers/quantizers/auto.py | 4 + .../quantizers/quantizer_mxfp4.py | 34 ++-- src/transformers/utils/quantization_config.py | 3 +- 7 files changed, 122 insertions(+), 85 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index d3f0c983eb70..5b4e1c17bc57 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -119,7 +119,7 @@ "run_hp_search_sigopt", "run_hp_search_wandb", ], - "mxpf4":["replace_with_mxfp4_linear"], + "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4Linear", "Mxfp4OpenaiExperts"], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], "spqr": ["replace_with_spqr_linear"], @@ -256,6 +256,7 @@ run_hp_search_sigopt, run_hp_search_wandb, ) + from .mxfp4 import replace_with_mxfp4_linear, Mxfp4Linear, Mxfp4OpenaiExperts from .peft import PeftAdapterMixin from .quanto import replace_with_quanto_layers from .spqr import replace_with_spqr_linear diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index d7711782f2d3..cd83130e12a6 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..activations import ACT2FN from ..utils import is_accelerate_available, is_torch_available, logging + if is_torch_available(): import torch from torch import nn @@ -22,21 +22,21 @@ if is_accelerate_available(): from accelerate import init_empty_weights +import re + logger = logging.get_logger(__name__) + class Mxfp4Linear(torch.nn.Linear): def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32): super().__init__(in_features, out_features, bias) - self.in_features = in_features - self.out_features = out_features - # dtype torch.float4_e2m1fn not supported yet - self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn)) - self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype)) + self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e5m2)) + # self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype)) if bias: - self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype)) + self.bias = torch.nn.Parameter(torch.zeros((out_features), dtype=weight_dtype)) else: self.bias = None @@ -46,47 +46,78 @@ def forward(self, x): """ return -class Mxfp4OaiTextExperts(nn.Module): - def __init__(self, config, dtype=torch.float32): + +# maybe subclass +class Mxfp4OpenaiExperts(nn.Module): + def __init__(self, config): super().__init__() self.num_experts = config.num_local_experts self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.act_fn = ACT2FN[config.hidden_act] - - # - self.gate_up_proj = torch.nn.Parameter( - torch.zeros((self.num_experts, self.hidden_size, 2 * self.expert_dim), dtype=torch.float8_e4m3fn) + + # dtype torch.float4_e2m1fn not supported yet + self.gate_up_proj = nn.Parameter( + torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim, dtype=torch.float8_e5m2), + ) + self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim)) + self.down_proj = nn.Parameter( + torch.zeros((self.num_experts, self.expert_dim, self.hidden_size), dtype=torch.float8_e5m2), ) + self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size)) + self.alpha = 1.702 + # self.gate_up_proj_scale = torch.nn.Parameter( - # torch.zeros((self.num_experts, 1, self.expert_dim * 2), dtype=torch.float32) + # torch.zeros((self.num_experts, 1, self.expert_dim * 2)) # ) - self.down_proj = torch.nn.Parameter( - torch.zeros((self.num_experts, self.expert_dim, self.hidden_size), dtype=torch.float8_e4m3fn) - ) # self.down_proj_scale = torch.nn.Parameter( - # torch.zeros((self.num_experts, self.hidden_size, 1), dtype=torch.float32) + # torch.zeros((self.num_experts, self.hidden_size, 1)) # ) - def forward(self, hidden_states): - # Reshape hidden states for expert computation - hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - num_tokens = None - - # Pre-allocate tensor for all expert outputs with same shape as hidden_states - next_states = torch.empty_like(hidden_states) - - for i in range(self.num_experts): - expert_hidden = hidden_states[i] - expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size) - """ - Update - - """ - # next_states[i] = ... - next_states = next_states.to(hidden_states.device) - return next_states.view(-1, self.hidden_size) + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + """ + To update with moe mxfp4 kernels, for now we just upcast the weights in torch.bfloat16 + """ + if self.training: + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) + expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hitted: + with torch.no_grad(): + idx, top_x = torch.where( + expert_mask[expert_idx][0] + ) # idx: top-1/top-2 indicator, top_x: token indices + current_state = hidden_states[top_x] # (num_tokens, hidden_dim) + gate_up = ( + current_state @ self.gate_up_proj[expert_idx].to(torch.bfloat16) + self.gate_up_proj_bias[expert_idx] + ) # (num_tokens, 2 * interm_dim) + gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim) + glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim) + gated_output = (up + 1) * glu # (num_tokens, interm_dim) + out = ( + gated_output @ self.down_proj[expert_idx].to(torch.bfloat16) + self.down_proj_bias[expert_idx] + ) # (num_tokens, hidden_dim) + weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) + next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) + else: + hidden_states = hidden_states.repeat(self.num_experts, 1) + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj.to(torch.bfloat16)) + self.gate_up_proj_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), self.down_proj.to(torch.bfloat16)) + self.down_proj_bias[..., None, :] + next_states = next_states.view(-1, self.hidden_size) + return next_states + +def should_convert_module(current_key_name, patterns): + current_key_name_str = ".".join(current_key_name) + if not any( + re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns + ): + return True + return False + def _replace_with_mxfp4_linear( model, @@ -98,42 +129,30 @@ def _replace_with_mxfp4_linear( config=None, tp_plan=None, ): - import re - if current_key_name is None: current_key_name = [] for name, module in model.named_children(): current_key_name.append(name) - if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert: - current_key_name_str = ".".join(current_key_name) - if not any( - (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert - ): - with init_empty_weights(include_buffers=True): - in_features = module.in_features - out_features = module.out_features - model._modules[name] = Mxfp4Linear( - in_features, - out_features, - module.bias is not None, - ) - has_been_replaced = True - model._modules[name].requires_grad_(False) - if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert: - current_key_name_str = ".".join(current_key_name) - if not any( - (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert - ): - with init_empty_weights(include_buffers=True): - tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None - model._modules[name] = Mxfp4OaiTextExperts( - config.text_config, - ) - model._modules[name].input_scale_ub = torch.tensor( - [quantization_config.activation_scale_ub], dtype=torch.float + if (isinstance(module, nn.Linear)) and should_convert_module(current_key_name, modules_to_not_convert): + with init_empty_weights(): + in_features = module.in_features + out_features = module.out_features + model._modules[name] = Mxfp4Linear( + in_features, + out_features, + module.bias is not None, ) + has_been_replaced = True + model._modules[name].requires_grad_(False) + if module.__class__.__name__ == "OpenaiExperts" and should_convert_module( + current_key_name, modules_to_not_convert + ): + with init_empty_weights(): + # tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None + model._modules[name] = Mxfp4OpenaiExperts(config) + has_been_replaced=True if len(list(module.children())) > 0: _, has_been_replaced = _replace_with_mxfp4_linear( @@ -149,6 +168,7 @@ def _replace_with_mxfp4_linear( current_key_name.pop(-1) return model, has_been_replaced + def replace_with_mxfp4_linear( model, modules_to_not_convert=None, @@ -179,4 +199,4 @@ def replace_with_mxfp4_linear( " a bug." ) - return model \ No newline at end of file + return model diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index d24bc9fb0ca3..456a6dbb952c 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -148,6 +148,8 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> str "F64": torch.float64, "I64": torch.int64, "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, + } diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 08e28dbe47e1..9bc2b894d1fc 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -715,11 +715,13 @@ def _infer_parameter_dtype( else: raise e is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + is_torch_e5m2_available = hasattr(torch, "float8_e5m2") # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params # in int/uint/bool and not cast them. casting_dtype = None is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn - if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: + is_param_float8_e5m2 = is_torch_e5m2_available and empty_param.dtype == torch.float8_e5m2 + if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn and not is_param_float8_e5m2: # First fp32 if part of the exception list if keep_in_fp32_regex is not None and keep_in_fp32_regex.search(param_name): casting_dtype = torch.float32 @@ -789,7 +791,6 @@ def _load_state_dict_into_meta_model( param = file_pointer.get_slice(serialized_param_name) else: param = empty_param.to(tensor_device) # It is actually not empty! - to_contiguous, casting_dtype = _infer_parameter_dtype( model, param_name, diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index e4fbaadb5d76..9d91b8edb696 100644 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -31,6 +31,7 @@ GPTQConfig, HiggsConfig, HqqConfig, + Mxfp4Config, QuantizationConfigMixin, QuantizationMethod, QuantoConfig, @@ -54,6 +55,7 @@ from .quantizer_gptq import GptqHfQuantizer from .quantizer_higgs import HiggsHfQuantizer from .quantizer_hqq import HqqHfQuantizer +from .quantizer_mxfp4 import Mxfp4HfQuantizer from .quantizer_quanto import QuantoHfQuantizer from .quantizer_quark import QuarkHfQuantizer from .quantizer_spqr import SpQRHfQuantizer @@ -81,6 +83,7 @@ "spqr": SpQRHfQuantizer, "fp8": FineGrainedFP8HfQuantizer, "auto-round": AutoRoundQuantizer, + "mxfp4": Mxfp4HfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -103,6 +106,7 @@ "spqr": SpQRConfig, "fp8": FineGrainedFP8Config, "auto-round": AutoRoundConfig, + "mxfp4": Mxfp4Config, } logger = logging.get_logger(__name__) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index d30b10e87845..a94cc8a381ba 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -15,12 +15,14 @@ from .base import HfQuantizer + if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging +from ..utils import is_torch_available, logging, is_accelerate_available from .quantizers_utils import get_module_from_name + if is_torch_available(): import torch @@ -33,7 +35,8 @@ class Mxfp4HfQuantizer(HfQuantizer): """ requires_parameters_quantization = True - requires_calibration = False + # to remove if we decide to allow quantizing weights with this method + requires_calibration = True required_packages = ["accelerate"] @@ -56,6 +59,11 @@ def validate_environment(self, *args, **kwargs): raise ValueError( "FP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)" ) + # TODO: update accelerate version when it is released + if not is_accelerate_available(): + raise ImportError( + f"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=1.8.0'`" + ) device_map = kwargs.get("device_map", None) if device_map is None: @@ -95,18 +103,20 @@ def check_quantized_param( state_dict: Dict[str, Any], **kwargs, ): - from ..integrations import Mxfp4Linear, Mxfp4OaiTextExperts + from ..integrations import Mxfp4Linear, Mxfp4OpenaiExperts module, tensor_name = get_module_from_name(model, param_name) if isinstance(module, Mxfp4Linear): if self.pre_quantized or tensor_name == "bias": - if tensor_name == "weight" and param_value.dtype != torch.float4_e2m1fn: + if tensor_name == "weight" and param_value.dtype != torch.float8_e5m2: raise ValueError("Expect quantized weights but got an unquantized weight") return False return True - if isinstance(module, Mxfp4OaiTextExperts): + if isinstance(module, Mxfp4OpenaiExperts): if self.pre_quantized or tensor_name == "bias": + if (tensor_name == "down_proj" or tensor_name == "gate_up_proj") and param_value.dtype != torch.float8_e5m2: + raise ValueError("Expect quantized weights but got an unquantized weight") return False else: if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale": @@ -126,8 +136,7 @@ def create_quantized_param( """ Quantizes weights into weight and weight_scale """ - - pass + pass def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): return model @@ -158,11 +167,11 @@ def _process_model_before_weight_loading( model.config.quantization_config = self.quantization_config def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: - from ..integrations import Mxfp4Linear, Mxfp4Llama4TextExperts + from ..integrations import Mxfp4Linear, Mxfp4OpenaiExperts not_missing_keys = [] for name, module in model.named_modules(): - if isinstance(module, Mxfp4Linear) or isinstance(module, Mxfp4Llama4TextExperts): + if isinstance(module, Mxfp4Linear) or isinstance(module, Mxfp4OpenaiExperts): for missing in missing_keys: if ( (name in missing or name in f"{prefix}.{missing}") @@ -173,9 +182,9 @@ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> Li return [k for k in missing_keys if k not in not_missing_keys] def update_tp_plan(self, config): - if "OAI" in config.__class__.__name__: - # config.base_model_tp_plan = text_plan - return config + # TODO: for tp support + # if "OpenaiExperts" in config.__class__.__name__: + # return config return config def is_serializable(self, safe_serialization=None): @@ -184,4 +193,3 @@ def is_serializable(self, safe_serialization=None): @property def is_trainable(self) -> bool: return False - \ No newline at end of file diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index ac7aeeb4b465..e29ef2a458e4 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -2050,6 +2050,7 @@ def __init__( self.quant_method = QuantizationMethod.QUARK + @dataclass class Mxfp4Config(QuantizationConfigMixin): """ @@ -2068,4 +2069,4 @@ def __init__( **kwargs, ): self.quant_method = QuantizationMethod.MXFP4 - self.modules_to_not_convert = modules_to_not_convert \ No newline at end of file + self.modules_to_not_convert = modules_to_not_convert From fe896d36305eb358778e8b00fee8ded16c6a38dc Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 12 Jun 2025 10:57:54 +0000 Subject: [PATCH 183/342] fix --- src/transformers/integrations/mxfp4.py | 6 +++--- src/transformers/quantizers/quantizer_mxfp4.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index cd83130e12a6..953cbe952d5c 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -48,7 +48,7 @@ def forward(self, x): # maybe subclass -class Mxfp4OpenaiExperts(nn.Module): +class Mxfp4OpenAIMoeExperts(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.num_local_experts @@ -146,12 +146,12 @@ def _replace_with_mxfp4_linear( ) has_been_replaced = True model._modules[name].requires_grad_(False) - if module.__class__.__name__ == "OpenaiExperts" and should_convert_module( + if module.__class__.__name__ == "OpenAIMoeExperts" and should_convert_module( current_key_name, modules_to_not_convert ): with init_empty_weights(): # tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None - model._modules[name] = Mxfp4OpenaiExperts(config) + model._modules[name] = Mxfp4OpenAIMoeExperts(config) has_been_replaced=True if len(list(module.children())) > 0: diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index a94cc8a381ba..df0defeeb03d 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -103,7 +103,7 @@ def check_quantized_param( state_dict: Dict[str, Any], **kwargs, ): - from ..integrations import Mxfp4Linear, Mxfp4OpenaiExperts + from ..integrations import Mxfp4Linear, Mxfp4OpenAIMoeExperts module, tensor_name = get_module_from_name(model, param_name) @@ -113,7 +113,7 @@ def check_quantized_param( raise ValueError("Expect quantized weights but got an unquantized weight") return False return True - if isinstance(module, Mxfp4OpenaiExperts): + if isinstance(module, Mxfp4OpenAIMoeExperts): if self.pre_quantized or tensor_name == "bias": if (tensor_name == "down_proj" or tensor_name == "gate_up_proj") and param_value.dtype != torch.float8_e5m2: raise ValueError("Expect quantized weights but got an unquantized weight") @@ -167,11 +167,11 @@ def _process_model_before_weight_loading( model.config.quantization_config = self.quantization_config def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: - from ..integrations import Mxfp4Linear, Mxfp4OpenaiExperts + from ..integrations import Mxfp4Linear, Mxfp4OpenAIMoeExperts not_missing_keys = [] for name, module in model.named_modules(): - if isinstance(module, Mxfp4Linear) or isinstance(module, Mxfp4OpenaiExperts): + if isinstance(module, Mxfp4Linear) or isinstance(module, Mxfp4OpenAIMoeExperts): for missing in missing_keys: if ( (name in missing or name in f"{prefix}.{missing}") @@ -183,7 +183,7 @@ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> Li def update_tp_plan(self, config): # TODO: for tp support - # if "OpenaiExperts" in config.__class__.__name__: + # if "OpenAIMoeExperts" in config.__class__.__name__: # return config return config From 174147dfd9ebf5c42b50440193afe22fd3f5640d Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 23 Jun 2025 16:07:51 +0000 Subject: [PATCH 184/342] fix import --- src/transformers/integrations/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 5b4e1c17bc57..a724f6e69115 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -119,7 +119,7 @@ "run_hp_search_sigopt", "run_hp_search_wandb", ], - "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4Linear", "Mxfp4OpenaiExperts"], + "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4Linear", "Mxfp4OpenAIMoeExperts"], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], "spqr": ["replace_with_spqr_linear"], @@ -256,7 +256,7 @@ run_hp_search_sigopt, run_hp_search_wandb, ) - from .mxfp4 import replace_with_mxfp4_linear, Mxfp4Linear, Mxfp4OpenaiExperts + from .mxfp4 import replace_with_mxfp4_linear, Mxfp4Linear, Mxfp4OpenAIMoeExperts from .peft import PeftAdapterMixin from .quanto import replace_with_quanto_layers from .spqr import replace_with_spqr_linear From b8215ddd172f4cc6a0062b1adff7bf0ae9525124 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 23 Jun 2025 16:12:55 +0000 Subject: [PATCH 185/342] draft --- src/transformers/integrations/mxfp4.py | 123 ++++++++++++++++++------- 1 file changed, 90 insertions(+), 33 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 953cbe952d5c..82dc698a933d 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -27,6 +27,9 @@ logger = logging.get_logger(__name__) +from triton_kernels.matmul_ogs import matmul_ogs +from triton_kernels.routing import (GatherIndx, RoutingData, ScatterIndx, + routing) class Mxfp4Linear(torch.nn.Linear): def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32): @@ -46,7 +49,6 @@ def forward(self, x): """ return - # maybe subclass class Mxfp4OpenAIMoeExperts(nn.Module): def __init__(self, config): @@ -74,41 +76,96 @@ def __init__(self, config): # torch.zeros((self.num_experts, self.hidden_size, 1)) # ) - def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, router_logits=None, topk=None, router_indices=None, routing_weights=None) -> torch.Tensor: """ To update with moe mxfp4 kernels, for now we just upcast the weights in torch.bfloat16 """ - if self.training: - next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) - expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hitted: - with torch.no_grad(): - idx, top_x = torch.where( - expert_mask[expert_idx][0] - ) # idx: top-1/top-2 indicator, top_x: token indices - current_state = hidden_states[top_x] # (num_tokens, hidden_dim) - gate_up = ( - current_state @ self.gate_up_proj[expert_idx].to(torch.bfloat16) + self.gate_up_proj_bias[expert_idx] - ) # (num_tokens, 2 * interm_dim) - gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim) - glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim) - gated_output = (up + 1) * glu # (num_tokens, interm_dim) - out = ( - gated_output @ self.down_proj[expert_idx].to(torch.bfloat16) + self.down_proj_bias[expert_idx] - ) # (num_tokens, hidden_dim) - weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) - next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) - else: - hidden_states = hidden_states.repeat(self.num_experts, 1) - hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - gate_up = torch.bmm(hidden_states, self.gate_up_proj.to(torch.bfloat16)) + self.gate_up_proj_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors - glu = gate * torch.sigmoid(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj.to(torch.bfloat16)) + self.down_proj_bias[..., None, :] - next_states = next_states.view(-1, self.hidden_size) - return next_states + # if self.training: + # next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + # with torch.no_grad(): + # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) + # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + # for expert_idx in expert_hitted: + # with torch.no_grad(): + # idx, top_x = torch.where( + # expert_mask[expert_idx][0] + # ) # idx: top-1/top-2 indicator, top_x: token indices + # current_state = hidden_states[top_x] # (num_tokens, hidden_dim) + # gate_up = ( + # current_state @ self.gate_up_proj[expert_idx].to(torch.bfloat16) + self.gate_up_proj_bias[expert_idx] + # ) # (num_tokens, 2 * interm_dim) + # gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim) + # glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim) + # gated_output = (up + 1) * glu # (num_tokens, interm_dim) + # out = ( + # gated_output @ self.down_proj[expert_idx].to(torch.bfloat16) + self.down_proj_bias[expert_idx] + # ) # (num_tokens, hidden_dim) + # weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) + # next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) + # else: + # hidden_states = hidden_states.repeat(self.num_experts, 1) + # hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + # gate_up = torch.bmm(hidden_states, self.gate_up_proj.to(torch.bfloat16)) + self.gate_up_proj_bias[..., None, :] + # gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + # glu = gate * torch.sigmoid(gate * self.alpha) + # next_states = torch.bmm(((up + 1) * glu), self.down_proj.to(torch.bfloat16)) + self.down_proj_bias[..., None, :] + # next_states = next_states.view(-1, self.hidden_size) + + use_oai_kernels = True + if use_oai_kernels: + renormalize = True + print(router_logits) + print(topk) + routing_data, gather_idx, scatter_idx = routing(router_logits, topk, renormalize) + print(routing_data) + print(gather_idx) + print(scatter_idx) + + # if global_num_experts == -1: + # global_num_experts = E + + # # consistent with default implementation + # intermediate_cache2 = torch.empty((M * n_expts_act, N // 2), + # device="cuda", + # dtype=dtype) + + # intermediate_cache1 = matmul_ogs(hidden_states, + # w1, + # None, + # routing_data, + # gather_indx=gather_indx, + # gammas=routing_data.gate_scal + # if apply_router_weight_on_input else None) + + # if activation == "silu": + # torch.ops._C.silu_and_mul(intermediate_cache2, + # intermediate_cache1.view(-1, N)) + # elif activation == "gelu": + # torch.ops._C.gelu_and_mul(intermediate_cache2, + # intermediate_cache1.view(-1, N)) + # else: + # raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + # intermediate_cache3 = matmul_ogs( + # intermediate_cache2, + # w2, + # None, + # routing_data, + # scatter_indx=scatter_indx, + # gammas=None + # if apply_router_weight_on_input else routing_data.gate_scal) + + + + # if n_expts_tot > 1: + # logits = matmul_ogs(xg, wg, bg, precision_config=pcg) + # rdata, gather_indx, scatter_indx = routing(logits, n_expts_act, simulated_ep=EP) + # else: + # rdata, gather_indx, scatter_indx = None, None, None + # x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act) + # x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2) + + # return intermediate_cache3 def should_convert_module(current_key_name, patterns): current_key_name_str = ".".join(current_key_name) From 62f77e17e2cfa1d0855586843fcc7926601de40b Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 4 Jul 2025 14:24:26 +0000 Subject: [PATCH 186/342] draft impl --- src/transformers/integrations/__init__.py | 4 +- src/transformers/integrations/mxfp4.py | 233 ++++++++---------- src/transformers/modeling_utils.py | 3 +- .../quantizers/quantizer_mxfp4.py | 98 ++++++-- 4 files changed, 186 insertions(+), 152 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index a724f6e69115..4663e811a521 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -119,7 +119,7 @@ "run_hp_search_sigopt", "run_hp_search_wandb", ], - "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4Linear", "Mxfp4OpenAIMoeExperts"], + "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4OpenAIMoeExperts", "quantize_to_mxfp4", "shuffle_weight"], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], "spqr": ["replace_with_spqr_linear"], @@ -256,7 +256,7 @@ run_hp_search_sigopt, run_hp_search_wandb, ) - from .mxfp4 import replace_with_mxfp4_linear, Mxfp4Linear, Mxfp4OpenAIMoeExperts + from .mxfp4 import replace_with_mxfp4_linear, quantize_to_mxfp4, shuffle_weight, Mxfp4OpenAIMoeExperts from .peft import PeftAdapterMixin from .quanto import replace_with_quanto_layers from .spqr import replace_with_spqr_linear diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 82dc698a933d..1aa570f15c96 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -27,27 +27,44 @@ logger = logging.get_logger(__name__) -from triton_kernels.matmul_ogs import matmul_ogs -from triton_kernels.routing import (GatherIndx, RoutingData, ScatterIndx, - routing) -class Mxfp4Linear(torch.nn.Linear): - def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32): - super().__init__(in_features, out_features, bias) - # dtype torch.float4_e2m1fn not supported yet - self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e5m2)) - # self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype)) - - if bias: - self.bias = torch.nn.Parameter(torch.zeros((out_features), dtype=weight_dtype)) - else: - self.bias = None - - def forward(self, x): - """ - update - """ - return +def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): + from triton_kernels.numerics_details.mxfp import downcast_to_mxfp + from triton_kernels.matmul_ogs import InFlexData, MicroscalingCtx + swizzle_axis = 2 if swizzle_mx_scale else None + w = w.to(torch.bfloat16) + w, mx_scales, weight_scale_shape = downcast_to_mxfp( + w, + torch.uint8, + axis=1, + swizzle_axis=swizzle_axis, + swizzle_scale=swizzle_mx_scale, + swizzle_value=swizzle_mx_value) + return w, InFlexData(), MicroscalingCtx( + weight_scale=mx_scales, + swizzle_scale=swizzle_mx_scale, + swizzle_value=swizzle_mx_value, + actual_weight_scale_shape=weight_scale_shape) + +def shuffle_weight(w: "torch.Tensor") -> "torch.Tensor": + # Shuffle weight along the last dimension so that + # we folded the weights to adjance location + # Example: + # input: + # [[1, 2, 3, 4, 5, 6], + # [7, 8, 9, 10, 11, 12]] + # output: + # [[1, 4, 2, 5, 3, 6], + # [7, 10, 8, 11, 9, 12]] + # This will be used together with triton swiglu kernel + shape = w.shape + N = shape[-1] + first = w[..., :N // 2] + second = w[..., N // 2:] + + stacked = torch.stack((first, second), dim=-1) + w_shuffled = stacked.reshape(shape) + return w_shuffled # maybe subclass class Mxfp4OpenAIMoeExperts(nn.Module): @@ -68,104 +85,75 @@ def __init__(self, config): ) self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size)) self.alpha = 1.702 + + self.gate_up_proj_precision_config = None + self.down_proj_precision_config = None - # self.gate_up_proj_scale = torch.nn.Parameter( - # torch.zeros((self.num_experts, 1, self.expert_dim * 2)) - # ) - # self.down_proj_scale = torch.nn.Parameter( - # torch.zeros((self.num_experts, self.hidden_size, 1)) - # ) + smallest_even_divide_number = lambda x, n: (x // n + 1) * n if x % n != 0 else x + self.gate_up_proj_right_pad = smallest_even_divide_number(self.intermediate_size * 2, 256) - self.intermediate_size * 2 + self.gate_up_proj_bottom_pad = 0 + + self.down_proj_right_pad = smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size + self.down_proj_bottom_pad = self.gate_up_proj_right_pad // 2 + def forward(self, hidden_states: torch.Tensor, router_logits=None, topk=None, router_indices=None, routing_weights=None) -> torch.Tensor: """ To update with moe mxfp4 kernels, for now we just upcast the weights in torch.bfloat16 """ - # if self.training: - # next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) - # with torch.no_grad(): - # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) - # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - # for expert_idx in expert_hitted: - # with torch.no_grad(): - # idx, top_x = torch.where( - # expert_mask[expert_idx][0] - # ) # idx: top-1/top-2 indicator, top_x: token indices - # current_state = hidden_states[top_x] # (num_tokens, hidden_dim) - # gate_up = ( - # current_state @ self.gate_up_proj[expert_idx].to(torch.bfloat16) + self.gate_up_proj_bias[expert_idx] - # ) # (num_tokens, 2 * interm_dim) - # gate, up = gate_up.chunk(2, dim=-1) # (num_tokens, interm_dim) - # glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim) - # gated_output = (up + 1) * glu # (num_tokens, interm_dim) - # out = ( - # gated_output @ self.down_proj[expert_idx].to(torch.bfloat16) + self.down_proj_bias[expert_idx] - # ) # (num_tokens, hidden_dim) - # weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) - # next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) - # else: - # hidden_states = hidden_states.repeat(self.num_experts, 1) - # hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - # gate_up = torch.bmm(hidden_states, self.gate_up_proj.to(torch.bfloat16)) + self.gate_up_proj_bias[..., None, :] - # gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors - # glu = gate * torch.sigmoid(gate * self.alpha) - # next_states = torch.bmm(((up + 1) * glu), self.down_proj.to(torch.bfloat16)) + self.down_proj_bias[..., None, :] - # next_states = next_states.view(-1, self.hidden_size) - use_oai_kernels = True - if use_oai_kernels: + # type check, uint8 means mxfp4 + #TODO: fp8 x mxfp4 on blackwell + assert hidden_states.dtype == torch.bfloat16 + assert self.gate_up_proj.dtype in (torch.bfloat16, torch.uint8) + assert self.down_proj.dtype in (torch.bfloat16, torch.uint8) + assert self.gate_up_proj_bias.dtype == torch.float32 + assert self.down_proj_bias.dtype == torch.float32 + + # Shape check, only check non-mxfp4 + if self.gate_up_proj.dtype != torch.uint8: + assert hidden_states.ndim == 2 + assert hidden_states.shape[-1] == self.gate_up_proj.shape[-2] + assert self.down_proj.shape[-1] == self.gate_up_proj.shape[1] + + from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs + from triton_kernels.swiglu import swiglu_fn + from triton_kernels.routing import routing + # TODO: needed in the context of device_map, maybe not for TP + with torch.cuda.device(hidden_states.device): renormalize = True - print(router_logits) - print(topk) - routing_data, gather_idx, scatter_idx = routing(router_logits, topk, renormalize) - print(routing_data) - print(gather_idx) - print(scatter_idx) - - # if global_num_experts == -1: - # global_num_experts = E - - # # consistent with default implementation - # intermediate_cache2 = torch.empty((M * n_expts_act, N // 2), - # device="cuda", - # dtype=dtype) - - # intermediate_cache1 = matmul_ogs(hidden_states, - # w1, - # None, - # routing_data, - # gather_indx=gather_indx, - # gammas=routing_data.gate_scal - # if apply_router_weight_on_input else None) - - # if activation == "silu": - # torch.ops._C.silu_and_mul(intermediate_cache2, - # intermediate_cache1.view(-1, N)) - # elif activation == "gelu": - # torch.ops._C.gelu_and_mul(intermediate_cache2, - # intermediate_cache1.view(-1, N)) - # else: - # raise ValueError(f"Unsupported FusedMoe activation: {activation}") - - # intermediate_cache3 = matmul_ogs( - # intermediate_cache2, - # w2, - # None, - # routing_data, - # scatter_indx=scatter_indx, - # gammas=None - # if apply_router_weight_on_input else routing_data.gate_scal) - - - - # if n_expts_tot > 1: - # logits = matmul_ogs(xg, wg, bg, precision_config=pcg) - # rdata, gather_indx, scatter_indx = routing(logits, n_expts_act, simulated_ep=EP) - # else: - # rdata, gather_indx, scatter_indx = None, None, None - # x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act) - # x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2) - - # return intermediate_cache3 + routing_data, gather_idx, scatter_idx = routing(router_logits, topk, sm_first=not renormalize) + act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),(self.alpha, None), 2) + + apply_router_weight_on_input = False + intermediate_cache1 = matmul_ogs(hidden_states, + self.gate_up_proj, + self.gate_up_proj_bias, + routing_data, + gather_indx=gather_idx, + precision_config=self.gate_up_proj_precision_config, + gammas=routing_data.gate_scal if apply_router_weight_on_input else None, + fused_activation=act) + + intermediate_cache3 = matmul_ogs( + intermediate_cache1, + self.down_proj, + self.down_proj_bias, + routing_data, + scatter_indx=scatter_idx, + precision_config=self.down_proj_precision_config, + gammas=None if apply_router_weight_on_input else routing_data.gate_scal) + + # manually crop the tensor since oai kernel pad the output + output_states = intermediate_cache3[..., :self.hidden_size].contiguous() + torch.cuda.synchronize() + return output_states + +def mlp_forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = self.router(hidden_states) + routed_out = self.experts(hidden_states, router_logits=router_logits, topk=self.top_k) + return routed_out, router_logits def should_convert_module(current_key_name, patterns): current_key_name_str = ".".join(current_key_name) @@ -191,26 +179,19 @@ def _replace_with_mxfp4_linear( for name, module in model.named_children(): current_key_name.append(name) - - if (isinstance(module, nn.Linear)) and should_convert_module(current_key_name, modules_to_not_convert): - with init_empty_weights(): - in_features = module.in_features - out_features = module.out_features - model._modules[name] = Mxfp4Linear( - in_features, - out_features, - module.bias is not None, - ) - has_been_replaced = True - model._modules[name].requires_grad_(False) - if module.__class__.__name__ == "OpenAIMoeExperts" and should_convert_module( - current_key_name, modules_to_not_convert - ): + if not should_convert_module(current_key_name, modules_to_not_convert): + current_key_name.pop(-1) + continue + if isinstance(module, nn.Linear): + raise NotImplementedError("Mxfp4 linear layer is not implemented yet") + if module.__class__.__name__ == "OpenAIMoeExperts": with init_empty_weights(): # tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None model._modules[name] = Mxfp4OpenAIMoeExperts(config) has_been_replaced=True - + if module.__class__.__name__ == "OpenAIMoeMLP": + from types import MethodType + module.forward = MethodType(mlp_forward, module) if len(list(module.children())) > 0: _, has_been_replaced = _replace_with_mxfp4_linear( module, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9bc2b894d1fc..7911c3ccdb4a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -720,8 +720,9 @@ def _infer_parameter_dtype( # in int/uint/bool and not cast them. casting_dtype = None is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn - is_param_float8_e5m2 = is_torch_e5m2_available and empty_param.dtype == torch.float8_e5m2 + # is_param_float8_e5m2 = is_torch_e5m2_available and empty_param.dtype == torch.float8_e5m2 if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn and not is_param_float8_e5m2: + # if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: # First fp32 if part of the exception list if keep_in_fp32_regex is not None and keep_in_fp32_regex.search(param_name): casting_dtype = torch.float32 diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index df0defeeb03d..95360bcc25a7 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -55,11 +55,15 @@ def validate_environment(self, *args, **kwargs): compute_capability = torch.cuda.get_device_capability() major, minor = compute_capability + # TODO: Fix that + # if not is_triton_kernels_availalble(): + # raise ValueError( + # "MXFP4 quantization requires triton_kernels library" + # ) if major < 9: raise ValueError( - "FP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)" + "MXFP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)" ) - # TODO: update accelerate version when it is released if not is_accelerate_available(): raise ImportError( f"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=1.8.0'`" @@ -82,6 +86,21 @@ def validate_environment(self, *args, **kwargs): "This is not supported when the model is quantized on the fly. " "Please use a quantized checkpoint or remove the CPU or disk device from the device_map." ) + from triton_kernels.numerics_details.mxfp import SwizzlingType + + if major < 9: + # NYI for Ampere + swizzle_mx_value = None + swizzle_mx_scale = None + elif major < 10: + swizzle_mx_value = SwizzlingType.HOPPER + swizzle_mx_scale = SwizzlingType.HOPPER + else: + swizzle_mx_value = None + swizzle_mx_scale = SwizzlingType.BLACKWELL + + self.swizzle_mx_value = swizzle_mx_value + self.swizzle_mx_scale = swizzle_mx_scale def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: @@ -103,25 +122,13 @@ def check_quantized_param( state_dict: Dict[str, Any], **kwargs, ): - from ..integrations import Mxfp4Linear, Mxfp4OpenAIMoeExperts - + from ..integrations import Mxfp4OpenAIMoeExperts module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module, Mxfp4Linear): - if self.pre_quantized or tensor_name == "bias": - if tensor_name == "weight" and param_value.dtype != torch.float8_e5m2: - raise ValueError("Expect quantized weights but got an unquantized weight") - return False - return True if isinstance(module, Mxfp4OpenAIMoeExperts): - if self.pre_quantized or tensor_name == "bias": - if (tensor_name == "down_proj" or tensor_name == "gate_up_proj") and param_value.dtype != torch.float8_e5m2: - raise ValueError("Expect quantized weights but got an unquantized weight") + if tensor_name in ["down_proj_bias", "gate_up_proj_bias"]: return False - else: - if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale": - raise ValueError("Expect unquantized weights but got a quantized weight_scale") - return True + return True return False def create_quantized_param( @@ -133,12 +140,57 @@ def create_quantized_param( state_dict: Dict[str, Any], unexpected_keys: Optional[List[str]] = None, ): - """ - Quantizes weights into weight and weight_scale - """ - pass + from ..integrations import quantize_to_mxfp4, Mxfp4OpenAIMoeExperts, shuffle_weight + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + + module, _ = get_module_from_name(model, param_name) + + # calculating padding needed for each tensor + if "gate_up_proj" in param_name and isinstance(module, Mxfp4OpenAIMoeExperts): + right_pad = module.gate_up_proj_right_pad + bottom_pad = module.gate_up_proj_bottom_pad + elif "down_proj" in param_name and isinstance(module, Mxfp4OpenAIMoeExperts): + right_pad = module.down_proj_right_pad + bottom_pad = module.down_proj_bottom_pad + + with torch.cuda.device(target_device): + loaded_weight_shuffled = shuffle_weight(param_value).to(target_device) + loaded_weight = torch.nn.functional.pad(loaded_weight_shuffled, + (0, right_pad, 0, bottom_pad, 0, 0), + mode="constant", + value=0) + # delete intermediate tensor immediate to prevent OOM + del loaded_weight_shuffled + torch.cuda.empty_cache() + loaded_weight, flex, mx = quantize_to_mxfp4( + loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) + if isinstance(module, Mxfp4OpenAIMoeExperts): + if "gate_up_proj" in param_name: + module.gate_up_proj_precision_config = PrecisionConfig( + mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) + elif "down_proj" in param_name: + module.down_proj_precision_config = PrecisionConfig( + mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + from ..integrations import shuffle_weight, Mxfp4OpenAIMoeExperts + + for module in model.modules(): + if isinstance(module, Mxfp4OpenAIMoeExperts): + gate_up_proj_bias = shuffle_weight(module.gate_up_proj_bias) + gate_up_proj_bias = gate_up_proj_bias.to(torch.float32) + gate_up_proj_bias = torch.nn.functional.pad(gate_up_proj_bias, (0, module.gate_up_proj_right_pad, 0, 0), + mode="constant", + value=0) + + down_proj_bias = module.down_proj_bias.to(torch.float32) + down_proj_bias = torch.nn.functional.pad(down_proj_bias, (0, module.down_proj_right_pad, 0, 0), + mode="constant", + value=0) + module.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) + module.down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) return model def _process_model_before_weight_loading( @@ -167,11 +219,11 @@ def _process_model_before_weight_loading( model.config.quantization_config = self.quantization_config def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: - from ..integrations import Mxfp4Linear, Mxfp4OpenAIMoeExperts + from ..integrations import Mxfp4OpenAIMoeExperts not_missing_keys = [] for name, module in model.named_modules(): - if isinstance(module, Mxfp4Linear) or isinstance(module, Mxfp4OpenAIMoeExperts): + if isinstance(module, Mxfp4OpenAIMoeExperts): for missing in missing_keys: if ( (name in missing or name in f"{prefix}.{missing}") From 6e9d0c72bb624606b4adb6d1e40c78ef80494ef2 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 4 Jul 2025 15:36:55 +0000 Subject: [PATCH 187/342] finally working ! --- src/transformers/modeling_utils.py | 2 +- .../quantizers/quantizer_mxfp4.py | 57 ++++++++++--------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7911c3ccdb4a..28efaea59318 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -720,7 +720,7 @@ def _infer_parameter_dtype( # in int/uint/bool and not cast them. casting_dtype = None is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn - # is_param_float8_e5m2 = is_torch_e5m2_available and empty_param.dtype == torch.float8_e5m2 + is_param_float8_e5m2 = is_torch_e5m2_available and empty_param.dtype == torch.float8_e5m2 if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn and not is_param_float8_e5m2: # if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: # First fp32 if part of the exception list diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 95360bcc25a7..c74b75ea7b55 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -145,34 +145,36 @@ def create_quantized_param( module, _ = get_module_from_name(model, param_name) - # calculating padding needed for each tensor - if "gate_up_proj" in param_name and isinstance(module, Mxfp4OpenAIMoeExperts): - right_pad = module.gate_up_proj_right_pad - bottom_pad = module.gate_up_proj_bottom_pad - elif "down_proj" in param_name and isinstance(module, Mxfp4OpenAIMoeExperts): - right_pad = module.down_proj_right_pad - bottom_pad = module.down_proj_bottom_pad - + with torch.cuda.device(target_device): - loaded_weight_shuffled = shuffle_weight(param_value).to(target_device) - loaded_weight = torch.nn.functional.pad(loaded_weight_shuffled, - (0, right_pad, 0, bottom_pad, 0, 0), - mode="constant", - value=0) - # delete intermediate tensor immediate to prevent OOM - del loaded_weight_shuffled - torch.cuda.empty_cache() - loaded_weight, flex, mx = quantize_to_mxfp4( - loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) - if isinstance(module, Mxfp4OpenAIMoeExperts): - if "gate_up_proj" in param_name: - module.gate_up_proj_precision_config = PrecisionConfig( - mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) - module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) - elif "down_proj" in param_name: - module.down_proj_precision_config = PrecisionConfig( - mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) - module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) + if isinstance(module, Mxfp4OpenAIMoeExperts): + if "gate_up_proj" in param_name: + right_pad = module.gate_up_proj_right_pad + bottom_pad = module.gate_up_proj_bottom_pad + # we only shuffle gate_proj + loaded_weight_shuffled = shuffle_weight(param_value).to(target_device) + loaded_weight = torch.nn.functional.pad(loaded_weight_shuffled, + (0, right_pad, 0, bottom_pad, 0, 0), + mode="constant", + value=0) + del loaded_weight_shuffled + torch.cuda.empty_cache() + loaded_weight, flex, mx = quantize_to_mxfp4( + loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) + module.gate_up_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) + elif "down_proj" in param_name: + right_pad = module.down_proj_right_pad + bottom_pad = module.down_proj_bottom_pad + loaded_weight = torch.nn.functional.pad(param_value, + (0, right_pad, 0, bottom_pad, 0, 0), + mode="constant", + value=0).to(target_device) + # delete intermediate tensor immediate to prevent OOM + loaded_weight, flex, mx = quantize_to_mxfp4( + loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) + module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): from ..integrations import shuffle_weight, Mxfp4OpenAIMoeExperts @@ -184,7 +186,6 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs gate_up_proj_bias = torch.nn.functional.pad(gate_up_proj_bias, (0, module.gate_up_proj_right_pad, 0, 0), mode="constant", value=0) - down_proj_bias = module.down_proj_bias.to(torch.float32) down_proj_bias = torch.nn.functional.pad(down_proj_bias, (0, module.down_proj_right_pad, 0, 0), mode="constant", From 6b8b279f9137e665df867e6023fa5f8972504cb9 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 8 Jul 2025 09:48:43 +0000 Subject: [PATCH 188/342] simplify --- src/transformers/integrations/mxfp4.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 1aa570f15c96..e4e5ad84fe60 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -170,7 +170,6 @@ def _replace_with_mxfp4_linear( current_key_name=None, quantization_config=None, has_been_replaced=False, - pre_quantized=False, config=None, tp_plan=None, ): @@ -199,7 +198,6 @@ def _replace_with_mxfp4_linear( current_key_name, quantization_config, has_been_replaced=has_been_replaced, - pre_quantized=pre_quantized, config=config, tp_plan=tp_plan, ) @@ -212,7 +210,6 @@ def replace_with_mxfp4_linear( modules_to_not_convert=None, current_key_name=None, quantization_config=None, - pre_quantized=False, config=None, tp_plan=None, ): @@ -226,7 +223,6 @@ def replace_with_mxfp4_linear( modules_to_not_convert, current_key_name, quantization_config, - pre_quantized=pre_quantized, config=config, tp_plan=tp_plan, ) From ea5c364aeb96776e75ad2fa1f7b63e5602f5d7ba Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 8 Jul 2025 13:00:12 +0000 Subject: [PATCH 189/342] add import --- src/transformers/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9186277fdd25..d405cf56f903 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -273,6 +273,7 @@ "GPTQConfig", "HiggsConfig", "HqqConfig", + "Mxfp4Config", "QuantoConfig", "QuarkConfig", "FPQuantConfig", @@ -966,6 +967,7 @@ GPTQConfig, HiggsConfig, HqqConfig, + Mxfp4Config, QuantoConfig, QuarkConfig, SpQRConfig, From 1175ab46f2a00d3472af1cab8710fe88ce64431a Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 8 Jul 2025 13:30:05 +0000 Subject: [PATCH 190/342] working version --- src/transformers/integrations/mxfp4.py | 13 ++- .../quantizers/quantizer_mxfp4.py | 97 +++++++++---------- src/transformers/utils/quantization_config.py | 2 +- 3 files changed, 53 insertions(+), 59 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index e4e5ad84fe60..f1c16d6552a6 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -97,7 +97,7 @@ def __init__(self, config): self.down_proj_right_pad = smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size self.down_proj_bottom_pad = self.gate_up_proj_right_pad // 2 - def forward(self, hidden_states: torch.Tensor, router_logits=None, topk=None, router_indices=None, routing_weights=None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: """ To update with moe mxfp4 kernels, for now we just upcast the weights in torch.bfloat16 """ @@ -118,11 +118,8 @@ def forward(self, hidden_states: torch.Tensor, router_logits=None, topk=None, ro from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs from triton_kernels.swiglu import swiglu_fn - from triton_kernels.routing import routing # TODO: needed in the context of device_map, maybe not for TP with torch.cuda.device(hidden_states.device): - renormalize = True - routing_data, gather_idx, scatter_idx = routing(router_logits, topk, sm_first=not renormalize) act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),(self.alpha, None), 2) apply_router_weight_on_input = False @@ -150,9 +147,11 @@ def forward(self, hidden_states: torch.Tensor, router_logits=None, topk=None, ro return output_states def mlp_forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.router(hidden_states) - routed_out = self.experts(hidden_states, router_logits=router_logits, topk=self.top_k) + from triton_kernels.routing import routing + hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) + router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) + routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k, sm_first=False) + routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) return routed_out, router_logits def should_convert_module(current_key_name, patterns): diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index c74b75ea7b55..005700c42734 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -36,7 +36,7 @@ class Mxfp4HfQuantizer(HfQuantizer): requires_parameters_quantization = True # to remove if we decide to allow quantizing weights with this method - requires_calibration = True + requires_calibration = False required_packages = ["accelerate"] @@ -142,56 +142,55 @@ def create_quantized_param( ): from ..integrations import quantize_to_mxfp4, Mxfp4OpenAIMoeExperts, shuffle_weight from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig - - module, _ = get_module_from_name(model, param_name) - - with torch.cuda.device(target_device): - if isinstance(module, Mxfp4OpenAIMoeExperts): - if "gate_up_proj" in param_name: - right_pad = module.gate_up_proj_right_pad - bottom_pad = module.gate_up_proj_bottom_pad - # we only shuffle gate_proj - loaded_weight_shuffled = shuffle_weight(param_value).to(target_device) - loaded_weight = torch.nn.functional.pad(loaded_weight_shuffled, - (0, right_pad, 0, bottom_pad, 0, 0), - mode="constant", - value=0) - del loaded_weight_shuffled - torch.cuda.empty_cache() - loaded_weight, flex, mx = quantize_to_mxfp4( - loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) - module.gate_up_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) - module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) - elif "down_proj" in param_name: - right_pad = module.down_proj_right_pad - bottom_pad = module.down_proj_bottom_pad - loaded_weight = torch.nn.functional.pad(param_value, - (0, right_pad, 0, bottom_pad, 0, 0), - mode="constant", - value=0).to(target_device) - # delete intermediate tensor immediate to prevent OOM - loaded_weight, flex, mx = quantize_to_mxfp4( + if not self.pre_quantized: + module, _ = get_module_from_name(model, param_name) + with torch.cuda.device(target_device): + if isinstance(module, Mxfp4OpenAIMoeExperts): + if "gate_up_proj" in param_name: + right_pad = module.gate_up_proj_right_pad + bottom_pad = module.gate_up_proj_bottom_pad + # we only shuffle gate_proj + loaded_weight_shuffled = shuffle_weight(param_value).to(target_device) + loaded_weight = torch.nn.functional.pad(loaded_weight_shuffled, + (0, right_pad, 0, bottom_pad, 0, 0), + mode="constant", + value=0) + del loaded_weight_shuffled + torch.cuda.empty_cache() + loaded_weight, flex, mx = quantize_to_mxfp4( loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) - module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) - module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) + module.gate_up_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) + elif "down_proj" in param_name: + right_pad = module.down_proj_right_pad + bottom_pad = module.down_proj_bottom_pad + loaded_weight = torch.nn.functional.pad(param_value, + (0, right_pad, 0, bottom_pad, 0, 0), + mode="constant", + value=0).to(target_device) + # delete intermediate tensor immediate to prevent OOM + loaded_weight, flex, mx = quantize_to_mxfp4( + loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) + module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): from ..integrations import shuffle_weight, Mxfp4OpenAIMoeExperts - - for module in model.modules(): - if isinstance(module, Mxfp4OpenAIMoeExperts): - gate_up_proj_bias = shuffle_weight(module.gate_up_proj_bias) - gate_up_proj_bias = gate_up_proj_bias.to(torch.float32) - gate_up_proj_bias = torch.nn.functional.pad(gate_up_proj_bias, (0, module.gate_up_proj_right_pad, 0, 0), - mode="constant", - value=0) - down_proj_bias = module.down_proj_bias.to(torch.float32) - down_proj_bias = torch.nn.functional.pad(down_proj_bias, (0, module.down_proj_right_pad, 0, 0), - mode="constant", - value=0) - module.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) - module.down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) + if not self.pre_quantized: + for module in model.modules(): + if isinstance(module, Mxfp4OpenAIMoeExperts): + gate_up_proj_bias = shuffle_weight(module.gate_up_proj_bias) + gate_up_proj_bias = gate_up_proj_bias.to(torch.float32) + gate_up_proj_bias = torch.nn.functional.pad(gate_up_proj_bias, (0, module.gate_up_proj_right_pad, 0, 0), + mode="constant", + value=0) + down_proj_bias = module.down_proj_bias.to(torch.float32) + down_proj_bias = torch.nn.functional.pad(down_proj_bias, (0, module.down_proj_right_pad, 0, 0), + mode="constant", + value=0) + module.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) + module.down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) return model def _process_model_before_weight_loading( @@ -212,7 +211,6 @@ def _process_model_before_weight_loading( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config, - pre_quantized=self.pre_quantized, config=config, tp_plan=tp_plan, ) @@ -235,13 +233,10 @@ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> Li return [k for k in missing_keys if k not in not_missing_keys] def update_tp_plan(self, config): - # TODO: for tp support - # if "OpenAIMoeExperts" in config.__class__.__name__: - # return config return config def is_serializable(self, safe_serialization=None): - return True + return False @property def is_trainable(self) -> bool: diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index e29ef2a458e4..8ebe5d45557f 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -2065,7 +2065,7 @@ class Mxfp4Config(QuantizationConfigMixin): def __init__( self, - modules_to_not_convert: Optional[List] = None, + modules_to_not_convert: Optional[list] = None, **kwargs, ): self.quant_method = QuantizationMethod.MXFP4 From d53cb49e12459b1ca0a4041d5c19f0c5167036ca Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 10 Jul 2025 10:22:41 +0000 Subject: [PATCH 191/342] consider blocks and scales --- .../convert_openai_weights_to_hf.py | 55 ++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index c8daf8a43ca5..3db18a398bc4 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -16,6 +16,7 @@ import gc import json import os +import math from pathlib import Path from typing import List, Optional @@ -152,7 +153,8 @@ def write_model( print(f"Fetching all parameters from the checkpoint at {input_base_path}...") final_ = {} for file in list(os.listdir(input_base_path)): - if file.endswith(".safetensors"): + # TODO: remove that + if file.endswith("of-00007.safetensors"): final_.update(safe_load(os.path.join(input_base_path, file))) print("Converting ..") @@ -324,6 +326,57 @@ def bytes_to_unicode(): cs = [chr(n) for n in cs] return dict(zip(bs, cs)) +FP4_VALUES = [ + +0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +] + +def convert_moe_packed_tensors( + blocks, + scales, + *, + dtype: torch.dtype = torch.bfloat16, + rows_per_chunk: int = 32768 * 1024, +) -> torch.Tensor: + + scales = scales.to(torch.int32) - 127 + + assert blocks.shape[:-1] == scales.shape, ( + f"{blocks.shape=} does not match {scales.shape=}" + ) + + lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) + + *prefix_shape, G, B = blocks.shape + rows_total = math.prod(prefix_shape) * G + + blocks = blocks.reshape(rows_total, B) + scales = scales.reshape(rows_total, 1) + + out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + del idx_lo, idx_hi, blk, exp + + out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + # to match for now existing implementation + out = out.to(torch.float8_e5m2) + return out.to(torch.float8_e5m2) + class OpenAIMoeConverter(TikTokenConverter): def extract_vocab_merges_from_model(self, tiktoken_url: str): From 8c43631f7f8666b534f633859d919599ef1e1301 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 10 Jul 2025 10:27:20 +0000 Subject: [PATCH 192/342] device mesh fix --- src/transformers/modeling_utils.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 28efaea59318..0b2a1785c94a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2266,13 +2266,8 @@ def post_init(self): raise ValueError( f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}" ) - - if ( - is_torch_greater_or_equal("2.5") - and _torch_distributed_available - and hasattr(self.config, "device_mesh") - and self.config.device_mesh is not None - ): + + if is_torch_greater_or_equal("2.5") and _torch_distributed_available and getattr(self.config, "device_mesh", None) is not None: # loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit device_mesh = self.config.device_mesh for name, module in self.named_modules(): From 4f515ebc23202e20bf7a9a1270980ef02fae175f Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 16 Jul 2025 08:51:32 +0000 Subject: [PATCH 193/342] initial commit --- src/transformers/integrations/__init__.py | 2 +- src/transformers/integrations/mxfp4.py | 155 +++++++++++++---- src/transformers/modeling_utils.py | 2 +- .../convert_openai_weights_to_hf.py | 5 +- .../models/openai_moe/modeling_openai_moe.py | 2 +- .../quantizers/quantizer_mxfp4.py | 164 ++++++++++++++++-- 6 files changed, 281 insertions(+), 49 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 4663e811a521..713883ac668e 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -119,7 +119,7 @@ "run_hp_search_sigopt", "run_hp_search_wandb", ], - "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4OpenAIMoeExperts", "quantize_to_mxfp4", "shuffle_weight"], + "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4OpenAIMoeExperts", "quantize_to_mxfp4", "shuffle_weight", "convert_moe_packed_tensors"], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], "spqr": ["replace_with_spqr_linear"], diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index f1c16d6552a6..43ed397db185 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -27,12 +27,19 @@ logger = logging.get_logger(__name__) +FP4_VALUES = [ + +0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +] def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): from triton_kernels.numerics_details.mxfp import downcast_to_mxfp from triton_kernels.matmul_ogs import InFlexData, MicroscalingCtx - swizzle_axis = 2 if swizzle_mx_scale else None + + + swizzle_axis = 2 if swizzle_mx_scale or swizzle_mx_value else None w = w.to(torch.bfloat16) + w, mx_scales, weight_scale_shape = downcast_to_mxfp( w, torch.uint8, @@ -66,55 +73,116 @@ def shuffle_weight(w: "torch.Tensor") -> "torch.Tensor": w_shuffled = stacked.reshape(shape) return w_shuffled +def convert_moe_packed_tensors( + blocks, + scales, + *, + dtype: torch.dtype = torch.bfloat16, + rows_per_chunk: int = 32768 * 1024, +) -> torch.Tensor: + import math + + scales = scales.to(torch.int32) - 127 + + + assert blocks.shape[:-1] == scales.shape, ( + f"{blocks.shape=} does not match {scales.shape=}" + ) + + lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) + + *prefix_shape, G, B = blocks.shape + rows_total = math.prod(prefix_shape) * G + + blocks = blocks.reshape(rows_total, B) + scales = scales.reshape(rows_total, 1) + + out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + del idx_lo, idx_hi, blk, exp + + out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + # to match for now existing implementation + return out + # maybe subclass class Mxfp4OpenAIMoeExperts(nn.Module): def __init__(self, config): + from triton_kernels.matmul_ogs import InFlexData, FlexCtx, PrecisionConfig, MicroscalingCtx + + from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, SwizzlingType + super().__init__() - self.num_experts = config.num_local_experts + self.num_experts = config.num_experts self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - + # print("#######################expert_dim: ", self.expert_dim) + # print("#######################num_experts: ", self.num_experts) + # print(config) # dtype torch.float4_e2m1fn not supported yet - self.gate_up_proj = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim, dtype=torch.float8_e5m2), + self.gate_up_proj_blocks = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size//32, 16, dtype=torch.uint8), requires_grad=False, + ) + self.gate_up_proj_scales = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size//32, dtype=torch.uint8), requires_grad=False, ) - self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim)) - self.down_proj = nn.Parameter( - torch.zeros((self.num_experts, self.expert_dim, self.hidden_size), dtype=torch.float8_e5m2), + self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim, dtype=torch.float32), requires_grad=False) + + self.down_proj_blocks = nn.Parameter( + torch.zeros((self.num_experts, self.expert_dim, self.hidden_size//32, 16), dtype=torch.uint8), requires_grad=False, + ) + self.down_proj_scales = nn.Parameter( + torch.zeros(self.num_experts, self.expert_dim, self.hidden_size//32, dtype=torch.uint8), requires_grad=False, ) - self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size)) + self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.expert_dim, dtype=torch.float32), requires_grad=False) self.alpha = 1.702 - + + self.gate_up_proj_precision_config = None self.down_proj_precision_config = None smallest_even_divide_number = lambda x, n: (x // n + 1) * n if x % n != 0 else x - self.gate_up_proj_right_pad = smallest_even_divide_number(self.intermediate_size * 2, 256) - self.intermediate_size * 2 + self.gate_up_proj_right_pad = 0#smallest_even_divide_number(self.intermediate_size * 2, 256) - self.intermediate_size * 2 self.gate_up_proj_bottom_pad = 0 - self.down_proj_right_pad = smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size - self.down_proj_bottom_pad = self.gate_up_proj_right_pad // 2 + self.down_proj_right_pad = 0#smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size + self.down_proj_bottom_pad = 0#self.gate_up_proj_right_pad // 2 + self.hidden_size_pad = 0#smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: """ To update with moe mxfp4 kernels, for now we just upcast the weights in torch.bfloat16 """ - # type check, uint8 means mxfp4 #TODO: fp8 x mxfp4 on blackwell assert hidden_states.dtype == torch.bfloat16 - assert self.gate_up_proj.dtype in (torch.bfloat16, torch.uint8) - assert self.down_proj.dtype in (torch.bfloat16, torch.uint8) - assert self.gate_up_proj_bias.dtype == torch.float32 - assert self.down_proj_bias.dtype == torch.float32 + assert self.gate_up_proj_blocks.dtype in (torch.bfloat16, torch.uint8) + assert self.down_proj_blocks.dtype in (torch.bfloat16, torch.uint8) + # assert self.gate_up_proj_bias.dtype == torch.float32, "expected float32 for gate_up_proj_bias and got " + str(self.gate_up_proj_bias.dtype) + # assert self.down_proj_bias.dtype == torch.float32, "expected float32 for down_proj_bias and got " + str(self.down_proj_bias.dtype) # Shape check, only check non-mxfp4 - if self.gate_up_proj.dtype != torch.uint8: + if self.gate_up_proj_blocks.dtype != torch.uint8: assert hidden_states.ndim == 2 - assert hidden_states.shape[-1] == self.gate_up_proj.shape[-2] - assert self.down_proj.shape[-1] == self.gate_up_proj.shape[1] + assert hidden_states.shape[-1] == self.gate_up_proj_blocks.shape[-2] + assert self.down_proj_blocks.shape[-1] == self.gate_up_proj_blocks.shape[1] from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs from triton_kernels.swiglu import swiglu_fn @@ -122,25 +190,54 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter with torch.cuda.device(hidden_states.device): act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),(self.alpha, None), 2) + if self.hidden_size_pad is not None: + hidden_states = torch.nn.functional.pad(hidden_states, + (0, self.hidden_size_pad, 0, 0), + mode="constant", + value=0) + apply_router_weight_on_input = False + # if torch.isnan(hidden_states).any(): + # raise ValueError("NaNs detected in hidden_states.") + intermediate_cache1 = matmul_ogs(hidden_states, self.gate_up_proj, - self.gate_up_proj_bias, + self.gate_up_proj_bias.to(torch.float32), routing_data, gather_indx=gather_idx, precision_config=self.gate_up_proj_precision_config, gammas=routing_data.gate_scal if apply_router_weight_on_input else None, fused_activation=act) + torch.cuda.synchronize() + + if torch.isnan(intermediate_cache1).any(): + print(f"hidden_states.shape: {hidden_states.shape}, self.gate_up_proj.shape: {self.gate_up_proj.shape}") + print(f"precision_config: {self.gate_up_proj_precision_config}") + print(f"hidden_states.dtype: {hidden_states.dtype}, self.gate_up_proj.dtype: {self.gate_up_proj.dtype}, self.gate_up_proj_bias.dtype: {self.gate_up_proj_bias.dtype}") + + raise ValueError("NaNs detected in intermediate_cache1 after matmul_ogs (gate_up_proj).") + + with torch.cuda.device(hidden_states.device): intermediate_cache3 = matmul_ogs( intermediate_cache1, self.down_proj, - self.down_proj_bias, + self.down_proj_bias.to(torch.float32), routing_data, scatter_indx=scatter_idx, precision_config=self.down_proj_precision_config, gammas=None if apply_router_weight_on_input else routing_data.gate_scal) - + + torch.cuda.synchronize() + if torch.isnan(intermediate_cache3).any(): + print(f"routing_data: {routing_data}") + print(f"scatter_idx: {scatter_idx}") + print(f"intermediate_cache3.shape: {intermediate_cache3.shape}") + print(f"self.down_proj.shape: {self.down_proj.shape}") + print(f"self.down_proj_bias.shape: {self.down_proj_bias.shape}") + print(f"self.down_proj_precision_config: {self.down_proj_precision_config}") + raise ValueError("NaNs detected in intermediate_cache3 after matmul_ogs (down_proj).") + # manually crop the tensor since oai kernel pad the output output_states = intermediate_cache3[..., :self.hidden_size].contiguous() torch.cuda.synchronize() @@ -180,16 +277,16 @@ def _replace_with_mxfp4_linear( if not should_convert_module(current_key_name, modules_to_not_convert): current_key_name.pop(-1) continue - if isinstance(module, nn.Linear): - raise NotImplementedError("Mxfp4 linear layer is not implemented yet") + # if isinstance(module, nn.Linear): + # raise NotImplementedError("Mxfp4 linear layer is not implemented yet") if module.__class__.__name__ == "OpenAIMoeExperts": with init_empty_weights(): # tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None model._modules[name] = Mxfp4OpenAIMoeExperts(config) has_been_replaced=True - if module.__class__.__name__ == "OpenAIMoeMLP": - from types import MethodType - module.forward = MethodType(mlp_forward, module) + # if module.__class__.__name__ == "OpenAIMoeMLP": + # from types import MethodType + # module.forward = MethodType(mlp_forward, module) if len(list(module.children())) > 0: _, has_been_replaced = _replace_with_mxfp4_linear( module, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0b2a1785c94a..e0b7ca4aaf1c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -784,7 +784,6 @@ def _load_state_dict_into_meta_model( for param_name, empty_param in state_dict.items(): if param_name not in expected_keys: # when loading from ckpt, we skip param if doesnt exist in modeling continue - # we need to use serialized_param_name as file pointer is untouched if is_meta_state_dict: # This is the name of the parameter as it appears on disk file @@ -855,6 +854,7 @@ def _load_state_dict_into_meta_model( hf_quantizer.create_quantized_param( model, param, param_name, param_device, state_dict, unexpected_keys ) + # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU # and then cast it to CPU to avoid excessive memory usage on each GPU # in comparison to the sharded model across GPUs. diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 3db18a398bc4..72911b8960be 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -154,7 +154,7 @@ def write_model( final_ = {} for file in list(os.listdir(input_base_path)): # TODO: remove that - if file.endswith("of-00007.safetensors"): + if file.endswith(".safetensors"): final_.update(safe_load(os.path.join(input_base_path, file))) print("Converting ..") @@ -210,6 +210,9 @@ def write_model( if not re.search("norm", new_key): weight = weight.to(torch.bfloat16) # norms are the only ones in float32 state_dict[new_key] = weight + if "bias" in new_key and "down_proj" in new_key: + print("new_key: ", new_key) + print(f"bias.shape: {state_dict[new_key].shape}") del final_ gc.collect() diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index e9cc409eed8d..ac032620b875 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -115,7 +115,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] gate, up = gate_up[..., ::2], gate_up[..., 1::2] glu = gate * torch.sigmoid(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj) + next_states = torch.bmm(((up + 1) * glu), self.down_proj.to(hidden_states.dtype)) next_states = next_states + self.down_proj_bias[..., None, :] next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 005700c42734..6c24054eb7a1 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -94,7 +94,7 @@ def validate_environment(self, *args, **kwargs): swizzle_mx_scale = None elif major < 10: swizzle_mx_value = SwizzlingType.HOPPER - swizzle_mx_scale = SwizzlingType.HOPPER + swizzle_mx_scale = None else: swizzle_mx_value = None swizzle_mx_scale = SwizzlingType.BLACKWELL @@ -140,8 +140,9 @@ def create_quantized_param( state_dict: Dict[str, Any], unexpected_keys: Optional[List[str]] = None, ): - from ..integrations import quantize_to_mxfp4, Mxfp4OpenAIMoeExperts, shuffle_weight + from ..integrations import quantize_to_mxfp4, Mxfp4OpenAIMoeExperts, shuffle_weight, convert_moe_packed_tensors from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + from ..modeling_utils import _load_parameter_into_model if not self.pre_quantized: module, _ = get_module_from_name(model, param_name) @@ -174,23 +175,78 @@ def create_quantized_param( loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) + # we take this path if alredy quantized but not in a compatible way: + else: + module, _ = get_module_from_name(model, param_name) + if isinstance(module, Mxfp4OpenAIMoeExperts): + if "gate_up_proj" in param_name: + if module.gate_up_proj_blocks.device.type == "meta" and module.gate_up_proj_scales.device.type == "meta": + _load_parameter_into_model(model, param_name, param_value) + return + else: + # In this case the weights are already on the device, so param_value should be the scale value + if (module.gate_up_proj_blocks.device != "meta" and "scales" in param_name) or (module.gate_up_proj_scales.device != "meta" and "blocks" in param_name): + _load_parameter_into_model(model, param_name, param_value) + else: + raise ValueError(f"Something went horribly wrong mate in gate_up_proj") + + dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) + + right_pad = module.gate_up_proj_right_pad + bottom_pad = module.gate_up_proj_bottom_pad + loaded_weight = torch.nn.functional.pad(dequantized_gate_up_proj, + (0, right_pad, 0, bottom_pad, 0, 0), + mode="constant", + value=0) + # del dequantized_gate_up_proj + # torch.cuda.empty_cache() + loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) + module.gate_up_proj_precision_config = None #PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.gate_up_proj = dequantized_gate_up_proj #torch.nn.Parameter(loaded_weight, requires_grad=False) + + elif "down_proj" in param_name: + if module.down_proj_blocks.device.type == "meta" and module.down_proj_scales.device.type == "meta": + _load_parameter_into_model(model, param_name, param_value) + return + else: + if (module.down_proj_blocks.device != "meta" and "scales" in param_name) or (module.down_proj_scales.device != "meta" and "blocks" in param_name): + _load_parameter_into_model(model, param_name, param_value) + else: + raise ValueError(f"Something went horribly wrong mate in down_proj") + + dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) + + right_pad = module.down_proj_right_pad + bottom_pad = module.down_proj_bottom_pad + + loaded_weight = torch.nn.functional.pad(dequantized_down_proj, + (0, right_pad, 0, bottom_pad, 0, 0), + mode="constant", + value=0) + # del dequantized_down_proj + + # torch.cuda.empty_cache() + loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) + module.down_proj_precision_config = None #PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.down_proj = dequantized_down_proj #torch.nn.Parameter(loaded_weight, requires_grad=False) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): from ..integrations import shuffle_weight, Mxfp4OpenAIMoeExperts - if not self.pre_quantized: - for module in model.modules(): - if isinstance(module, Mxfp4OpenAIMoeExperts): - gate_up_proj_bias = shuffle_weight(module.gate_up_proj_bias) - gate_up_proj_bias = gate_up_proj_bias.to(torch.float32) - gate_up_proj_bias = torch.nn.functional.pad(gate_up_proj_bias, (0, module.gate_up_proj_right_pad, 0, 0), - mode="constant", - value=0) - down_proj_bias = module.down_proj_bias.to(torch.float32) - down_proj_bias = torch.nn.functional.pad(down_proj_bias, (0, module.down_proj_right_pad, 0, 0), - mode="constant", - value=0) - module.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) - module.down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) + # if not self.pre_quantized: + # for module in model.modules(): + # if isinstance(module, Mxfp4OpenAIMoeExperts): + # # gate_up_proj_bias = shuffle_weight(module.gate_up_proj_bias) + # gate_up_proj_bias = module.gate_up_proj_bias.to(torch.float32) + # # gate_up_proj_bias = torch.nn.functional.pad(gate_up_proj_bias, (0, module.gate_up_proj_right_pad, 0, 0), + # # mode="constant", + # # value=0) + # down_proj_bias = module.down_proj_bias.to(torch.float32) + # # down_proj_bias = torch.nn.functional.pad(down_proj_bias, (0, module.down_proj_right_pad, 0, 0), + # # mode="constant", + # # value=0) + # module.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) + # module.down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) + reverse_replace_with_mxfp4_linear(model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config, config=model.config, tp_plan=model._tp_plan) return model def _process_model_before_weight_loading( @@ -241,3 +297,79 @@ def is_serializable(self, safe_serialization=None): @property def is_trainable(self) -> bool: return False + + +def _reverse_replace_with_mxfp4_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, + config=None, + tp_plan=None, +): + if current_key_name is None: + current_key_name = [] + + for name, module in model.named_children(): + current_key_name.append(name) + # if isinstance(module, nn.Linear): + # raise NotImplementedError("Mxfp4 linear layer is not implemented yet") + from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts + from accelerate import init_empty_weights + if module.__class__.__name__ == "Mxfp4OpenAIMoeExperts": + + # tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None + gate_up_proj = module.gate_up_proj + down_proj = module.down_proj + gate_up_proj_bias = module.gate_up_proj_bias + down_proj_bias = module.down_proj_bias + model._modules[name] = OpenAIMoeExperts(config) + model._modules[name].gate_up_proj = torch.nn.Parameter(gate_up_proj.transpose(1,2), requires_grad=False) + model._modules[name].down_proj = torch.nn.Parameter(down_proj, requires_grad=False) + model._modules[name].gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) + model._modules[name].down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) + has_been_replaced=True + if len(list(module.children())) > 0: + _, has_been_replaced = _reverse_replace_with_mxfp4_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + config=config, + tp_plan=tp_plan, + ) + current_key_name.pop(-1) + return model, has_been_replaced + + +def reverse_replace_with_mxfp4_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + config=None, + tp_plan=None, +): + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + + if quantization_config.modules_to_not_convert is not None: + modules_to_not_convert.extend(quantization_config.modules_to_not_convert) + modules_to_not_convert = list(set(modules_to_not_convert)) + model, has_been_replaced = _reverse_replace_with_mxfp4_linear( + model, + modules_to_not_convert, + current_key_name, + quantization_config, + config=config, + tp_plan=tp_plan, + ) + if not has_been_replaced: + logger.warning( + "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model From 0ff6727263f1be47ef5047ef2d440d6532b1c3ee Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 16 Jul 2025 11:03:37 +0000 Subject: [PATCH 194/342] add working dequant + quant logic --- src/transformers/integrations/__init__.py | 2 +- src/transformers/integrations/mxfp4.py | 104 ++++++++--- .../models/openai_moe/modeling_openai_moe.py | 2 +- .../quantizers/quantizer_mxfp4.py | 172 ++++++------------ src/transformers/utils/__init__.py | 2 + src/transformers/utils/import_utils.py | 14 +- src/transformers/utils/quantization_config.py | 2 + 7 files changed, 153 insertions(+), 145 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 713883ac668e..f0e18ddff2d6 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -119,7 +119,7 @@ "run_hp_search_sigopt", "run_hp_search_wandb", ], - "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4OpenAIMoeExperts", "quantize_to_mxfp4", "shuffle_weight", "convert_moe_packed_tensors"], + "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4OpenAIMoeExperts", "quantize_to_mxfp4", "shuffle_weight", "convert_moe_packed_tensors", "reverse_replace_with_mxfp4_linear"], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], "spqr": ["replace_with_spqr_linear"], diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 43ed397db185..d759ed87605b 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -84,7 +84,6 @@ def convert_moe_packed_tensors( scales = scales.to(torch.int32) - 127 - assert blocks.shape[:-1] == scales.shape, ( f"{blocks.shape=} does not match {scales.shape=}" ) @@ -159,7 +158,7 @@ def __init__(self, config): smallest_even_divide_number = lambda x, n: (x // n + 1) * n if x % n != 0 else x - self.gate_up_proj_right_pad = 0#smallest_even_divide_number(self.intermediate_size * 2, 256) - self.intermediate_size * 2 + self.gate_up_proj_right_pad = 0# smallest_even_divide_number(self.intermediate_size * 2, 256) - self.intermediate_size * 2 self.gate_up_proj_bottom_pad = 0 self.down_proj_right_pad = 0#smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size @@ -197,9 +196,7 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter value=0) apply_router_weight_on_input = False - # if torch.isnan(hidden_states).any(): - # raise ValueError("NaNs detected in hidden_states.") - + intermediate_cache1 = matmul_ogs(hidden_states, self.gate_up_proj, self.gate_up_proj_bias.to(torch.float32), @@ -211,13 +208,6 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter torch.cuda.synchronize() - if torch.isnan(intermediate_cache1).any(): - print(f"hidden_states.shape: {hidden_states.shape}, self.gate_up_proj.shape: {self.gate_up_proj.shape}") - print(f"precision_config: {self.gate_up_proj_precision_config}") - print(f"hidden_states.dtype: {hidden_states.dtype}, self.gate_up_proj.dtype: {self.gate_up_proj.dtype}, self.gate_up_proj_bias.dtype: {self.gate_up_proj_bias.dtype}") - - raise ValueError("NaNs detected in intermediate_cache1 after matmul_ogs (gate_up_proj).") - with torch.cuda.device(hidden_states.device): intermediate_cache3 = matmul_ogs( intermediate_cache1, @@ -229,15 +219,6 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter gammas=None if apply_router_weight_on_input else routing_data.gate_scal) torch.cuda.synchronize() - if torch.isnan(intermediate_cache3).any(): - print(f"routing_data: {routing_data}") - print(f"scatter_idx: {scatter_idx}") - print(f"intermediate_cache3.shape: {intermediate_cache3.shape}") - print(f"self.down_proj.shape: {self.down_proj.shape}") - print(f"self.down_proj_bias.shape: {self.down_proj_bias.shape}") - print(f"self.down_proj_precision_config: {self.down_proj_precision_config}") - raise ValueError("NaNs detected in intermediate_cache3 after matmul_ogs (down_proj).") - # manually crop the tensor since oai kernel pad the output output_states = intermediate_cache3[..., :self.hidden_size].contiguous() torch.cuda.synchronize() @@ -284,9 +265,9 @@ def _replace_with_mxfp4_linear( # tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None model._modules[name] = Mxfp4OpenAIMoeExperts(config) has_been_replaced=True - # if module.__class__.__name__ == "OpenAIMoeMLP": - # from types import MethodType - # module.forward = MethodType(mlp_forward, module) + if module.__class__.__name__ == "OpenAIMoeMLP" and not quantization_config.dequantize: + from types import MethodType + module.forward = MethodType(mlp_forward, module) if len(list(module.children())) > 0: _, has_been_replaced = _replace_with_mxfp4_linear( module, @@ -330,3 +311,78 @@ def replace_with_mxfp4_linear( ) return model + +def _reverse_replace_with_mxfp4_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, + config=None, + tp_plan=None, +): + if current_key_name is None: + current_key_name = [] + + for name, module in model.named_children(): + current_key_name.append(name) + # if isinstance(module, nn.Linear): + # raise NotImplementedError("Mxfp4 linear layer is not implemented yet") + from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts + from accelerate import init_empty_weights + if module.__class__.__name__ == "Mxfp4OpenAIMoeExperts": + + # tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None + gate_up_proj = module.gate_up_proj + down_proj = module.down_proj + gate_up_proj_bias = module.gate_up_proj_bias + down_proj_bias = module.down_proj_bias + model._modules[name] = OpenAIMoeExperts(config) + model._modules[name].gate_up_proj = torch.nn.Parameter(gate_up_proj, requires_grad=False) + model._modules[name].down_proj = torch.nn.Parameter(down_proj, requires_grad=False) + model._modules[name].gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) + model._modules[name].down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) + has_been_replaced=True + if len(list(module.children())) > 0: + _, has_been_replaced = _reverse_replace_with_mxfp4_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + config=config, + tp_plan=tp_plan, + ) + current_key_name.pop(-1) + return model, has_been_replaced + + +def reverse_replace_with_mxfp4_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + config=None, + tp_plan=None, +): + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + + if quantization_config.modules_to_not_convert is not None: + modules_to_not_convert.extend(quantization_config.modules_to_not_convert) + modules_to_not_convert = list(set(modules_to_not_convert)) + model, has_been_replaced = _reverse_replace_with_mxfp4_linear( + model, + modules_to_not_convert, + current_key_name, + quantization_config, + config=config, + tp_plan=tp_plan, + ) + if not has_been_replaced: + logger.warning( + "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index ac032620b875..8c69c03738f4 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -127,7 +127,7 @@ class TopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok - self.num_experts = config.num_local_experts + self.num_experts = config.num_experts if hasattr(config, "num_experts") else config.num_local_experts self.hidden_dim = config.hidden_size self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) self.bias = nn.Parameter(torch.empty(self.num_experts)) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 6c24054eb7a1..1c6f2396c86a 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from ..utils import is_torch_available, logging, is_accelerate_available +from ..utils import is_torch_available, logging, is_accelerate_available, is_triton_kernels_availalble, is_triton_available from .quantizers_utils import get_module_from_name @@ -55,11 +55,19 @@ def validate_environment(self, *args, **kwargs): compute_capability = torch.cuda.get_device_capability() major, minor = compute_capability - # TODO: Fix that - # if not is_triton_kernels_availalble(): - # raise ValueError( - # "MXFP4 quantization requires triton_kernels library" - # ) + + if not is_triton_available("3.4.0") or not is_triton_kernels_availalble(): + if self.pre_quantized: + logger.warning_once( + "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16" + ) + self.quantization_config.dequantize = True + else: + # we can't quantize the model in this case so we raise an error + raise ValueError( + "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed" + ) + if major < 9: raise ValueError( "MXFP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)" @@ -175,7 +183,7 @@ def create_quantized_param( loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) - # we take this path if alredy quantized but not in a compatible way: + # we take this path if already quantized but not in a compatible way: else: module, _ = get_module_from_name(model, param_name) if isinstance(module, Mxfp4OpenAIMoeExperts): @@ -191,6 +199,11 @@ def create_quantized_param( raise ValueError(f"Something went horribly wrong mate in gate_up_proj") dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) + dequantized_gate_up_proj = dequantized_gate_up_proj.transpose(1,2).to(target_device) + + if self.quantization_config.dequantize: + module.gate_up_proj = dequantized_gate_up_proj + return right_pad = module.gate_up_proj_right_pad bottom_pad = module.gate_up_proj_bottom_pad @@ -198,11 +211,12 @@ def create_quantized_param( (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0) - # del dequantized_gate_up_proj - # torch.cuda.empty_cache() - loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) - module.gate_up_proj_precision_config = None #PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) - module.gate_up_proj = dequantized_gate_up_proj #torch.nn.Parameter(loaded_weight, requires_grad=False) + del dequantized_gate_up_proj + torch.cuda.empty_cache() + with torch.cuda.device(target_device): + loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) + module.gate_up_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) elif "down_proj" in param_name: if module.down_proj_blocks.device.type == "meta" and module.down_proj_scales.device.type == "meta": @@ -215,38 +229,46 @@ def create_quantized_param( raise ValueError(f"Something went horribly wrong mate in down_proj") dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) - + dequantized_down_proj = dequantized_down_proj.transpose(1,2).to(target_device) + + if self.quantization_config.dequantize: + module.down_proj = dequantized_down_proj + return + right_pad = module.down_proj_right_pad bottom_pad = module.down_proj_bottom_pad - loaded_weight = torch.nn.functional.pad(dequantized_down_proj, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0) - # del dequantized_down_proj - - # torch.cuda.empty_cache() - loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) - module.down_proj_precision_config = None #PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) - module.down_proj = dequantized_down_proj #torch.nn.Parameter(loaded_weight, requires_grad=False) + del dequantized_down_proj + torch.cuda.empty_cache() + with torch.cuda.device(target_device): + loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) + + module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): - from ..integrations import shuffle_weight, Mxfp4OpenAIMoeExperts - # if not self.pre_quantized: - # for module in model.modules(): - # if isinstance(module, Mxfp4OpenAIMoeExperts): - # # gate_up_proj_bias = shuffle_weight(module.gate_up_proj_bias) - # gate_up_proj_bias = module.gate_up_proj_bias.to(torch.float32) - # # gate_up_proj_bias = torch.nn.functional.pad(gate_up_proj_bias, (0, module.gate_up_proj_right_pad, 0, 0), - # # mode="constant", - # # value=0) - # down_proj_bias = module.down_proj_bias.to(torch.float32) - # # down_proj_bias = torch.nn.functional.pad(down_proj_bias, (0, module.down_proj_right_pad, 0, 0), - # # mode="constant", - # # value=0) - # module.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) - # module.down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) - reverse_replace_with_mxfp4_linear(model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config, config=model.config, tp_plan=model._tp_plan) + from ..integrations import shuffle_weight, Mxfp4OpenAIMoeExperts, reverse_replace_with_mxfp4_linear + if not self.pre_quantized: + for module in model.modules(): + if isinstance(module, Mxfp4OpenAIMoeExperts): + # gate_up_proj_bias = shuffle_weight(module.gate_up_proj_bias) + gate_up_proj_bias = module.gate_up_proj_bias.to(torch.float32) + gate_up_proj_bias = torch.nn.functional.pad(gate_up_proj_bias, (0, module.gate_up_proj_right_pad, 0, 0), + mode="constant", + value=0) + down_proj_bias = module.down_proj_bias.to(torch.float32) + down_proj_bias = torch.nn.functional.pad(down_proj_bias, (0, module.down_proj_right_pad, 0, 0), + mode="constant", + value=0) + module.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) + module.down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) + + if self.quantization_config.dequantize: + reverse_replace_with_mxfp4_linear(model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config, config=model.config, tp_plan=model._tp_plan) + return model def _process_model_before_weight_loading( @@ -296,80 +318,4 @@ def is_serializable(self, safe_serialization=None): @property def is_trainable(self) -> bool: - return False - - -def _reverse_replace_with_mxfp4_linear( - model, - modules_to_not_convert=None, - current_key_name=None, - quantization_config=None, - has_been_replaced=False, - config=None, - tp_plan=None, -): - if current_key_name is None: - current_key_name = [] - - for name, module in model.named_children(): - current_key_name.append(name) - # if isinstance(module, nn.Linear): - # raise NotImplementedError("Mxfp4 linear layer is not implemented yet") - from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts - from accelerate import init_empty_weights - if module.__class__.__name__ == "Mxfp4OpenAIMoeExperts": - - # tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None - gate_up_proj = module.gate_up_proj - down_proj = module.down_proj - gate_up_proj_bias = module.gate_up_proj_bias - down_proj_bias = module.down_proj_bias - model._modules[name] = OpenAIMoeExperts(config) - model._modules[name].gate_up_proj = torch.nn.Parameter(gate_up_proj.transpose(1,2), requires_grad=False) - model._modules[name].down_proj = torch.nn.Parameter(down_proj, requires_grad=False) - model._modules[name].gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) - model._modules[name].down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) - has_been_replaced=True - if len(list(module.children())) > 0: - _, has_been_replaced = _reverse_replace_with_mxfp4_linear( - module, - modules_to_not_convert, - current_key_name, - quantization_config, - has_been_replaced=has_been_replaced, - config=config, - tp_plan=tp_plan, - ) - current_key_name.pop(-1) - return model, has_been_replaced - - -def reverse_replace_with_mxfp4_linear( - model, - modules_to_not_convert=None, - current_key_name=None, - quantization_config=None, - config=None, - tp_plan=None, -): - modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert - - if quantization_config.modules_to_not_convert is not None: - modules_to_not_convert.extend(quantization_config.modules_to_not_convert) - modules_to_not_convert = list(set(modules_to_not_convert)) - model, has_been_replaced = _reverse_replace_with_mxfp4_linear( - model, - modules_to_not_convert, - current_key_name, - quantization_config, - config=config, - tp_plan=tp_plan, - ) - if not has_been_replaced: - logger.warning( - "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." - " Please double check your model architecture, or submit an issue on github if you think this is" - " a bug." - ) - - return model + return False \ No newline at end of file diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 0bb3709a42ea..4b96244f75b6 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -273,6 +273,8 @@ is_yt_dlp_available, requires_backends, torch_only_method, + is_triton_available, + is_triton_kernels_availalble, ) from .peft_utils import ( ADAPTER_CONFIG_NAME, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 310b00eb7300..efd6c45e710f 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -111,6 +111,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ VPTQ_MIN_VERSION = "0.0.4" TORCHAO_MIN_VERSION = "0.4.0" AUTOROUND_MIN_VERSION = "0.5.0" +TRITON_MIN_VERSION = "3.4.0" _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _apex_available = _is_package_available("apex") @@ -225,12 +226,13 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _tiktoken_available = _is_package_available("tiktoken") _blobfile_available = _is_package_available("blobfile") _liger_kernel_available = _is_package_available("liger_kernel") -_triton_available = _is_package_available("triton") _spqr_available = _is_package_available("spqr_quant") _rich_available = _is_package_available("rich") _kernels_available = _is_package_available("kernels") _matplotlib_available = _is_package_available("matplotlib") _mistral_common_available = _is_package_available("mistral_common") +_triton_available, _triton_version = _is_package_available("triton", return_version=True) +_triton_kernels_available = _is_package_available("triton_kernels") _torch_version = "N/A" _torch_available = False @@ -406,6 +408,11 @@ def is_torch_deterministic(): return False +def is_triton_available(min_version: str = TRITON_MIN_VERSION): + return _triton_available and version.parse(_triton_version) >= version.parse(min_version) + +def is_triton_kernels_availalble(): + return _triton_kernels_available def is_hadamard_available(): return _hadamard_available @@ -1578,11 +1585,6 @@ def is_liger_kernel_available(): return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0") - -def is_triton_available(): - return _triton_available - - def is_rich_available(): return _rich_available diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 8ebe5d45557f..06e3c1b804dc 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -2066,7 +2066,9 @@ class Mxfp4Config(QuantizationConfigMixin): def __init__( self, modules_to_not_convert: Optional[list] = None, + dequantize: bool = False, **kwargs, ): self.quant_method = QuantizationMethod.MXFP4 self.modules_to_not_convert = modules_to_not_convert + self.dequantize = dequantize From 13cb07b07c9a24aa06562aeaa5732bfe65c0e58c Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 17 Jul 2025 08:38:49 +0000 Subject: [PATCH 195/342] update --- src/transformers/integrations/mxfp4.py | 2 +- src/transformers/modeling_utils.py | 3 +-- src/transformers/models/openai_moe/modeling_openai_moe.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index d759ed87605b..2cb85f42dda8 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -127,7 +127,7 @@ def __init__(self, config): from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, SwizzlingType super().__init__() - self.num_experts = config.num_experts + self.num_experts = config.num_experts if hasattr(config, "num_experts") else config.num_local_experts self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e0b7ca4aaf1c..cd899e6004a9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -720,8 +720,7 @@ def _infer_parameter_dtype( # in int/uint/bool and not cast them. casting_dtype = None is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn - is_param_float8_e5m2 = is_torch_e5m2_available and empty_param.dtype == torch.float8_e5m2 - if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn and not is_param_float8_e5m2: + if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: # if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: # First fp32 if part of the exception list if keep_in_fp32_regex is not None and keep_in_fp32_regex.search(param_name): diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 8c69c03738f4..0abee3fde9c6 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -115,7 +115,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] gate, up = gate_up[..., ::2], gate_up[..., 1::2] glu = gate * torch.sigmoid(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj.to(hidden_states.dtype)) + next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] From 39888563f69cba1bafae7cf13d2dd7e736d911dd Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Mon, 21 Jul 2025 13:24:27 +0000 Subject: [PATCH 196/342] non nan, gibberish output --- src/transformers/integrations/mxfp4.py | 7 +-- src/transformers/modeling_utils.py | 33 ++++++++++- .../quantizers/quantizer_mxfp4.py | 56 +++++++++++++++++-- 3 files changed, 85 insertions(+), 11 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 2cb85f42dda8..7e09fd30e8ca 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -122,9 +122,6 @@ def convert_moe_packed_tensors( # maybe subclass class Mxfp4OpenAIMoeExperts(nn.Module): def __init__(self, config): - from triton_kernels.matmul_ogs import InFlexData, FlexCtx, PrecisionConfig, MicroscalingCtx - - from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, SwizzlingType super().__init__() self.num_experts = config.num_experts if hasattr(config, "num_experts") else config.num_local_experts @@ -196,7 +193,7 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter value=0) apply_router_weight_on_input = False - + intermediate_cache1 = matmul_ogs(hidden_states, self.gate_up_proj, self.gate_up_proj_bias.to(torch.float32), @@ -228,7 +225,7 @@ def mlp_forward(self, hidden_states): from triton_kernels.routing import routing hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) - routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k, sm_first=False) + routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k, sm_first=False, simulated_ep=2) routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) return routed_out, router_logits diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index cd899e6004a9..da83aeb029c6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -798,7 +798,19 @@ def _load_state_dict_into_meta_model( hf_quantizer, ) - if device_mesh is not None: # In this case, the param is already on the correct device! + if device_mesh is not None and ( + not is_quantized + or (not hf_quantizer.requires_parameters_quantization) + or ( + not hf_quantizer.check_quantized_param( + model, + param, + param_name, + state_dict, + device_map=device_map, + ) + ) + ): # In this case, the param is already on the correct device! shard_and_distribute_module( model, param, @@ -809,6 +821,25 @@ def _load_state_dict_into_meta_model( device_mesh.get_local_rank(), device_mesh, ) + elif device_mesh is not None: # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param: + sharding_kwargs = { + "empty_param": empty_param, + "casting_dtype": casting_dtype, + "to_contiguous": to_contiguous, + "rank": device_mesh.get_local_rank(), + "device_mesh": device_mesh + } + if device_map is None: + param_device = "cpu" + else: + module_layer = re.search(device_map_regex, param_name) + if not module_layer: + raise ValueError(f"{param_name} doesn't have any device set.") + else: + param_device = device_map[module_layer.group()] + hf_quantizer.create_quantized_param( + model, param, param_name, param_device, state_dict, unexpected_keys, **sharding_kwargs + ) else: param = param[...] if casting_dtype is not None: diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 1c6f2396c86a..37d782bcccab 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -147,10 +147,12 @@ def create_quantized_param( target_device: "torch.device", state_dict: Dict[str, Any], unexpected_keys: Optional[List[str]] = None, + **kwargs, ): from ..integrations import quantize_to_mxfp4, Mxfp4OpenAIMoeExperts, shuffle_weight, convert_moe_packed_tensors from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig from ..modeling_utils import _load_parameter_into_model + from ..integrations.tensor_parallel import shard_and_distribute_module if not self.pre_quantized: module, _ = get_module_from_name(model, param_name) @@ -187,14 +189,21 @@ def create_quantized_param( else: module, _ = get_module_from_name(model, param_name) if isinstance(module, Mxfp4OpenAIMoeExperts): + tp_mode = kwargs.get("device_mesh", None) is not None if "gate_up_proj" in param_name: if module.gate_up_proj_blocks.device.type == "meta" and module.gate_up_proj_scales.device.type == "meta": - _load_parameter_into_model(model, param_name, param_value) + if tp_mode: + shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) + else: + _load_parameter_into_model(model, param_name, param_value) return else: - # In this case the weights are already on the device, so param_value should be the scale value + # In this case the weights or the scales are already on the correct device, so param_value should be the other missing param if (module.gate_up_proj_blocks.device != "meta" and "scales" in param_name) or (module.gate_up_proj_scales.device != "meta" and "blocks" in param_name): - _load_parameter_into_model(model, param_name, param_value) + if tp_mode: + shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) + else: + _load_parameter_into_model(model, param_name, param_value) else: raise ValueError(f"Something went horribly wrong mate in gate_up_proj") @@ -220,11 +229,17 @@ def create_quantized_param( elif "down_proj" in param_name: if module.down_proj_blocks.device.type == "meta" and module.down_proj_scales.device.type == "meta": - _load_parameter_into_model(model, param_name, param_value) + if tp_mode: + shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) + else: + _load_parameter_into_model(model, param_name, param_value) return else: if (module.down_proj_blocks.device != "meta" and "scales" in param_name) or (module.down_proj_scales.device != "meta" and "blocks" in param_name): - _load_parameter_into_model(model, param_name, param_value) + if tp_mode: + shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) + else: + _load_parameter_into_model(model, param_name, param_value) else: raise ValueError(f"Something went horribly wrong mate in down_proj") @@ -311,6 +326,37 @@ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> Li return [k for k in missing_keys if k not in not_missing_keys] def update_tp_plan(self, config): + config.base_model_tp_plan = { + # "embed_tokens": "vocab_parallel_rowwise", + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.sinks": "local_rowwise", + "layers.*.mlp.experts.gate_up_proj_blocks": "local_packed_rowwise", + "layers.*.mlp.experts.gate_up_proj_scales": "local_packed_rowwise", + "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", + "layers.*.mlp.experts.down_proj_blocks": "local_colwise", + "layers.*.mlp.experts.down_proj_scales": "local_colwise", + "layers.*.mlp.experts.down_proj_bias": "local_colwise", + # "layers.*.mlp": "gather", + } + config.base_model_ep_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.sinks": "local_rowwise", + "layers.*.mlp.experts": "gather", + "layers.*.mlp.router": "ep_router", + "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", + "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm", + "layers.*.mlp.experts.down_proj_scales": "grouped_gemm", + "layers.*.mlp.experts.down_proj_bias": "grouped_gemm", + } + return config def is_serializable(self, safe_serialization=None): From b9c8138bde28c3421112642812686ea0fd5f0484 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 22 Jul 2025 13:01:02 +0000 Subject: [PATCH 197/342] working EP + quantization finally ! --- src/transformers/integrations/mxfp4.py | 138 +++++++++++++++++- .../models/openai_moe/modeling_openai_moe.py | 1 + .../quantizers/quantizer_mxfp4.py | 40 +++-- 3 files changed, 161 insertions(+), 18 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 7e09fd30e8ca..07e762863a71 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -36,7 +36,6 @@ def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): from triton_kernels.numerics_details.mxfp import downcast_to_mxfp from triton_kernels.matmul_ogs import InFlexData, MicroscalingCtx - swizzle_axis = 2 if swizzle_mx_scale or swizzle_mx_value else None w = w.to(torch.bfloat16) @@ -168,6 +167,9 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter """ # type check, uint8 means mxfp4 #TODO: fp8 x mxfp4 on blackwell + # print("routing_data", routing_data) + # print("gather_idx", gather_idx) + # print("scatter_idx", scatter_idx) assert hidden_states.dtype == torch.bfloat16 assert self.gate_up_proj_blocks.dtype in (torch.bfloat16, torch.uint8) assert self.down_proj_blocks.dtype in (torch.bfloat16, torch.uint8) @@ -221,13 +223,139 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter torch.cuda.synchronize() return output_states +def create_expert_indices_for_rank(router_logits, top_k, ep_rank, ep_size, total_experts): + """Create expert indices that only select experts belonging to ep_rank""" + num_local_experts = total_experts // ep_size + router_top_value, router_indices = torch.topk(router_logits, top_k, dim=-1) # (seq_len, top_k) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + + router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts] + router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, 0) + router_indices = router_indices % num_local_experts + # Create indices that only point to local experts + return router_indices + + def mlp_forward(self, hidden_states): from triton_kernels.routing import routing hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) - routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k, sm_first=False, simulated_ep=2) + # expert_indices = create_expert_indices_for_rank(router_logits, self.router.top_k, self.experts.rank, self.experts.device_mesh.size(), self.experts.num_experts) + # print("expert_indices.shape", expert_indices.shape) + # print("expert_indices", expert_indices) + routing_data, gather_idx, scatter_idx = routing_torch_ep_2(router_logits, self.router.top_k) + # print("routing_data", routing_data) + # print("gather_idx", gather_idx) + # print("scatter_idx", scatter_idx) + # raise ValueError("stop here") routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) return routed_out, router_logits + +def routing_torch_ep_2( + logits, + n_expts_act, +): + import os + print("in routing_torch_ep_2") + from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch + + with torch.cuda.device(logits.device): + world_size = torch.distributed.get_world_size() + # world_size = 1 + rank = int(os.environ.get("LOCAL_RANK", 0)) + + replace_value = -1 + + n_tokens = logits.shape[0] + n_expts_tot = logits.shape[1] + + n_local_experts = n_expts_tot // world_size + # rank = 1 + local_expert_start = rank * n_local_experts + local_expert_end = (rank + 1) * n_local_experts + + # TODO: check why +20 ? + n_gates_pad = n_tokens * n_expts_act + + def topk(vals, k, expt_indx): + tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k] + tk_indx = tk_indx.long() + tk_val = torch.take_along_dim(vals, tk_indx, dim=1) + return tk_val, tk_indx.int() + + expt_scal, expt_indx = topk(logits, n_expts_act, None) + expt_scal = torch.softmax(expt_scal, dim=-1) + expt_indx, sort_indices = torch.sort(expt_indx, dim=1) + expt_scal = torch.gather(expt_scal, 1, sort_indices) + # print("expt_indx and expt_scal") + # print(expt_indx) + # print(expt_scal) + + # Flatten and mask for local experts + expt_scal = expt_scal.reshape(-1) + # print("local_experts") + # print(local_expert_start) + # print(local_expert_end) + hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1)[local_expert_start : local_expert_end] + + # for each row, count how many of its experts are local + # num_local_expts = (expt_indx < local_expert_end) & (local_expert_start <= expt_indx) + # num_local_expts = num_local_expts.sum(dim=1) + + # Count the number of rows that are for local experts, padded to an alignment. + # n_local_rows = (num_local_expts != 0).sum() + # n_local_rows = ((n_local_rows + row_align - 1) // row_align) * row_align + + # TODO: check if reorder really impacts or not the performances + # is_active = torch.argsort((num_local_expts == 0).to(torch.int8), stable=True) + # expt_indx = expt_indx[is_active] + expt_indx = expt_indx.view(-1).to(torch.int32) + # print("expt_indx 2 ") + # print(expt_indx) + + # Note: Because the number of rows routed to each expert is only known at runtime, + # we do not drop tokens that are not routed to the local expert. This ensures that + # the tensor shapes are fixed. + # Create topk_indx/gate_indx. + + # try to move values that were seen to later + # print(expt_indx) + expt_indx = torch.where(expt_indx < local_expert_start, 1000, expt_indx) + # print('after') + # print(expt_indx) + topk_indx = torch.argsort(expt_indx, stable=True).to(torch.int32) + # print("topk_indx") + # print(topk_indx) + gate_indx = torch.argsort(topk_indx).to(torch.int32) + # print("gate_indx") + # print(gate_indx) + # Now filter out all experts + expt_indx = torch.where(expt_indx < local_expert_end, expt_indx, replace_value) + expt_indx = torch.where(local_expert_start <= expt_indx, expt_indx, replace_value) + # print(expt_indx) + + gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx) + gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx) + gate_scal = expt_scal[topk_indx] + # print("updated expt_scal") + # print(gate_scal) + + topk_indx = torch.where(gate_indx[topk_indx] == replace_value, replace_value, topk_indx) + + # print(topk_indx) + + # # Routing metadata for local expert computation + gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int()) + scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int()) + + # n_gates_pad = local_expt_indx.numel() + expt_data = compute_expt_data_torch(hist, n_local_experts, n_gates_pad) + # print("expt_data") + # print(expt_data) + # hitted_experts = len(local_expt_indx) -> maybe try to get the closest power of 2 later on + hitted_experts = n_expts_act + return RoutingData(gate_scal, hist, n_local_experts, hitted_experts, expt_data), gather_indx, scatter_indx def should_convert_module(current_key_name, patterns): current_key_name_str = ".".join(current_key_name) @@ -260,6 +388,8 @@ def _replace_with_mxfp4_linear( if module.__class__.__name__ == "OpenAIMoeExperts": with init_empty_weights(): # tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None + _forward_pre_hooks = module._forward_pre_hooks + _forward_hooks = module._forward_hooks model._modules[name] = Mxfp4OpenAIMoeExperts(config) has_been_replaced=True if module.__class__.__name__ == "OpenAIMoeMLP" and not quantization_config.dequantize: @@ -334,7 +464,11 @@ def _reverse_replace_with_mxfp4_linear( down_proj = module.down_proj gate_up_proj_bias = module.gate_up_proj_bias down_proj_bias = module.down_proj_bias + _forward_pre_hooks = module._forward_pre_hooks + _forward_hooks = module._forward_hooks model._modules[name] = OpenAIMoeExperts(config) + model._modules[name]._forward_pre_hooks = _forward_pre_hooks + model._modules[name]._forward_hooks = _forward_hooks model._modules[name].gate_up_proj = torch.nn.Parameter(gate_up_proj, requires_grad=False) model._modules[name].down_proj = torch.nn.Parameter(down_proj, requires_grad=False) model._modules[name].gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 0abee3fde9c6..f286fe95affc 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -89,6 +89,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) num_experts = routing_weights.shape[1] + self.training = False if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 37d782bcccab..227ef40f8578 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -209,7 +209,10 @@ def create_quantized_param( dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) dequantized_gate_up_proj = dequantized_gate_up_proj.transpose(1,2).to(target_device) - + + module.device_mesh = kwargs.get("device_mesh") + module.rank = kwargs.get("rank") + if self.quantization_config.dequantize: module.gate_up_proj = dequantized_gate_up_proj return @@ -245,7 +248,8 @@ def create_quantized_param( dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) dequantized_down_proj = dequantized_down_proj.transpose(1,2).to(target_device) - + module.device_mesh = kwargs.get("device_mesh") + module.rank = kwargs.get("rank") if self.quantization_config.dequantize: module.down_proj = dequantized_down_proj return @@ -266,20 +270,20 @@ def create_quantized_param( def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): from ..integrations import shuffle_weight, Mxfp4OpenAIMoeExperts, reverse_replace_with_mxfp4_linear - if not self.pre_quantized: - for module in model.modules(): - if isinstance(module, Mxfp4OpenAIMoeExperts): - # gate_up_proj_bias = shuffle_weight(module.gate_up_proj_bias) - gate_up_proj_bias = module.gate_up_proj_bias.to(torch.float32) - gate_up_proj_bias = torch.nn.functional.pad(gate_up_proj_bias, (0, module.gate_up_proj_right_pad, 0, 0), - mode="constant", - value=0) - down_proj_bias = module.down_proj_bias.to(torch.float32) - down_proj_bias = torch.nn.functional.pad(down_proj_bias, (0, module.down_proj_right_pad, 0, 0), - mode="constant", - value=0) - module.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) - module.down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) + # if not self.pre_quantized: + # for module in model.modules(): + # if isinstance(module, Mxfp4OpenAIMoeExperts): + # # gate_up_proj_bias = shuffle_weight(module.gate_up_proj_bias) + # gate_up_proj_bias = module.gate_up_proj_bias.to(torch.float32) + # gate_up_proj_bias = torch.nn.functional.pad(gate_up_proj_bias, (0, module.gate_up_proj_right_pad, 0, 0), + # mode="constant", + # value=0) + # down_proj_bias = module.down_proj_bias.to(torch.float32) + # down_proj_bias = torch.nn.functional.pad(down_proj_bias, (0, module.down_proj_right_pad, 0, 0), + # mode="constant", + # value=0) + # module.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) + # module.down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) if self.quantization_config.dequantize: reverse_replace_with_mxfp4_linear(model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config, config=model.config, tp_plan=model._tp_plan) @@ -333,9 +337,11 @@ def update_tp_plan(self, config): "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.self_attn.sinks": "local_rowwise", + "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", "layers.*.mlp.experts.gate_up_proj_blocks": "local_packed_rowwise", "layers.*.mlp.experts.gate_up_proj_scales": "local_packed_rowwise", "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", + "layers.*.mlp.experts.down_proj": "local_colwise", "layers.*.mlp.experts.down_proj_blocks": "local_colwise", "layers.*.mlp.experts.down_proj_scales": "local_colwise", "layers.*.mlp.experts.down_proj_bias": "local_colwise", @@ -349,9 +355,11 @@ def update_tp_plan(self, config): "layers.*.self_attn.sinks": "local_rowwise", "layers.*.mlp.experts": "gather", "layers.*.mlp.router": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm", "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm", "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm", "layers.*.mlp.experts.down_proj_scales": "grouped_gemm", "layers.*.mlp.experts.down_proj_bias": "grouped_gemm", From 5117d71eb12c0a1f777d633373db0a43781d54d7 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 23 Jul 2025 11:08:54 +0000 Subject: [PATCH 198/342] start cleaning --- src/transformers/integrations/mxfp4.py | 2 +- src/transformers/quantizers/quantizer_mxfp4.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 07e762863a71..9db33dd3cf11 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -385,7 +385,7 @@ def _replace_with_mxfp4_linear( continue # if isinstance(module, nn.Linear): # raise NotImplementedError("Mxfp4 linear layer is not implemented yet") - if module.__class__.__name__ == "OpenAIMoeExperts": + if module.__class__.__name__ == "OpenAIMoeExperts" and not quantization_config.dequantize: with init_empty_weights(): # tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None _forward_pre_hooks = module._forward_pre_hooks diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 227ef40f8578..2743135b2d21 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -290,6 +290,14 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs return model + + def update_expected_keys(self, model: "PreTrainedModel", expected_keys: List[str]): + from ..integrations import Mxfp4OpenAIMoeExperts + for name, module in model.named_modules(): + if isinstance(module, Mxfp4OpenAIMoeExperts): + expected_keys.append(name) + return expected_keys + def _process_model_before_weight_loading( self, model: "PreTrainedModel", From 3733a3492649b821716affba49ed9b4486a262e0 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 23 Jul 2025 13:45:36 +0000 Subject: [PATCH 199/342] remove reversing process --- src/transformers/integrations/mxfp4.py | 111 ++------- .../integrations/tensor_parallel.py | 5 +- src/transformers/modeling_utils.py | 22 +- .../quantizers/quantizer_mxfp4.py | 218 ++++++++++-------- 4 files changed, 152 insertions(+), 204 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 9db33dd3cf11..c8241dbf7a0b 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -127,10 +127,7 @@ def __init__(self, config): self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - # print("#######################expert_dim: ", self.expert_dim) - # print("#######################num_experts: ", self.num_experts) - # print(config) - # dtype torch.float4_e2m1fn not supported yet + self.gate_up_proj_blocks = nn.Parameter( torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size//32, 16, dtype=torch.uint8), requires_grad=False, ) @@ -238,33 +235,30 @@ def create_expert_indices_for_rank(router_logits, top_k, ep_rank, ep_size, total def mlp_forward(self, hidden_states): - from triton_kernels.routing import routing + import torch.distributed as dist + if dist.is_available() and dist.is_initialized(): + routing = routing_torch_ep_2 + else: + from triton_kernels.routing import routing + routing = routing hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) - # expert_indices = create_expert_indices_for_rank(router_logits, self.router.top_k, self.experts.rank, self.experts.device_mesh.size(), self.experts.num_experts) - # print("expert_indices.shape", expert_indices.shape) - # print("expert_indices", expert_indices) - routing_data, gather_idx, scatter_idx = routing_torch_ep_2(router_logits, self.router.top_k) - # print("routing_data", routing_data) - # print("gather_idx", gather_idx) - # print("scatter_idx", scatter_idx) - # raise ValueError("stop here") + routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k, sm_first=False) routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) return routed_out, router_logits def routing_torch_ep_2( logits, n_expts_act, + sm_first=False ): import os - print("in routing_torch_ep_2") from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch with torch.cuda.device(logits.device): world_size = torch.distributed.get_world_size() # world_size = 1 rank = int(os.environ.get("LOCAL_RANK", 0)) - replace_value = -1 n_tokens = logits.shape[0] @@ -417,91 +411,16 @@ def replace_with_mxfp4_linear( config=None, tp_plan=None, ): - modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert - - if quantization_config.modules_to_not_convert is not None: - modules_to_not_convert.extend(quantization_config.modules_to_not_convert) - modules_to_not_convert = list(set(modules_to_not_convert)) - model, has_been_replaced = _replace_with_mxfp4_linear( - model, - modules_to_not_convert, - current_key_name, - quantization_config, - config=config, - tp_plan=tp_plan, - ) - if not has_been_replaced: - logger.warning( - "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." - " Please double check your model architecture, or submit an issue on github if you think this is" - " a bug." - ) - - return model - -def _reverse_replace_with_mxfp4_linear( - model, - modules_to_not_convert=None, - current_key_name=None, - quantization_config=None, - has_been_replaced=False, - config=None, - tp_plan=None, -): - if current_key_name is None: - current_key_name = [] - - for name, module in model.named_children(): - current_key_name.append(name) - # if isinstance(module, nn.Linear): - # raise NotImplementedError("Mxfp4 linear layer is not implemented yet") - from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts - from accelerate import init_empty_weights - if module.__class__.__name__ == "Mxfp4OpenAIMoeExperts": - - # tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None - gate_up_proj = module.gate_up_proj - down_proj = module.down_proj - gate_up_proj_bias = module.gate_up_proj_bias - down_proj_bias = module.down_proj_bias - _forward_pre_hooks = module._forward_pre_hooks - _forward_hooks = module._forward_hooks - model._modules[name] = OpenAIMoeExperts(config) - model._modules[name]._forward_pre_hooks = _forward_pre_hooks - model._modules[name]._forward_hooks = _forward_hooks - model._modules[name].gate_up_proj = torch.nn.Parameter(gate_up_proj, requires_grad=False) - model._modules[name].down_proj = torch.nn.Parameter(down_proj, requires_grad=False) - model._modules[name].gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) - model._modules[name].down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) - has_been_replaced=True - if len(list(module.children())) > 0: - _, has_been_replaced = _reverse_replace_with_mxfp4_linear( - module, - modules_to_not_convert, - current_key_name, - quantization_config, - has_been_replaced=has_been_replaced, - config=config, - tp_plan=tp_plan, - ) - current_key_name.pop(-1) - return model, has_been_replaced - -def reverse_replace_with_mxfp4_linear( - model, - modules_to_not_convert=None, - current_key_name=None, - quantization_config=None, - config=None, - tp_plan=None, -): + if quantization_config.dequantize: + return model + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert if quantization_config.modules_to_not_convert is not None: modules_to_not_convert.extend(quantization_config.modules_to_not_convert) modules_to_not_convert = list(set(modules_to_not_convert)) - model, has_been_replaced = _reverse_replace_with_mxfp4_linear( + model, has_been_replaced = _replace_with_mxfp4_linear( model, modules_to_not_convert, current_key_name, @@ -509,11 +428,11 @@ def reverse_replace_with_mxfp4_linear( config=config, tp_plan=tp_plan, ) - if not has_been_replaced: + if not has_been_replaced : logger.warning( "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." " Please double check your model architecture, or submit an issue on github if you think this is" " a bug." ) - return model + return model \ No newline at end of file diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 456a6dbb952c..96af3c16e0de 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -1068,7 +1068,7 @@ def __init__(self): def shard_and_distribute_module( - model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh + model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh, set_param_inside=True ): # TODO: rename to shard_and_distribute_param r""" Main uses cases: @@ -1118,7 +1118,8 @@ def shard_and_distribute_module( # otherwise loading is crazy slow if not isinstance(param, torch.nn.Parameter): param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point()) - setattr(module_to_tp, param_type, param) + if set_param_inside: + setattr(module_to_tp, param_type, param) # module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True) return param diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index da83aeb029c6..d2230ea1532f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -710,6 +710,7 @@ def _infer_parameter_dtype( if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in { QuantizationMethod.HQQ, QuantizationMethod.QUARK, + QuantizationMethod.MXFP4, }: return True, None else: @@ -829,16 +830,8 @@ def _load_state_dict_into_meta_model( "rank": device_mesh.get_local_rank(), "device_mesh": device_mesh } - if device_map is None: - param_device = "cpu" - else: - module_layer = re.search(device_map_regex, param_name) - if not module_layer: - raise ValueError(f"{param_name} doesn't have any device set.") - else: - param_device = device_map[module_layer.group()] hf_quantizer.create_quantized_param( - model, param, param_name, param_device, state_dict, unexpected_keys, **sharding_kwargs + model, param, param_name, device_mesh.get_local_rank(), state_dict, unexpected_keys, **sharding_kwargs ) else: param = param[...] @@ -6181,8 +6174,15 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, # Skip if the parameter has already been accounted for (tied weights) if param_name in tied_param_names: continue - - param = model.get_parameter_or_buffer(param_name) + try: + param = model.get_parameter_or_buffer(param_name) + except AttributeError: + if hf_quantizer.quantization_config.quant_method in {QuantizationMethod.MXFP4} and ("blocks" in param_name or "scales" in param_name): + neutral_param_name = param_name[:-len("_blocks")] if "blocks" in param_name else param_name[:-len("_scales")] + try: + param = model.get_parameter_or_buffer(neutral_param_name) + except AttributeError: + raise AttributeError(f"Parameter {param_name} not found in model") # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` param_byte_count = param.numel() * param.element_size() diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 2743135b2d21..e2f8ada45d01 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -131,9 +131,13 @@ def check_quantized_param( **kwargs, ): from ..integrations import Mxfp4OpenAIMoeExperts - module, tensor_name = get_module_from_name(model, param_name) + from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts + if self.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name): + module, tensor_name = get_module_from_name(model, param_name[:-len("_blocks")]) + else: + module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module, Mxfp4OpenAIMoeExperts): + if isinstance(module, Mxfp4OpenAIMoeExperts) or (isinstance(module, OpenAIMoeExperts) and self.quantization_config.dequantize): if tensor_name in ["down_proj_bias", "gate_up_proj_bias"]: return False return True @@ -153,6 +157,7 @@ def create_quantized_param( from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig from ..modeling_utils import _load_parameter_into_model from ..integrations.tensor_parallel import shard_and_distribute_module + from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts if not self.pre_quantized: module, _ = get_module_from_name(model, param_name) @@ -187,116 +192,139 @@ def create_quantized_param( module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) # we take this path if already quantized but not in a compatible way: else: - module, _ = get_module_from_name(model, param_name) - if isinstance(module, Mxfp4OpenAIMoeExperts): + if ("blocks" in param_name or "scales" in param_name) and self.quantization_config.dequantize: + # blocks and scales have the same length that's why the below line works + module, _ = get_module_from_name(model, param_name[:-len("_blocks")]) + else: + module, _ = get_module_from_name(model, param_name) + if isinstance(module, Mxfp4OpenAIMoeExperts) or (isinstance(module, OpenAIMoeExperts) and self.quantization_config.dequantize): tp_mode = kwargs.get("device_mesh", None) is not None - if "gate_up_proj" in param_name: - if module.gate_up_proj_blocks.device.type == "meta" and module.gate_up_proj_scales.device.type == "meta": - if tp_mode: - shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) + if self.quantization_config.dequantize: + neutral_param_name = param_name[:-len("_blocks")] if "blocks" in param_name else param_name[:-len("_scales")] + if "gate_up_proj" in param_name: + if not hasattr(module, "gate_up_proj_blocks") and not hasattr(module, "gate_up_proj_scales"): + if tp_mode: + param_value = shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), neutral_param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh"), set_param_inside=False) + setattr(module, param_name.rsplit(".", 1)[1], param_value) + return else: - _load_parameter_into_model(model, param_name, param_value) - return - else: - # In this case the weights or the scales are already on the correct device, so param_value should be the other missing param - if (module.gate_up_proj_blocks.device != "meta" and "scales" in param_name) or (module.gate_up_proj_scales.device != "meta" and "blocks" in param_name): + if tp_mode: + param_value = shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), neutral_param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh"), set_param_inside=False) + setattr(module, param_name.rsplit(".", 1)[1], param_value) + + dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) + dequantized_gate_up_proj = dequantized_gate_up_proj.transpose(1,2).to(target_device) + module.gate_up_proj = torch.nn.Parameter(dequantized_gate_up_proj, requires_grad=False) + del module.gate_up_proj_blocks + del module.gate_up_proj_scales + return + elif "down_proj" in param_name: + if not hasattr(module, "down_proj_blocks") and not hasattr(module, "down_proj_scales"): + if tp_mode: + param_value = shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), neutral_param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh"), set_param_inside=False) + setattr(module, param_name.rsplit(".", 1)[1], param_value) + return + else: + if tp_mode: + param_value = shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), neutral_param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh"), set_param_inside=False) + setattr(module, param_name.rsplit(".", 1)[1], param_value) + + dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) + dequantized_down_proj = dequantized_down_proj.transpose(1,2).to(target_device) + module.down_proj = torch.nn.Parameter(dequantized_down_proj, requires_grad=False) + del module.down_proj_blocks + del module.down_proj_scales + return + else: + if "gate_up_proj" in param_name: + if module.gate_up_proj_blocks.device.type == "meta" and module.gate_up_proj_scales.device.type == "meta": if tp_mode: shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) else: _load_parameter_into_model(model, param_name, param_value) - else: - raise ValueError(f"Something went horribly wrong mate in gate_up_proj") - - dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) - dequantized_gate_up_proj = dequantized_gate_up_proj.transpose(1,2).to(target_device) - - module.device_mesh = kwargs.get("device_mesh") - module.rank = kwargs.get("rank") - - if self.quantization_config.dequantize: - module.gate_up_proj = dequantized_gate_up_proj return - - right_pad = module.gate_up_proj_right_pad - bottom_pad = module.gate_up_proj_bottom_pad - loaded_weight = torch.nn.functional.pad(dequantized_gate_up_proj, - (0, right_pad, 0, bottom_pad, 0, 0), - mode="constant", - value=0) - del dequantized_gate_up_proj - torch.cuda.empty_cache() - with torch.cuda.device(target_device): - loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) - module.gate_up_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) - module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) - - elif "down_proj" in param_name: - if module.down_proj_blocks.device.type == "meta" and module.down_proj_scales.device.type == "meta": - if tp_mode: - shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) else: - _load_parameter_into_model(model, param_name, param_value) - return - else: - if (module.down_proj_blocks.device != "meta" and "scales" in param_name) or (module.down_proj_scales.device != "meta" and "blocks" in param_name): + # In this case the weights or the scales are already on the correct device, so param_value should be the other missing param + if (module.gate_up_proj_blocks.device != "meta" and "scales" in param_name) or (module.gate_up_proj_scales.device != "meta" and "blocks" in param_name): + if tp_mode: + shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) + else: + _load_parameter_into_model(model, param_name, param_value) + else: + raise ValueError(f"Something went horribly wrong mate in gate_up_proj") + + dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) + dequantized_gate_up_proj = dequantized_gate_up_proj.transpose(1,2).to(target_device) + + module.device_mesh = kwargs.get("device_mesh") + module.rank = kwargs.get("rank") + + right_pad = module.gate_up_proj_right_pad + bottom_pad = module.gate_up_proj_bottom_pad + loaded_weight = torch.nn.functional.pad(dequantized_gate_up_proj, + (0, right_pad, 0, bottom_pad, 0, 0), + mode="constant", + value=0) + del dequantized_gate_up_proj + torch.cuda.empty_cache() + with torch.cuda.device(target_device): + loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) + module.gate_up_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) + + elif "down_proj" in param_name: + if module.down_proj_blocks.device.type == "meta" and module.down_proj_scales.device.type == "meta": if tp_mode: shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) else: _load_parameter_into_model(model, param_name, param_value) - else: - raise ValueError(f"Something went horribly wrong mate in down_proj") - - dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) - dequantized_down_proj = dequantized_down_proj.transpose(1,2).to(target_device) - module.device_mesh = kwargs.get("device_mesh") - module.rank = kwargs.get("rank") - if self.quantization_config.dequantize: - module.down_proj = dequantized_down_proj return - - right_pad = module.down_proj_right_pad - bottom_pad = module.down_proj_bottom_pad - loaded_weight = torch.nn.functional.pad(dequantized_down_proj, - (0, right_pad, 0, bottom_pad, 0, 0), - mode="constant", - value=0) - del dequantized_down_proj - torch.cuda.empty_cache() - with torch.cuda.device(target_device): - loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) - - module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) - module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) + else: + if (module.down_proj_blocks.device != "meta" and "scales" in param_name) or (module.down_proj_scales.device != "meta" and "blocks" in param_name): + if tp_mode: + shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) + else: + _load_parameter_into_model(model, param_name, param_value) + else: + raise ValueError(f"Something went horribly wrong mate in down_proj") + + dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) + dequantized_down_proj = dequantized_down_proj.transpose(1,2).to(target_device) + module.device_mesh = kwargs.get("device_mesh") + module.rank = kwargs.get("rank") + + right_pad = module.down_proj_right_pad + bottom_pad = module.down_proj_bottom_pad + loaded_weight = torch.nn.functional.pad(dequantized_down_proj, + (0, right_pad, 0, bottom_pad, 0, 0), + mode="constant", + value=0) + del dequantized_down_proj + torch.cuda.empty_cache() + with torch.cuda.device(target_device): + loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) + + module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): - from ..integrations import shuffle_weight, Mxfp4OpenAIMoeExperts, reverse_replace_with_mxfp4_linear - # if not self.pre_quantized: - # for module in model.modules(): - # if isinstance(module, Mxfp4OpenAIMoeExperts): - # # gate_up_proj_bias = shuffle_weight(module.gate_up_proj_bias) - # gate_up_proj_bias = module.gate_up_proj_bias.to(torch.float32) - # gate_up_proj_bias = torch.nn.functional.pad(gate_up_proj_bias, (0, module.gate_up_proj_right_pad, 0, 0), - # mode="constant", - # value=0) - # down_proj_bias = module.down_proj_bias.to(torch.float32) - # down_proj_bias = torch.nn.functional.pad(down_proj_bias, (0, module.down_proj_right_pad, 0, 0), - # mode="constant", - # value=0) - # module.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias, requires_grad=False) - # module.down_proj_bias = torch.nn.Parameter(down_proj_bias, requires_grad=False) - - if self.quantization_config.dequantize: - reverse_replace_with_mxfp4_linear(model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config, config=model.config, tp_plan=model._tp_plan) - return model - - def update_expected_keys(self, model: "PreTrainedModel", expected_keys: List[str]): - from ..integrations import Mxfp4OpenAIMoeExperts - for name, module in model.named_modules(): - if isinstance(module, Mxfp4OpenAIMoeExperts): - expected_keys.append(name) - return expected_keys + def update_expected_keys(self, model: "PreTrainedModel", expected_keys: List[str], checkpoint_keys: List[str]): + # Replace expected_keys for experts' gate_up_proj and down_proj with their _blocks and _scales variants + new_expected_keys = [] + for key in expected_keys: + if key.endswith(".mlp.experts.gate_up_proj"): + base = key[:-len("gate_up_proj")] + new_expected_keys.append(base + "gate_up_proj_blocks") + new_expected_keys.append(base + "gate_up_proj_scales") + elif key.endswith(".mlp.experts.down_proj"): + base = key[:-len("down_proj")] + new_expected_keys.append(base + "down_proj_blocks") + new_expected_keys.append(base + "down_proj_scales") + else: + new_expected_keys.append(key) + return new_expected_keys def _process_model_before_weight_loading( self, From 658735966602a0fd41aa0e2e3d77ac53bc5ab03e Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 23 Jul 2025 13:49:20 +0000 Subject: [PATCH 200/342] style --- src/transformers/integrations/mxfp4.py | 39 ++++++++------- src/transformers/modeling_utils.py | 4 +- .../quantizers/quantizer_mxfp4.py | 49 +++++++++++-------- 3 files changed, 50 insertions(+), 42 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index c8241dbf7a0b..d59457890e59 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -32,9 +32,9 @@ -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, ] -def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): - from triton_kernels.numerics_details.mxfp import downcast_to_mxfp +def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): from triton_kernels.matmul_ogs import InFlexData, MicroscalingCtx + from triton_kernels.numerics_details.mxfp import downcast_to_mxfp swizzle_axis = 2 if swizzle_mx_scale or swizzle_mx_value else None w = w.to(torch.bfloat16) @@ -152,11 +152,11 @@ def __init__(self, config): smallest_even_divide_number = lambda x, n: (x // n + 1) * n if x % n != 0 else x self.gate_up_proj_right_pad = 0# smallest_even_divide_number(self.intermediate_size * 2, 256) - self.intermediate_size * 2 - self.gate_up_proj_bottom_pad = 0 - + self.gate_up_proj_bottom_pad = 0 + self.down_proj_right_pad = 0#smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size self.down_proj_bottom_pad = 0#self.gate_up_proj_right_pad // 2 - + self.hidden_size_pad = 0#smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: """ @@ -204,7 +204,7 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter torch.cuda.synchronize() - with torch.cuda.device(hidden_states.device): + with torch.cuda.device(hidden_states.device): intermediate_cache3 = matmul_ogs( intermediate_cache1, self.down_proj, @@ -213,7 +213,7 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter scatter_indx=scatter_idx, precision_config=self.down_proj_precision_config, gammas=None if apply_router_weight_on_input else routing_data.gate_scal) - + torch.cuda.synchronize() # manually crop the tensor since oai kernel pad the output output_states = intermediate_cache3[..., :self.hidden_size].contiguous() @@ -246,21 +246,22 @@ def mlp_forward(self, hidden_states): routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k, sm_first=False) routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) return routed_out, router_logits - + def routing_torch_ep_2( logits, n_expts_act, sm_first=False ): import os + from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch with torch.cuda.device(logits.device): - world_size = torch.distributed.get_world_size() + world_size = torch.distributed.get_world_size() # world_size = 1 rank = int(os.environ.get("LOCAL_RANK", 0)) replace_value = -1 - + n_tokens = logits.shape[0] n_expts_tot = logits.shape[1] @@ -269,7 +270,7 @@ def routing_torch_ep_2( local_expert_start = rank * n_local_experts local_expert_end = (rank + 1) * n_local_experts - # TODO: check why +20 ? + # TODO: check why +20 ? n_gates_pad = n_tokens * n_expts_act def topk(vals, k, expt_indx): @@ -296,7 +297,7 @@ def topk(vals, k, expt_indx): # for each row, count how many of its experts are local # num_local_expts = (expt_indx < local_expert_end) & (local_expert_start <= expt_indx) # num_local_expts = num_local_expts.sum(dim=1) - + # Count the number of rows that are for local experts, padded to an alignment. # n_local_rows = (num_local_expts != 0).sum() # n_local_rows = ((n_local_rows + row_align - 1) // row_align) * row_align @@ -312,7 +313,7 @@ def topk(vals, k, expt_indx): # we do not drop tokens that are not routed to the local expert. This ensures that # the tensor shapes are fixed. # Create topk_indx/gate_indx. - + # try to move values that were seen to later # print(expt_indx) expt_indx = torch.where(expt_indx < local_expert_start, 1000, expt_indx) @@ -324,7 +325,7 @@ def topk(vals, k, expt_indx): gate_indx = torch.argsort(topk_indx).to(torch.int32) # print("gate_indx") # print(gate_indx) - # Now filter out all experts + # Now filter out all experts expt_indx = torch.where(expt_indx < local_expert_end, expt_indx, replace_value) expt_indx = torch.where(local_expert_start <= expt_indx, expt_indx, replace_value) # print(expt_indx) @@ -334,7 +335,7 @@ def topk(vals, k, expt_indx): gate_scal = expt_scal[topk_indx] # print("updated expt_scal") # print(gate_scal) - + topk_indx = torch.where(gate_indx[topk_indx] == replace_value, replace_value, topk_indx) # print(topk_indx) @@ -342,12 +343,12 @@ def topk(vals, k, expt_indx): # # Routing metadata for local expert computation gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int()) scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int()) - + # n_gates_pad = local_expt_indx.numel() expt_data = compute_expt_data_torch(hist, n_local_experts, n_gates_pad) # print("expt_data") # print(expt_data) - # hitted_experts = len(local_expt_indx) -> maybe try to get the closest power of 2 later on + # hitted_experts = len(local_expt_indx) -> maybe try to get the closest power of 2 later on hitted_experts = n_expts_act return RoutingData(gate_scal, hist, n_local_experts, hitted_experts, expt_data), gather_indx, scatter_indx @@ -414,7 +415,7 @@ def replace_with_mxfp4_linear( if quantization_config.dequantize: return model - + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert if quantization_config.modules_to_not_convert is not None: @@ -435,4 +436,4 @@ def replace_with_mxfp4_linear( " a bug." ) - return model \ No newline at end of file + return model diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d2230ea1532f..f98c85729603 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -51,9 +51,9 @@ from torchao.quantization import Int4WeightOnlyConfig from .configuration_utils import PretrainedConfig +from .distributed import DistributedConfig from .dynamic_module_utils import custom_object_save from .generation import CompileConfig, GenerationConfig -from .distributed import DistributedConfig from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled from .integrations.accelerate import find_tied_parameters, init_empty_weights from .integrations.deepspeed import _load_state_dict_into_zero3_model @@ -2289,7 +2289,7 @@ def post_init(self): raise ValueError( f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}" ) - + if is_torch_greater_or_equal("2.5") and _torch_distributed_available and getattr(self.config, "device_mesh", None) is not None: # loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit device_mesh = self.config.device_mesh diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index e2f8ada45d01..22930e306cdc 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Optional from .base import HfQuantizer @@ -19,7 +19,13 @@ if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from ..utils import is_torch_available, logging, is_accelerate_available, is_triton_kernels_availalble, is_triton_available +from ..utils import ( + is_accelerate_available, + is_torch_available, + is_triton_available, + is_triton_kernels_availalble, + logging, +) from .quantizers_utils import get_module_from_name @@ -67,14 +73,14 @@ def validate_environment(self, *args, **kwargs): raise ValueError( "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed" ) - + if major < 9: raise ValueError( - "MXFP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)" + "MXFP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100, or B100)" ) if not is_accelerate_available(): raise ImportError( - f"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=1.8.0'`" + "Using mxfp4 requires Accelerate: `pip install 'accelerate>=1.8.0'`" ) device_map = kwargs.get("device_map", None) @@ -127,7 +133,7 @@ def check_quantized_param( model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], **kwargs, ): from ..integrations import Mxfp4OpenAIMoeExperts @@ -149,14 +155,15 @@ def create_quantized_param( param_value: "torch.Tensor", param_name: str, target_device: "torch.device", - state_dict: Dict[str, Any], - unexpected_keys: Optional[List[str]] = None, + state_dict: dict[str, Any], + unexpected_keys: Optional[list[str]] = None, **kwargs, ): - from ..integrations import quantize_to_mxfp4, Mxfp4OpenAIMoeExperts, shuffle_weight, convert_moe_packed_tensors from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig - from ..modeling_utils import _load_parameter_into_model + + from ..integrations import Mxfp4OpenAIMoeExperts, convert_moe_packed_tensors, quantize_to_mxfp4, shuffle_weight from ..integrations.tensor_parallel import shard_and_distribute_module + from ..modeling_utils import _load_parameter_into_model from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts if not self.pre_quantized: @@ -251,7 +258,7 @@ def create_quantized_param( else: _load_parameter_into_model(model, param_name, param_value) else: - raise ValueError(f"Something went horribly wrong mate in gate_up_proj") + raise ValueError("Something went horribly wrong mate in gate_up_proj") dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) dequantized_gate_up_proj = dequantized_gate_up_proj.transpose(1,2).to(target_device) @@ -267,7 +274,7 @@ def create_quantized_param( value=0) del dequantized_gate_up_proj torch.cuda.empty_cache() - with torch.cuda.device(target_device): + with torch.cuda.device(target_device): loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) module.gate_up_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) @@ -285,9 +292,9 @@ def create_quantized_param( shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) else: _load_parameter_into_model(model, param_name, param_value) - else: - raise ValueError(f"Something went horribly wrong mate in down_proj") - + else: + raise ValueError("Something went horribly wrong mate in down_proj") + dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) dequantized_down_proj = dequantized_down_proj.transpose(1,2).to(target_device) module.device_mesh = kwargs.get("device_mesh") @@ -301,16 +308,16 @@ def create_quantized_param( value=0) del dequantized_down_proj torch.cuda.empty_cache() - with torch.cuda.device(target_device): + with torch.cuda.device(target_device): loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) - + module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): return model - def update_expected_keys(self, model: "PreTrainedModel", expected_keys: List[str], checkpoint_keys: List[str]): + def update_expected_keys(self, model: "PreTrainedModel", expected_keys: list[str], checkpoint_keys: list[str]): # Replace expected_keys for experts' gate_up_proj and down_proj with their _blocks and _scales variants new_expected_keys = [] for key in expected_keys: @@ -329,7 +336,7 @@ def update_expected_keys(self, model: "PreTrainedModel", expected_keys: List[str def _process_model_before_weight_loading( self, model: "PreTrainedModel", - keep_in_fp32_modules: Optional[List[str]] = None, + keep_in_fp32_modules: Optional[list[str]] = None, **kwargs, ): from ..integrations import replace_with_mxfp4_linear @@ -350,7 +357,7 @@ def _process_model_before_weight_loading( model.config.quantization_config = self.quantization_config - def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: + def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]: from ..integrations import Mxfp4OpenAIMoeExperts not_missing_keys = [] @@ -408,4 +415,4 @@ def is_serializable(self, safe_serialization=None): @property def is_trainable(self) -> bool: - return False \ No newline at end of file + return False From 7961073102cc36ab76df1d5b3f263a283b068643 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 23 Jul 2025 13:53:17 +0000 Subject: [PATCH 201/342] some cleaning --- src/transformers/integrations/mxfp4.py | 51 +++----------------------- 1 file changed, 5 insertions(+), 46 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index d59457890e59..2eb1bb8d9153 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -237,7 +237,7 @@ def create_expert_indices_for_rank(router_logits, top_k, ep_rank, ep_size, total def mlp_forward(self, hidden_states): import torch.distributed as dist if dist.is_available() and dist.is_initialized(): - routing = routing_torch_ep_2 + routing = routing_torch_dist else: from triton_kernels.routing import routing routing = routing @@ -247,7 +247,7 @@ def mlp_forward(self, hidden_states): routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) return routed_out, router_logits -def routing_torch_ep_2( +def routing_torch_dist( logits, n_expts_act, sm_first=False @@ -258,7 +258,6 @@ def routing_torch_ep_2( with torch.cuda.device(logits.device): world_size = torch.distributed.get_world_size() - # world_size = 1 rank = int(os.environ.get("LOCAL_RANK", 0)) replace_value = -1 @@ -266,11 +265,9 @@ def routing_torch_ep_2( n_expts_tot = logits.shape[1] n_local_experts = n_expts_tot // world_size - # rank = 1 local_expert_start = rank * n_local_experts local_expert_end = (rank + 1) * n_local_experts - # TODO: check why +20 ? n_gates_pad = n_tokens * n_expts_act def topk(vals, k, expt_indx): @@ -283,72 +280,34 @@ def topk(vals, k, expt_indx): expt_scal = torch.softmax(expt_scal, dim=-1) expt_indx, sort_indices = torch.sort(expt_indx, dim=1) expt_scal = torch.gather(expt_scal, 1, sort_indices) - # print("expt_indx and expt_scal") - # print(expt_indx) - # print(expt_scal) + # Flatten and mask for local experts expt_scal = expt_scal.reshape(-1) - # print("local_experts") - # print(local_expert_start) - # print(local_expert_end) - hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1)[local_expert_start : local_expert_end] - - # for each row, count how many of its experts are local - # num_local_expts = (expt_indx < local_expert_end) & (local_expert_start <= expt_indx) - # num_local_expts = num_local_expts.sum(dim=1) - # Count the number of rows that are for local experts, padded to an alignment. - # n_local_rows = (num_local_expts != 0).sum() - # n_local_rows = ((n_local_rows + row_align - 1) // row_align) * row_align + hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1)[local_expert_start : local_expert_end] - # TODO: check if reorder really impacts or not the performances - # is_active = torch.argsort((num_local_expts == 0).to(torch.int8), stable=True) - # expt_indx = expt_indx[is_active] expt_indx = expt_indx.view(-1).to(torch.int32) - # print("expt_indx 2 ") - # print(expt_indx) - - # Note: Because the number of rows routed to each expert is only known at runtime, - # we do not drop tokens that are not routed to the local expert. This ensures that - # the tensor shapes are fixed. - # Create topk_indx/gate_indx. - # try to move values that were seen to later - # print(expt_indx) expt_indx = torch.where(expt_indx < local_expert_start, 1000, expt_indx) - # print('after') - # print(expt_indx) topk_indx = torch.argsort(expt_indx, stable=True).to(torch.int32) - # print("topk_indx") - # print(topk_indx) gate_indx = torch.argsort(topk_indx).to(torch.int32) - # print("gate_indx") - # print(gate_indx) - # Now filter out all experts expt_indx = torch.where(expt_indx < local_expert_end, expt_indx, replace_value) expt_indx = torch.where(local_expert_start <= expt_indx, expt_indx, replace_value) - # print(expt_indx) gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx) gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx) gate_scal = expt_scal[topk_indx] - # print("updated expt_scal") - # print(gate_scal) topk_indx = torch.where(gate_indx[topk_indx] == replace_value, replace_value, topk_indx) - # print(topk_indx) # # Routing metadata for local expert computation gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int()) scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int()) - # n_gates_pad = local_expt_indx.numel() expt_data = compute_expt_data_torch(hist, n_local_experts, n_gates_pad) - # print("expt_data") - # print(expt_data) - # hitted_experts = len(local_expt_indx) -> maybe try to get the closest power of 2 later on + hitted_experts = n_expts_act return RoutingData(gate_scal, hist, n_local_experts, hitted_experts, expt_data), gather_indx, scatter_indx From 0de006a23407a1bf63bf34597f030d4d7219576a Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 22 Jul 2025 14:59:06 +0000 Subject: [PATCH 202/342] initial commmit --- .../convert_openai_weights_to_hf.py | 52 ++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 72911b8960be..ed57224ceb8b 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -80,6 +80,56 @@ def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None): -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, ] +def convert_moe_packed_tensors( + blocks, + scales, + *, + dtype: torch.dtype = torch.bfloat16, + rows_per_chunk: int = 32768 * 1024, +) -> torch.Tensor: + import math + scales = scales.to(torch.int32) - 127 + + assert blocks.shape[:-1] == scales.shape, ( + f"{blocks.shape=} does not match {scales.shape=}" + ) + + lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) + + *prefix_shape, G, B = blocks.shape + rows_total = math.prod(prefix_shape) * G + + blocks = blocks.reshape(rows_total, B) + scales = scales.reshape(rows_total, 1) + + out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + del idx_lo, idx_hi, blk, exp + + out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + # to match for now existing implementation + return out.to(torch.float8_e5m2) + +FP4_VALUES = [ + +0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +] + def convert_moe_packed_tensors( blocks, scales, @@ -157,7 +207,7 @@ def write_model( if file.endswith(".safetensors"): final_.update(safe_load(os.path.join(input_base_path, file))) - print("Converting ..") + print("Converting ..", unpack) all_keys = final_.keys() new_keys = convert_old_keys_to_new_keys(all_keys) From 12a9e802f54695d2f4155a7a97f3bd7381e7be59 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 23 Jul 2025 14:51:08 +0000 Subject: [PATCH 203/342] more cleaning --- src/transformers/integrations/mxfp4.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 2eb1bb8d9153..e92b610bceb6 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -164,16 +164,10 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter """ # type check, uint8 means mxfp4 #TODO: fp8 x mxfp4 on blackwell - # print("routing_data", routing_data) - # print("gather_idx", gather_idx) - # print("scatter_idx", scatter_idx) assert hidden_states.dtype == torch.bfloat16 assert self.gate_up_proj_blocks.dtype in (torch.bfloat16, torch.uint8) assert self.down_proj_blocks.dtype in (torch.bfloat16, torch.uint8) - # assert self.gate_up_proj_bias.dtype == torch.float32, "expected float32 for gate_up_proj_bias and got " + str(self.gate_up_proj_bias.dtype) - # assert self.down_proj_bias.dtype == torch.float32, "expected float32 for down_proj_bias and got " + str(self.down_proj_bias.dtype) - # Shape check, only check non-mxfp4 if self.gate_up_proj_blocks.dtype != torch.uint8: assert hidden_states.ndim == 2 assert hidden_states.shape[-1] == self.gate_up_proj_blocks.shape[-2] @@ -220,20 +214,6 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter torch.cuda.synchronize() return output_states -def create_expert_indices_for_rank(router_logits, top_k, ep_rank, ep_size, total_experts): - """Create expert indices that only select experts belonging to ep_rank""" - num_local_experts = total_experts // ep_size - router_top_value, router_indices = torch.topk(router_logits, top_k, dim=-1) # (seq_len, top_k) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - - router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts] - router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, 0) - router_indices = router_indices % num_local_experts - # Create indices that only point to local experts - return router_indices - - def mlp_forward(self, hidden_states): import torch.distributed as dist if dist.is_available() and dist.is_initialized(): @@ -337,13 +317,8 @@ def _replace_with_mxfp4_linear( if not should_convert_module(current_key_name, modules_to_not_convert): current_key_name.pop(-1) continue - # if isinstance(module, nn.Linear): - # raise NotImplementedError("Mxfp4 linear layer is not implemented yet") if module.__class__.__name__ == "OpenAIMoeExperts" and not quantization_config.dequantize: with init_empty_weights(): - # tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None - _forward_pre_hooks = module._forward_pre_hooks - _forward_hooks = module._forward_hooks model._modules[name] = Mxfp4OpenAIMoeExperts(config) has_been_replaced=True if module.__class__.__name__ == "OpenAIMoeMLP" and not quantization_config.dequantize: From 39047834dc71f59e1af17e60ddfab735e38596b6 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 24 Jul 2025 08:39:47 +0000 Subject: [PATCH 204/342] more cleaning --- src/transformers/integrations/__init__.py | 4 +- src/transformers/integrations/mxfp4.py | 71 ++++--------------- src/transformers/modeling_utils.py | 2 - .../models/openai_moe/modeling_openai_moe.py | 3 +- .../quantizers/quantizer_mxfp4.py | 9 +-- 5 files changed, 20 insertions(+), 69 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index f0e18ddff2d6..b530a65f0005 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -119,7 +119,7 @@ "run_hp_search_sigopt", "run_hp_search_wandb", ], - "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4OpenAIMoeExperts", "quantize_to_mxfp4", "shuffle_weight", "convert_moe_packed_tensors", "reverse_replace_with_mxfp4_linear"], + "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4OpenAIMoeExperts", "quantize_to_mxfp4", "convert_moe_packed_tensors"], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], "spqr": ["replace_with_spqr_linear"], @@ -256,7 +256,7 @@ run_hp_search_sigopt, run_hp_search_wandb, ) - from .mxfp4 import replace_with_mxfp4_linear, quantize_to_mxfp4, shuffle_weight, Mxfp4OpenAIMoeExperts + from .mxfp4 import replace_with_mxfp4_linear, quantize_to_mxfp4, Mxfp4OpenAIMoeExperts from .peft import PeftAdapterMixin from .quanto import replace_with_quanto_layers from .spqr import replace_with_spqr_linear diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index e92b610bceb6..9dc3e0d27681 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -52,26 +52,6 @@ def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): swizzle_value=swizzle_mx_value, actual_weight_scale_shape=weight_scale_shape) -def shuffle_weight(w: "torch.Tensor") -> "torch.Tensor": - # Shuffle weight along the last dimension so that - # we folded the weights to adjance location - # Example: - # input: - # [[1, 2, 3, 4, 5, 6], - # [7, 8, 9, 10, 11, 12]] - # output: - # [[1, 4, 2, 5, 3, 6], - # [7, 10, 8, 11, 9, 12]] - # This will be used together with triton swiglu kernel - shape = w.shape - N = shape[-1] - first = w[..., :N // 2] - second = w[..., N // 2:] - - stacked = torch.stack((first, second), dim=-1) - w_shuffled = stacked.reshape(shape) - return w_shuffled - def convert_moe_packed_tensors( blocks, scales, @@ -115,15 +95,13 @@ def convert_moe_packed_tensors( del idx_lo, idx_hi, blk, exp out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) - # to match for now existing implementation return out -# maybe subclass class Mxfp4OpenAIMoeExperts(nn.Module): def __init__(self, config): super().__init__() - self.num_experts = config.num_experts if hasattr(config, "num_experts") else config.num_local_experts + self.num_experts = config.num_local_experts self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size @@ -149,33 +127,21 @@ def __init__(self, config): self.gate_up_proj_precision_config = None self.down_proj_precision_config = None - smallest_even_divide_number = lambda x, n: (x // n + 1) * n if x % n != 0 else x + # TODO: To remove once we make sure that we don't need this + # smallest_even_divide_number = lambda x, n: (x // n + 1) * n if x % n != 0 else x - self.gate_up_proj_right_pad = 0# smallest_even_divide_number(self.intermediate_size * 2, 256) - self.intermediate_size * 2 + self.gate_up_proj_right_pad = 0 #smallest_even_divide_number(self.intermediate_size * 2, 256) - self.intermediate_size * 2 self.gate_up_proj_bottom_pad = 0 + self.down_proj_right_pad = 0 #smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size + self.down_proj_bottom_pad = 0 #self.gate_up_proj_right_pad // 2 + self.hidden_size_pad = 0 #smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size - self.down_proj_right_pad = 0#smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size - self.down_proj_bottom_pad = 0#self.gate_up_proj_right_pad // 2 - - self.hidden_size_pad = 0#smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: """ To update with moe mxfp4 kernels, for now we just upcast the weights in torch.bfloat16 """ - # type check, uint8 means mxfp4 - #TODO: fp8 x mxfp4 on blackwell - assert hidden_states.dtype == torch.bfloat16 - assert self.gate_up_proj_blocks.dtype in (torch.bfloat16, torch.uint8) - assert self.down_proj_blocks.dtype in (torch.bfloat16, torch.uint8) - - if self.gate_up_proj_blocks.dtype != torch.uint8: - assert hidden_states.ndim == 2 - assert hidden_states.shape[-1] == self.gate_up_proj_blocks.shape[-2] - assert self.down_proj_blocks.shape[-1] == self.gate_up_proj_blocks.shape[1] - from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs from triton_kernels.swiglu import swiglu_fn - # TODO: needed in the context of device_map, maybe not for TP with torch.cuda.device(hidden_states.device): act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),(self.alpha, None), 2) @@ -185,20 +151,15 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter mode="constant", value=0) - apply_router_weight_on_input = False - intermediate_cache1 = matmul_ogs(hidden_states, self.gate_up_proj, self.gate_up_proj_bias.to(torch.float32), routing_data, gather_indx=gather_idx, precision_config=self.gate_up_proj_precision_config, - gammas=routing_data.gate_scal if apply_router_weight_on_input else None, + gammas=None, fused_activation=act) - torch.cuda.synchronize() - - with torch.cuda.device(hidden_states.device): intermediate_cache3 = matmul_ogs( intermediate_cache1, self.down_proj, @@ -206,13 +167,9 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter routing_data, scatter_indx=scatter_idx, precision_config=self.down_proj_precision_config, - gammas=None if apply_router_weight_on_input else routing_data.gate_scal) + gammas=routing_data.gate_scal) - torch.cuda.synchronize() - # manually crop the tensor since oai kernel pad the output - output_states = intermediate_cache3[..., :self.hidden_size].contiguous() - torch.cuda.synchronize() - return output_states + return intermediate_cache3 def mlp_forward(self, hidden_states): import torch.distributed as dist @@ -223,14 +180,13 @@ def mlp_forward(self, hidden_states): routing = routing hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) - routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k, sm_first=False) + routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) return routed_out, router_logits def routing_torch_dist( logits, n_expts_act, - sm_first=False ): import os @@ -269,13 +225,14 @@ def topk(vals, k, expt_indx): expt_indx = expt_indx.view(-1).to(torch.int32) - expt_indx = torch.where(expt_indx < local_expert_start, 1000, expt_indx) + # we use a large value to replace the indices that are not in the local expert range + var = 1000 + expt_indx = torch.where(expt_indx < local_expert_start, var, expt_indx) topk_indx = torch.argsort(expt_indx, stable=True).to(torch.int32) gate_indx = torch.argsort(topk_indx).to(torch.int32) expt_indx = torch.where(expt_indx < local_expert_end, expt_indx, replace_value) expt_indx = torch.where(local_expert_start <= expt_indx, expt_indx, replace_value) - gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx) gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx) gate_scal = expt_scal[topk_indx] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f98c85729603..1177f04d85eb 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -716,13 +716,11 @@ def _infer_parameter_dtype( else: raise e is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") - is_torch_e5m2_available = hasattr(torch, "float8_e5m2") # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params # in int/uint/bool and not cast them. casting_dtype = None is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: - # if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: # First fp32 if part of the exception list if keep_in_fp32_regex is not None and keep_in_fp32_regex.search(param_name): casting_dtype = torch.float32 diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index f286fe95affc..e9cc409eed8d 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -89,7 +89,6 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) num_experts = routing_weights.shape[1] - self.training = False if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): @@ -128,7 +127,7 @@ class TopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok - self.num_experts = config.num_experts if hasattr(config, "num_experts") else config.num_local_experts + self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) self.bias = nn.Parameter(torch.empty(self.num_experts)) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 22930e306cdc..212be5fbce9f 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -101,7 +101,7 @@ def validate_environment(self, *args, **kwargs): "Please use a quantized checkpoint or remove the CPU or disk device from the device_map." ) from triton_kernels.numerics_details.mxfp import SwizzlingType - + # TODO: Explain what swizzle_mx_value and swizzle_mx_scale are if major < 9: # NYI for Ampere swizzle_mx_value = None @@ -161,7 +161,7 @@ def create_quantized_param( ): from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig - from ..integrations import Mxfp4OpenAIMoeExperts, convert_moe_packed_tensors, quantize_to_mxfp4, shuffle_weight + from ..integrations import Mxfp4OpenAIMoeExperts, convert_moe_packed_tensors, quantize_to_mxfp4 from ..integrations.tensor_parallel import shard_and_distribute_module from ..modeling_utils import _load_parameter_into_model from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts @@ -173,13 +173,10 @@ def create_quantized_param( if "gate_up_proj" in param_name: right_pad = module.gate_up_proj_right_pad bottom_pad = module.gate_up_proj_bottom_pad - # we only shuffle gate_proj - loaded_weight_shuffled = shuffle_weight(param_value).to(target_device) - loaded_weight = torch.nn.functional.pad(loaded_weight_shuffled, + loaded_weight = torch.nn.functional.pad(param_value, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0) - del loaded_weight_shuffled torch.cuda.empty_cache() loaded_weight, flex, mx = quantize_to_mxfp4( loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) From 75e0f21a2af310b43e889c739d933a23916d8cc3 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 24 Jul 2025 08:50:23 +0000 Subject: [PATCH 205/342] simplify --- src/transformers/modeling_utils.py | 45 +++++++++++++++--------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1177f04d85eb..65ba8029496d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -797,7 +797,8 @@ def _load_state_dict_into_meta_model( hf_quantizer, ) - if device_mesh is not None and ( + if device_mesh is not None: + if ( not is_quantized or (not hf_quantizer.requires_parameters_quantization) or ( @@ -810,27 +811,27 @@ def _load_state_dict_into_meta_model( ) ) ): # In this case, the param is already on the correct device! - shard_and_distribute_module( - model, - param, - empty_param, - param_name, - casting_dtype, - to_contiguous, - device_mesh.get_local_rank(), - device_mesh, - ) - elif device_mesh is not None: # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param: - sharding_kwargs = { - "empty_param": empty_param, - "casting_dtype": casting_dtype, - "to_contiguous": to_contiguous, - "rank": device_mesh.get_local_rank(), - "device_mesh": device_mesh - } - hf_quantizer.create_quantized_param( - model, param, param_name, device_mesh.get_local_rank(), state_dict, unexpected_keys, **sharding_kwargs - ) + shard_and_distribute_module( + model, + param, + empty_param, + param_name, + casting_dtype, + to_contiguous, + device_mesh.get_local_rank(), + device_mesh, + ) + else: # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param: + sharding_kwargs = { + "empty_param": empty_param, + "casting_dtype": casting_dtype, + "to_contiguous": to_contiguous, + "rank": device_mesh.get_local_rank(), + "device_mesh": device_mesh + } + hf_quantizer.create_quantized_param( + model, param, param_name, device_mesh.get_local_rank(), state_dict, unexpected_keys, **sharding_kwargs + ) else: param = param[...] if casting_dtype is not None: From c8ce047337644cd620178ea37a663e88589ec9c9 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 24 Jul 2025 09:18:36 +0000 Subject: [PATCH 206/342] more cleaning --- src/transformers/integrations/mxfp4.py | 4 ---- src/transformers/quantizers/quantizer_mxfp4.py | 8 ++++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 9dc3e0d27681..2e494cf19f48 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -264,7 +264,6 @@ def _replace_with_mxfp4_linear( quantization_config=None, has_been_replaced=False, config=None, - tp_plan=None, ): if current_key_name is None: current_key_name = [] @@ -289,7 +288,6 @@ def _replace_with_mxfp4_linear( quantization_config, has_been_replaced=has_been_replaced, config=config, - tp_plan=tp_plan, ) current_key_name.pop(-1) return model, has_been_replaced @@ -301,7 +299,6 @@ def replace_with_mxfp4_linear( current_key_name=None, quantization_config=None, config=None, - tp_plan=None, ): if quantization_config.dequantize: @@ -318,7 +315,6 @@ def replace_with_mxfp4_linear( current_key_name, quantization_config, config=config, - tp_plan=tp_plan, ) if not has_been_replaced : logger.warning( diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 212be5fbce9f..7f6414e94a36 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -63,7 +63,7 @@ def validate_environment(self, *args, **kwargs): major, minor = compute_capability if not is_triton_available("3.4.0") or not is_triton_kernels_availalble(): - if self.pre_quantized: + if self.pre_quantized and not self.quantization_config.dequantize: logger.warning_once( "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16" ) @@ -74,7 +74,7 @@ def validate_environment(self, *args, **kwargs): "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed" ) - if major < 9: + if major < 9 and not self.quantization_config.dequantize: raise ValueError( "MXFP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100, or B100)" ) @@ -138,6 +138,8 @@ def check_quantized_param( ): from ..integrations import Mxfp4OpenAIMoeExperts from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts + + # if we are dequantizing, the model doesn't have scales, and blocks only params like gate_up_proj and down_proj so we need to handle this case differently if self.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name): module, tensor_name = get_module_from_name(model, param_name[:-len("_blocks")]) else: @@ -338,7 +340,6 @@ def _process_model_before_weight_loading( ): from ..integrations import replace_with_mxfp4_linear - tp_plan = model._tp_plan self.modules_to_not_convert = self.get_modules_to_not_convert( model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules ) @@ -349,7 +350,6 @@ def _process_model_before_weight_loading( modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config, config=config, - tp_plan=tp_plan, ) model.config.quantization_config = self.quantization_config From 8b162f70b47f7e22158b8a8dde6fb8f6b01fc0d3 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 24 Jul 2025 09:28:47 +0000 Subject: [PATCH 207/342] rm duplicated function --- .../convert_openai_weights_to_hf.py | 110 +----------------- 1 file changed, 2 insertions(+), 108 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index ed57224ceb8b..6149313a1aa9 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -16,7 +16,6 @@ import gc import json import os -import math from pathlib import Path from typing import List, Optional @@ -80,56 +79,6 @@ def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None): -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, ] -def convert_moe_packed_tensors( - blocks, - scales, - *, - dtype: torch.dtype = torch.bfloat16, - rows_per_chunk: int = 32768 * 1024, -) -> torch.Tensor: - import math - scales = scales.to(torch.int32) - 127 - - assert blocks.shape[:-1] == scales.shape, ( - f"{blocks.shape=} does not match {scales.shape=}" - ) - - lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) - - *prefix_shape, G, B = blocks.shape - rows_total = math.prod(prefix_shape) * G - - blocks = blocks.reshape(rows_total, B) - scales = scales.reshape(rows_total, 1) - - out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) - - for r0 in range(0, rows_total, rows_per_chunk): - r1 = min(r0 + rows_per_chunk, rows_total) - - blk = blocks[r0:r1] - exp = scales[r0:r1] - - # nibble indices -> int64 - idx_lo = (blk & 0x0F).to(torch.long) - idx_hi = (blk >> 4).to(torch.long) - - sub = out[r0:r1] - sub[:, 0::2] = lut[idx_lo] - sub[:, 1::2] = lut[idx_hi] - - torch.ldexp(sub, exp, out=sub) - del idx_lo, idx_hi, blk, exp - - out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) - # to match for now existing implementation - return out.to(torch.float8_e5m2) - -FP4_VALUES = [ - +0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, -] - def convert_moe_packed_tensors( blocks, scales, @@ -203,11 +152,10 @@ def write_model( print(f"Fetching all parameters from the checkpoint at {input_base_path}...") final_ = {} for file in list(os.listdir(input_base_path)): - # TODO: remove that if file.endswith(".safetensors"): final_.update(safe_load(os.path.join(input_base_path, file))) - print("Converting ..", unpack) + print("Converting ..") all_keys = final_.keys() new_keys = convert_old_keys_to_new_keys(all_keys) @@ -260,9 +208,6 @@ def write_model( if not re.search("norm", new_key): weight = weight.to(torch.bfloat16) # norms are the only ones in float32 state_dict[new_key] = weight - if "bias" in new_key and "down_proj" in new_key: - print("new_key: ", new_key) - print(f"bias.shape: {state_dict[new_key].shape}") del final_ gc.collect() @@ -379,57 +324,6 @@ def bytes_to_unicode(): cs = [chr(n) for n in cs] return dict(zip(bs, cs)) -FP4_VALUES = [ - +0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, -] - -def convert_moe_packed_tensors( - blocks, - scales, - *, - dtype: torch.dtype = torch.bfloat16, - rows_per_chunk: int = 32768 * 1024, -) -> torch.Tensor: - - scales = scales.to(torch.int32) - 127 - - assert blocks.shape[:-1] == scales.shape, ( - f"{blocks.shape=} does not match {scales.shape=}" - ) - - lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) - - *prefix_shape, G, B = blocks.shape - rows_total = math.prod(prefix_shape) * G - - blocks = blocks.reshape(rows_total, B) - scales = scales.reshape(rows_total, 1) - - out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) - - for r0 in range(0, rows_total, rows_per_chunk): - r1 = min(r0 + rows_per_chunk, rows_total) - - blk = blocks[r0:r1] - exp = scales[r0:r1] - - # nibble indices -> int64 - idx_lo = (blk & 0x0F).to(torch.long) - idx_hi = (blk >> 4).to(torch.long) - - sub = out[r0:r1] - sub[:, 0::2] = lut[idx_lo] - sub[:, 1::2] = lut[idx_hi] - - torch.ldexp(sub, exp, out=sub) - del idx_lo, idx_hi, blk, exp - - out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) - # to match for now existing implementation - out = out.to(torch.float8_e5m2) - return out.to(torch.float8_e5m2) - class OpenAIMoeConverter(TikTokenConverter): def extract_vocab_merges_from_model(self, tiktoken_url: str): @@ -649,4 +543,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file From 8a00f600cec788504c3e375b7ac6d64b2e5e98ce Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 24 Jul 2025 09:34:26 +0000 Subject: [PATCH 208/342] changing tp_plan --- .../quantizers/quantizer_mxfp4.py | 39 +++++++------------ 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 7f6414e94a36..20dacb4c0614 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -370,40 +370,27 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li return [k for k in missing_keys if k not in not_missing_keys] def update_tp_plan(self, config): - config.base_model_tp_plan = { - # "embed_tokens": "vocab_parallel_rowwise", - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.self_attn.sinks": "local_rowwise", - "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", + # Only update the existing plans, do not create new dicts + if not hasattr(config, "base_model_tp_plan") or config.base_model_tp_plan is None: + config.base_model_tp_plan = {} + if not hasattr(config, "base_model_ep_plan") or config.base_model_ep_plan is None: + config.base_model_ep_plan = {} + + # Update TP plan with scales and blocks + config.base_model_tp_plan.update({ "layers.*.mlp.experts.gate_up_proj_blocks": "local_packed_rowwise", "layers.*.mlp.experts.gate_up_proj_scales": "local_packed_rowwise", - "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", - "layers.*.mlp.experts.down_proj": "local_colwise", "layers.*.mlp.experts.down_proj_blocks": "local_colwise", "layers.*.mlp.experts.down_proj_scales": "local_colwise", - "layers.*.mlp.experts.down_proj_bias": "local_colwise", - # "layers.*.mlp": "gather", - } - config.base_model_ep_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.self_attn.sinks": "local_rowwise", - "layers.*.mlp.experts": "gather", - "layers.*.mlp.router": "ep_router", - "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + }) + + # Update EP plan with scales and blocks + config.base_model_ep_plan.update({ "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm", "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm", - "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", - "layers.*.mlp.experts.down_proj": "grouped_gemm", "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm", "layers.*.mlp.experts.down_proj_scales": "grouped_gemm", - "layers.*.mlp.experts.down_proj_bias": "grouped_gemm", - } + }) return config From d760f30c0a1c63758f2f20f46855af65c1c4f27e Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 24 Jul 2025 09:40:12 +0000 Subject: [PATCH 209/342] update tp plan check --- .../convert_openai_weights_to_hf.py | 2 +- .../quantizers/quantizer_mxfp4.py | 43 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 6149313a1aa9..c8daf8a43ca5 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -543,4 +543,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 20dacb4c0614..d88fc6bc51cb 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -370,31 +370,32 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li return [k for k in missing_keys if k not in not_missing_keys] def update_tp_plan(self, config): - # Only update the existing plans, do not create new dicts - if not hasattr(config, "base_model_tp_plan") or config.base_model_tp_plan is None: - config.base_model_tp_plan = {} - if not hasattr(config, "base_model_ep_plan") or config.base_model_ep_plan is None: - config.base_model_ep_plan = {} - - # Update TP plan with scales and blocks - config.base_model_tp_plan.update({ - "layers.*.mlp.experts.gate_up_proj_blocks": "local_packed_rowwise", - "layers.*.mlp.experts.gate_up_proj_scales": "local_packed_rowwise", - "layers.*.mlp.experts.down_proj_blocks": "local_colwise", - "layers.*.mlp.experts.down_proj_scales": "local_colwise", - }) - - # Update EP plan with scales and blocks - config.base_model_ep_plan.update({ - "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm", - "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm", - "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm", - "layers.*.mlp.experts.down_proj_scales": "grouped_gemm", - }) + if "OpenAIMoeConfig" in config.__class__.__name__: + if not hasattr(config, "base_model_tp_plan") or config.base_model_tp_plan is None: + config.base_model_tp_plan = {} + if not hasattr(config, "base_model_ep_plan") or config.base_model_ep_plan is None: + config.base_model_ep_plan = {} + + # Update TP plan with scales and blocks + config.base_model_tp_plan.update({ + "layers.*.mlp.experts.gate_up_proj_blocks": "local_packed_rowwise", + "layers.*.mlp.experts.gate_up_proj_scales": "local_packed_rowwise", + "layers.*.mlp.experts.down_proj_blocks": "local_colwise", + "layers.*.mlp.experts.down_proj_scales": "local_colwise", + }) + + # Update EP plan with scales and blocks + config.base_model_ep_plan.update({ + "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm", + "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm", + "layers.*.mlp.experts.down_proj_scales": "grouped_gemm", + }) return config def is_serializable(self, safe_serialization=None): + logger.warning_once("MXFP4 quantization is not serializable using safetensors for now") return False @property From b34570e748a2151cbbcab14af374c3941fa9671f Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 24 Jul 2025 09:57:22 +0000 Subject: [PATCH 210/342] add loading attribute --- src/transformers/quantizers/auto.py | 2 +- src/transformers/utils/quantization_config.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 9d91b8edb696..88f133ac713d 100644 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -215,7 +215,7 @@ def merge_quantization_configs( if ( isinstance( - quantization_config, (GPTQConfig, AwqConfig, AutoRoundConfig, FbgemmFp8Config, CompressedTensorsConfig) + quantization_config, (GPTQConfig, AwqConfig, AutoRoundConfig, FbgemmFp8Config, CompressedTensorsConfig, Mxfp4Config) ) and quantization_config_from_args is not None ): diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 06e3c1b804dc..1069d02e4765 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -2072,3 +2072,8 @@ def __init__( self.quant_method = QuantizationMethod.MXFP4 self.modules_to_not_convert = modules_to_not_convert self.dequantize = dequantize + + def get_loading_attributes(self): + return { + "dequantize": self.dequantize, + } \ No newline at end of file From a4950aa603889ea7d465030909a9a1cba597911c Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 24 Jul 2025 10:05:58 +0000 Subject: [PATCH 211/342] dequantizing logic --- src/transformers/quantizers/base.py | 12 ++++++++---- src/transformers/quantizers/quantizer_mxfp4.py | 6 +++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 0a4ddf680461..5eacd31aeef1 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -245,10 +245,14 @@ def dequantize(self, model): model = self._dequantize(model) # Delete quantizer and quantization config - del model.hf_quantizer - del model.config.quantization_config - del model.config._pre_quantization_dtype - del model.quantization_method + if hasattr(model, "hf_quantizer"): + del model.hf_quantizer + if hasattr(model.config, "quantization_config"): + del model.config.quantization_config + if hasattr(model.config, "_pre_quantization_dtype"): + del model.config._pre_quantization_dtype + if hasattr(model, "quantization_method"): + del model.quantization_method model.is_quantized = False return model diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index d88fc6bc51cb..39471e11cb17 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -313,9 +313,13 @@ def create_quantized_param( module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) - def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + def _dequantize(self, model): return model + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + # we are not really dequantizing, we are just removing everthing related to quantization here + self.dequantize(model) + def update_expected_keys(self, model: "PreTrainedModel", expected_keys: list[str], checkpoint_keys: list[str]): # Replace expected_keys for experts' gate_up_proj and down_proj with their _blocks and _scales variants new_expected_keys = [] From 89b067108847adfbfbf99951a834db8423665d2a Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 24 Jul 2025 11:30:56 +0000 Subject: [PATCH 212/342] use subfunctions --- src/transformers/integrations/__init__.py | 4 +- src/transformers/integrations/mxfp4.py | 114 ++++++++++++++ .../quantizers/quantizer_mxfp4.py | 142 +++++------------- 3 files changed, 152 insertions(+), 108 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index b530a65f0005..cbe920f9b280 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -119,7 +119,7 @@ "run_hp_search_sigopt", "run_hp_search_wandb", ], - "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4OpenAIMoeExperts", "quantize_to_mxfp4", "convert_moe_packed_tensors"], + "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4OpenAIMoeExperts", "quantize_to_mxfp4", "convert_moe_packed_tensors", "dequantize", "dequantize_and_quantize"], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], "spqr": ["replace_with_spqr_linear"], @@ -256,7 +256,7 @@ run_hp_search_sigopt, run_hp_search_wandb, ) - from .mxfp4 import replace_with_mxfp4_linear, quantize_to_mxfp4, Mxfp4OpenAIMoeExperts + from .mxfp4 import replace_with_mxfp4_linear, quantize_to_mxfp4, Mxfp4OpenAIMoeExperts, dequantize, dequantize_and_quantize from .peft import PeftAdapterMixin from .quanto import replace_with_quanto_layers from .spqr import replace_with_spqr_linear diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 2e494cf19f48..1e3159381526 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -256,6 +256,120 @@ def should_convert_module(current_key_name, patterns): return True return False +def dequantize(module,param_name, tp_mode, model, param_value, neutral_param_name, target_device, empty_param, casting_dtype, to_contiguous, rank, device_mesh): + from ..integrations.tensor_parallel import shard_and_distribute_module + + if "gate_up_proj" in param_name: + if not hasattr(module, "gate_up_proj_blocks") and not hasattr(module, "gate_up_proj_scales"): + if tp_mode: + param_value = shard_and_distribute_module(model, param_value, empty_param, neutral_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param_inside=False) + setattr(module, param_name.rsplit(".", 1)[1], param_value) + return + else: + if tp_mode: + param_value = shard_and_distribute_module(model, param_value, empty_param, neutral_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param_inside=False) + setattr(module, param_name.rsplit(".", 1)[1], param_value) + + dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) + dequantized_gate_up_proj = dequantized_gate_up_proj.transpose(1,2).to(target_device) + module.gate_up_proj = torch.nn.Parameter(dequantized_gate_up_proj, requires_grad=False) + del module.gate_up_proj_blocks + del module.gate_up_proj_scales + return + elif "down_proj" in param_name: + if not hasattr(module, "down_proj_blocks") and not hasattr(module, "down_proj_scales"): + if tp_mode: + param_value = shard_and_distribute_module(model, param_value, empty_param, neutral_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param_inside=False) + setattr(module, param_name.rsplit(".", 1)[1], param_value) + return + else: + if tp_mode: + param_value = shard_and_distribute_module(model, param_value, empty_param, neutral_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param_inside=False) + setattr(module, param_name.rsplit(".", 1)[1], param_value) + + dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) + dequantized_down_proj = dequantized_down_proj.transpose(1,2).to(target_device) + module.down_proj = torch.nn.Parameter(dequantized_down_proj, requires_grad=False) + del module.down_proj_blocks + del module.down_proj_scales + return + +def dequantize_and_quantize(module,param_name, tp_mode, model, param_value, target_device, empty_param, casting_dtype, to_contiguous, rank, device_mesh, swizzle_mx_value, swizzle_mx_scale): + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + + from ..integrations.tensor_parallel import shard_and_distribute_module + from ..modeling_utils import _load_parameter_into_model + + if "gate_up_proj" in param_name: + if module.gate_up_proj_blocks.device.type == "meta" and module.gate_up_proj_scales.device.type == "meta": + if tp_mode: + shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) + else: + _load_parameter_into_model(model, param_name, param_value) + return + else: + # In this case the weights or the scales are already on the correct device, so param_value should be the other missing param + if (module.gate_up_proj_blocks.device != "meta" and "scales" in param_name) or (module.gate_up_proj_scales.device != "meta" and "blocks" in param_name): + if tp_mode: + shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) + else: + _load_parameter_into_model(model, param_name, param_value) + else: + raise ValueError("Something went horribly wrong mate in gate_up_proj") + + dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) + dequantized_gate_up_proj = dequantized_gate_up_proj.transpose(1,2).to(target_device) + + module.device_mesh = device_mesh + module.rank = rank + + right_pad = module.gate_up_proj_right_pad + bottom_pad = module.gate_up_proj_bottom_pad + loaded_weight = torch.nn.functional.pad(dequantized_gate_up_proj, + (0, right_pad, 0, bottom_pad, 0, 0), + mode="constant", + value=0) + del dequantized_gate_up_proj + torch.cuda.empty_cache() + with torch.cuda.device(target_device): + loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, swizzle_mx_value, swizzle_mx_scale) + module.gate_up_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) + + elif "down_proj" in param_name: + if module.down_proj_blocks.device.type == "meta" and module.down_proj_scales.device.type == "meta": + if tp_mode: + shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) + else: + _load_parameter_into_model(model, param_name, param_value) + return + else: + if (module.down_proj_blocks.device != "meta" and "scales" in param_name) or (module.down_proj_scales.device != "meta" and "blocks" in param_name): + if tp_mode: + shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) + else: + _load_parameter_into_model(model, param_name, param_value) + else: + raise ValueError("Something went horribly wrong mate in down_proj") + + dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) + dequantized_down_proj = dequantized_down_proj.transpose(1,2).to(target_device) + module.device_mesh = device_mesh + module.rank = rank + + right_pad = module.down_proj_right_pad + bottom_pad = module.down_proj_bottom_pad + loaded_weight = torch.nn.functional.pad(dequantized_down_proj, + (0, right_pad, 0, bottom_pad, 0, 0), + mode="constant", + value=0) + del dequantized_down_proj + torch.cuda.empty_cache() + with torch.cuda.device(target_device): + loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, swizzle_mx_value, swizzle_mx_scale) + + module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) + module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) def _replace_with_mxfp4_linear( model, diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 39471e11cb17..873da0ff1bae 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -163,7 +163,7 @@ def create_quantized_param( ): from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig - from ..integrations import Mxfp4OpenAIMoeExperts, convert_moe_packed_tensors, quantize_to_mxfp4 + from ..integrations import Mxfp4OpenAIMoeExperts, convert_moe_packed_tensors, quantize_to_mxfp4, dequantize, dequantize_and_quantize from ..integrations.tensor_parallel import shard_and_distribute_module from ..modeling_utils import _load_parameter_into_model from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts @@ -198,8 +198,13 @@ def create_quantized_param( module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) # we take this path if already quantized but not in a compatible way: else: + empty_param = kwargs.get("empty_param", None) + casting_dtype = kwargs.get("casting_dtype", None) + to_contiguous = kwargs.get("to_contiguous", None) + rank = kwargs.get("rank", None) + device_mesh = kwargs.get("device_mesh", None) if ("blocks" in param_name or "scales" in param_name) and self.quantization_config.dequantize: - # blocks and scales have the same length that's why the below line works + # blocks and scales have the same length that's this works for both module, _ = get_module_from_name(model, param_name[:-len("_blocks")]) else: module, _ = get_module_from_name(model, param_name) @@ -207,111 +212,36 @@ def create_quantized_param( tp_mode = kwargs.get("device_mesh", None) is not None if self.quantization_config.dequantize: neutral_param_name = param_name[:-len("_blocks")] if "blocks" in param_name else param_name[:-len("_scales")] - if "gate_up_proj" in param_name: - if not hasattr(module, "gate_up_proj_blocks") and not hasattr(module, "gate_up_proj_scales"): - if tp_mode: - param_value = shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), neutral_param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh"), set_param_inside=False) - setattr(module, param_name.rsplit(".", 1)[1], param_value) - return - else: - if tp_mode: - param_value = shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), neutral_param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh"), set_param_inside=False) - setattr(module, param_name.rsplit(".", 1)[1], param_value) - - dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) - dequantized_gate_up_proj = dequantized_gate_up_proj.transpose(1,2).to(target_device) - module.gate_up_proj = torch.nn.Parameter(dequantized_gate_up_proj, requires_grad=False) - del module.gate_up_proj_blocks - del module.gate_up_proj_scales - return - elif "down_proj" in param_name: - if not hasattr(module, "down_proj_blocks") and not hasattr(module, "down_proj_scales"): - if tp_mode: - param_value = shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), neutral_param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh"), set_param_inside=False) - setattr(module, param_name.rsplit(".", 1)[1], param_value) - return - else: - if tp_mode: - param_value = shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), neutral_param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh"), set_param_inside=False) - setattr(module, param_name.rsplit(".", 1)[1], param_value) - - dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) - dequantized_down_proj = dequantized_down_proj.transpose(1,2).to(target_device) - module.down_proj = torch.nn.Parameter(dequantized_down_proj, requires_grad=False) - del module.down_proj_blocks - del module.down_proj_scales - return + dequantize( + module, + param_name, + tp_mode, + model, + param_value, + neutral_param_name, + target_device, + empty_param, + casting_dtype, + to_contiguous, + rank, + device_mesh + ) else: - if "gate_up_proj" in param_name: - if module.gate_up_proj_blocks.device.type == "meta" and module.gate_up_proj_scales.device.type == "meta": - if tp_mode: - shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) - else: - _load_parameter_into_model(model, param_name, param_value) - return - else: - # In this case the weights or the scales are already on the correct device, so param_value should be the other missing param - if (module.gate_up_proj_blocks.device != "meta" and "scales" in param_name) or (module.gate_up_proj_scales.device != "meta" and "blocks" in param_name): - if tp_mode: - shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) - else: - _load_parameter_into_model(model, param_name, param_value) - else: - raise ValueError("Something went horribly wrong mate in gate_up_proj") - - dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) - dequantized_gate_up_proj = dequantized_gate_up_proj.transpose(1,2).to(target_device) - - module.device_mesh = kwargs.get("device_mesh") - module.rank = kwargs.get("rank") - - right_pad = module.gate_up_proj_right_pad - bottom_pad = module.gate_up_proj_bottom_pad - loaded_weight = torch.nn.functional.pad(dequantized_gate_up_proj, - (0, right_pad, 0, bottom_pad, 0, 0), - mode="constant", - value=0) - del dequantized_gate_up_proj - torch.cuda.empty_cache() - with torch.cuda.device(target_device): - loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) - module.gate_up_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) - module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) - - elif "down_proj" in param_name: - if module.down_proj_blocks.device.type == "meta" and module.down_proj_scales.device.type == "meta": - if tp_mode: - shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) - else: - _load_parameter_into_model(model, param_name, param_value) - return - else: - if (module.down_proj_blocks.device != "meta" and "scales" in param_name) or (module.down_proj_scales.device != "meta" and "blocks" in param_name): - if tp_mode: - shard_and_distribute_module(model, param_value, kwargs.get("empty_param"), param_name, kwargs.get("casting_dtype"), kwargs.get("to_contiguous"), kwargs.get("rank"), kwargs.get("device_mesh")) - else: - _load_parameter_into_model(model, param_name, param_value) - else: - raise ValueError("Something went horribly wrong mate in down_proj") - - dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) - dequantized_down_proj = dequantized_down_proj.transpose(1,2).to(target_device) - module.device_mesh = kwargs.get("device_mesh") - module.rank = kwargs.get("rank") - - right_pad = module.down_proj_right_pad - bottom_pad = module.down_proj_bottom_pad - loaded_weight = torch.nn.functional.pad(dequantized_down_proj, - (0, right_pad, 0, bottom_pad, 0, 0), - mode="constant", - value=0) - del dequantized_down_proj - torch.cuda.empty_cache() - with torch.cuda.device(target_device): - loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) - - module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) - module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) + dequantize_and_quantize( + module, + param_name, + tp_mode, + model, + param_value, + target_device, + empty_param, + casting_dtype, + to_contiguous, + rank, + device_mesh, + self.swizzle_mx_value, + self.swizzle_mx_scale + ) def _dequantize(self, model): return model From 7bfdca61ba3ac5bbccd7f9fc871b0ab621af61d1 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 24 Jul 2025 11:33:39 +0000 Subject: [PATCH 213/342] import cleaning --- src/transformers/quantizers/quantizer_mxfp4.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 873da0ff1bae..f19a56845f32 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -163,9 +163,7 @@ def create_quantized_param( ): from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig - from ..integrations import Mxfp4OpenAIMoeExperts, convert_moe_packed_tensors, quantize_to_mxfp4, dequantize, dequantize_and_quantize - from ..integrations.tensor_parallel import shard_and_distribute_module - from ..modeling_utils import _load_parameter_into_model + from ..integrations import Mxfp4OpenAIMoeExperts, dequantize, dequantize_and_quantize, quantize_to_mxfp4 from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts if not self.pre_quantized: @@ -181,7 +179,11 @@ def create_quantized_param( value=0) torch.cuda.empty_cache() loaded_weight, flex, mx = quantize_to_mxfp4( - loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) + loaded_weight, + self.swizzle_mx_value, + self.swizzle_mx_scale + ) + torch.cuda.empty_cache() module.gate_up_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) elif "down_proj" in param_name: @@ -191,9 +193,13 @@ def create_quantized_param( (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0).to(target_device) - # delete intermediate tensor immediate to prevent OOM + torch.cuda.empty_cache() loaded_weight, flex, mx = quantize_to_mxfp4( - loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale) + loaded_weight, + self.swizzle_mx_value, + self.swizzle_mx_scale + ) + torch.cuda.empty_cache() module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) # we take this path if already quantized but not in a compatible way: From 21872bd0065f008cbfd08657a5aca002a3a38ac7 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 24 Jul 2025 11:53:36 +0000 Subject: [PATCH 214/342] update_param_name --- src/transformers/modeling_utils.py | 14 ++++++++------ src/transformers/quantizers/base.py | 6 ++++++ src/transformers/quantizers/quantizer_mxfp4.py | 10 ++++++++++ 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 65ba8029496d..25a034678c37 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -6173,15 +6173,17 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, # Skip if the parameter has already been accounted for (tied weights) if param_name in tied_param_names: continue + + # Exception in the case of MXFP4 quantization, we need to update the param name to the original param name + # because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name + if hf_quantizer.quantization_config.quant_method in {QuantizationMethod.MXFP4} and hf_quantizer.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name): + param_name = hf_quantizer.update_param_name(param_name) + try: param = model.get_parameter_or_buffer(param_name) except AttributeError: - if hf_quantizer.quantization_config.quant_method in {QuantizationMethod.MXFP4} and ("blocks" in param_name or "scales" in param_name): - neutral_param_name = param_name[:-len("_blocks")] if "blocks" in param_name else param_name[:-len("_scales")] - try: - param = model.get_parameter_or_buffer(neutral_param_name) - except AttributeError: - raise AttributeError(f"Parameter {param_name} not found in model") + raise AttributeError(f"Parameter {param_name} not found in model") + # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` param_byte_count = param.numel() * param.element_size() diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 5eacd31aeef1..93fff7be3d40 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -273,6 +273,12 @@ def _dequantize(self, model): f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub." ) + def update_param_name(self, param_name: str) -> str: + """ + Override this method if you want to adjust the `param_name`. + """ + return param_name + @staticmethod def get_modules_to_not_convert( model: "PreTrainedModel", diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index f19a56845f32..9f4b94e2089f 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -217,6 +217,8 @@ def create_quantized_param( if isinstance(module, Mxfp4OpenAIMoeExperts) or (isinstance(module, OpenAIMoeExperts) and self.quantization_config.dequantize): tp_mode = kwargs.get("device_mesh", None) is not None if self.quantization_config.dequantize: + # neutral_param_name is the name of the parameter without the blocks or scales suffix, it's used in this case since we don't switch linears + # so we only have the original param name neutral_param_name = param_name[:-len("_blocks")] if "blocks" in param_name else param_name[:-len("_scales")] dequantize( module, @@ -334,6 +336,14 @@ def update_tp_plan(self, config): return config + + def update_param_name(self, param_name: str) -> str: + if "_blocks" in param_name: + return param_name.replace("_blocks", "") + elif "_scales" in param_name: + return param_name.replace("_scales", "") + return param_name + def is_serializable(self, safe_serialization=None): logger.warning_once("MXFP4 quantization is not serializable using safetensors for now") return False From b68ece87cde00d42048f9b3bbf4c5316b2940e74 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Thu, 24 Jul 2025 15:10:46 +0000 Subject: [PATCH 215/342] adds clamped swiglu --- src/transformers/models/openai_moe/modeling_openai_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index e9cc409eed8d..0da52e6be8fa 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -40,7 +40,7 @@ @use_kernel_forward_from_hub("RMSNorm") class OpenAIMoeRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, eps=1e-5): """ OpenAIMoeRMSNorm is equivalent to T5LayerNorm """ @@ -71,6 +71,7 @@ def __init__(self, config): self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) self.alpha = 1.702 + self.limit = 7.0 def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: """ @@ -114,6 +115,8 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) glu = gate * torch.sigmoid(gate * self.alpha) next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] From 3e106d6276392891202b99a12740b11c42211461 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Mon, 28 Jul 2025 07:08:11 +0000 Subject: [PATCH 216/342] add clamping to training path --- src/transformers/models/openai_moe/modeling_openai_moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 0da52e6be8fa..38ffedb79f37 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -104,6 +104,8 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig current_state = hidden_states[token_idx] gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) glu = gate * torch.sigmoid(gate * self.alpha) gated_output = (up + 1) * glu out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] From 1716e6d8f8f018897456a3397de343aac5087a34 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Mon, 28 Jul 2025 09:42:49 +0000 Subject: [PATCH 217/342] simplify dequant logic --- src/transformers/integrations/mxfp4.py | 21 ++++++------------- src/transformers/modeling_utils.py | 2 +- .../quantizers/quantizer_mxfp4.py | 1 - 3 files changed, 7 insertions(+), 17 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 1e3159381526..91df4fe6ec77 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -137,9 +137,6 @@ def __init__(self, config): self.hidden_size_pad = 0 #smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: - """ - To update with moe mxfp4 kernels, for now we just upcast the weights in torch.bfloat16 - """ from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs from triton_kernels.swiglu import swiglu_fn with torch.cuda.device(hidden_states.device): @@ -256,20 +253,17 @@ def should_convert_module(current_key_name, patterns): return True return False -def dequantize(module,param_name, tp_mode, model, param_value, neutral_param_name, target_device, empty_param, casting_dtype, to_contiguous, rank, device_mesh): +def dequantize(module,param_name, model, param_value, neutral_param_name, target_device, empty_param, casting_dtype, to_contiguous, rank, device_mesh): from ..integrations.tensor_parallel import shard_and_distribute_module - + if "gate_up_proj" in param_name: + if device_mesh is not None: + param_value = shard_and_distribute_module(model, param_value, empty_param, neutral_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param_inside=False) if not hasattr(module, "gate_up_proj_blocks") and not hasattr(module, "gate_up_proj_scales"): - if tp_mode: - param_value = shard_and_distribute_module(model, param_value, empty_param, neutral_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param_inside=False) setattr(module, param_name.rsplit(".", 1)[1], param_value) return else: - if tp_mode: - param_value = shard_and_distribute_module(model, param_value, empty_param, neutral_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param_inside=False) setattr(module, param_name.rsplit(".", 1)[1], param_value) - dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) dequantized_gate_up_proj = dequantized_gate_up_proj.transpose(1,2).to(target_device) module.gate_up_proj = torch.nn.Parameter(dequantized_gate_up_proj, requires_grad=False) @@ -277,16 +271,13 @@ def dequantize(module,param_name, tp_mode, model, param_value, neutral_param_nam del module.gate_up_proj_scales return elif "down_proj" in param_name: + if device_mesh is not None: + param_value = shard_and_distribute_module(model, param_value, empty_param, neutral_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param_inside=False) if not hasattr(module, "down_proj_blocks") and not hasattr(module, "down_proj_scales"): - if tp_mode: - param_value = shard_and_distribute_module(model, param_value, empty_param, neutral_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param_inside=False) setattr(module, param_name.rsplit(".", 1)[1], param_value) return else: - if tp_mode: - param_value = shard_and_distribute_module(model, param_value, empty_param, neutral_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param_inside=False) setattr(module, param_name.rsplit(".", 1)[1], param_value) - dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) dequantized_down_proj = dequantized_down_proj.transpose(1,2).to(target_device) module.down_proj = torch.nn.Parameter(dequantized_down_proj, requires_grad=False) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 25a034678c37..b5e18a047f4c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -6176,7 +6176,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, # Exception in the case of MXFP4 quantization, we need to update the param name to the original param name # because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name - if hf_quantizer.quantization_config.quant_method in {QuantizationMethod.MXFP4} and hf_quantizer.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name): + if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {QuantizationMethod.MXFP4} and hf_quantizer.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name): param_name = hf_quantizer.update_param_name(param_name) try: diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 9f4b94e2089f..afa5b9c06ec8 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -223,7 +223,6 @@ def create_quantized_param( dequantize( module, param_name, - tp_mode, model, param_value, neutral_param_name, From b8b00238996d1d118204856353405d7c0e16a7ce Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 28 Jul 2025 13:51:55 +0000 Subject: [PATCH 218/342] update --- src/transformers/modeling_utils.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0b787b5792c6..7c480fd616b6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2255,24 +2255,6 @@ def post_init(self): # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None - if is_torch_greater_or_equal("2.5") and _torch_distributed_available and self.config.device_mesh is not None: - # loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit - device_mesh = self.config.device_mesh - for name, module in self.named_modules(): - if not getattr(module, "_is_hooked", False): - from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module - - add_tensor_parallel_hooks_to_module( - model=self, - module=module, - tp_plan=self._tp_plan, - layer_name="", # TODO: make this optional? - current_module_plan=_get_parameter_tp_plan(parameter_name=name, tp_plan=self._tp_plan), - device_mesh=device_mesh, - parameter_name=None, - ) - module._is_hooked = True - def dequantize(self): """ Potentially dequantize the model in case it has been quantized by a quantization method that support From 69761698a08b4f7b457567a5ab0717137d902668 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 28 Jul 2025 13:54:27 +0000 Subject: [PATCH 219/342] Bad merge --- src/transformers/modeling_utils.py | 34 ------------------------------ 1 file changed, 34 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 03ce2dd53626..746f8f71fa20 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2254,40 +2254,6 @@ def post_init(self): # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None - self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {} - for name, module in self.named_children(): - if plan := getattr(module, "_tp_plan", None): - self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) - - if self._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available: - for v in self._tp_plan.values(): - if v not in ALL_PARALLEL_STYLES: - raise ValueError( - f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}" - ) - - if ( - is_torch_greater_or_equal("2.5") - and _torch_distributed_available - and hasattr(self.config, "device_mesh") - and self.config.device_mesh is not None - ): - # loop over named modules and attach hooks. this is necessary when a module doesn't have parameters and thus we never hit - device_mesh = self.config.device_mesh - for name, module in self.named_modules(): - if not getattr(module, "_is_hooked", False): - from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module - - add_tensor_parallel_hooks_to_module( - model=self, - module=module, - tp_plan=self._tp_plan, - layer_name="", # TODO: make this optional? - current_module_plan=_get_parameter_tp_plan(parameter_name=name, tp_plan=self._tp_plan), - device_mesh=device_mesh, - parameter_name=None, - ) - module._is_hooked = True def dequantize(self): """ From 195cca63d2419bb995355e10a69b8d21fe7d63b2 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Mon, 28 Jul 2025 14:12:18 +0000 Subject: [PATCH 220/342] more simplifications & tests --- src/transformers/integrations/mxfp4.py | 165 +++---- .../integrations/tensor_parallel.py | 4 +- src/transformers/modeling_utils.py | 4 +- src/transformers/quantizers/base.py | 22 +- .../quantizers/quantizer_mxfp4.py | 62 ++- src/transformers/testing_utils.py | 20 + src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 2 +- tests/quantization/mxfp4/__init__.py | 0 tests/quantization/mxfp4/test_mxfp4.py | 428 ++++++++++++++++++ 10 files changed, 569 insertions(+), 139 deletions(-) create mode 100644 tests/quantization/mxfp4/__init__.py create mode 100644 tests/quantization/mxfp4/test_mxfp4.py diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 91df4fe6ec77..cdb5d6b34d8f 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -253,114 +253,93 @@ def should_convert_module(current_key_name, patterns): return True return False -def dequantize(module,param_name, model, param_value, neutral_param_name, target_device, empty_param, casting_dtype, to_contiguous, rank, device_mesh): +def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs): from ..integrations.tensor_parallel import shard_and_distribute_module - - if "gate_up_proj" in param_name: - if device_mesh is not None: - param_value = shard_and_distribute_module(model, param_value, empty_param, neutral_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param_inside=False) - if not hasattr(module, "gate_up_proj_blocks") and not hasattr(module, "gate_up_proj_scales"): - setattr(module, param_name.rsplit(".", 1)[1], param_value) - return - else: - setattr(module, param_name.rsplit(".", 1)[1], param_value) - dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) - dequantized_gate_up_proj = dequantized_gate_up_proj.transpose(1,2).to(target_device) - module.gate_up_proj = torch.nn.Parameter(dequantized_gate_up_proj, requires_grad=False) - del module.gate_up_proj_blocks - del module.gate_up_proj_scales - return - elif "down_proj" in param_name: - if device_mesh is not None: - param_value = shard_and_distribute_module(model, param_value, empty_param, neutral_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param_inside=False) - if not hasattr(module, "down_proj_blocks") and not hasattr(module, "down_proj_scales"): - setattr(module, param_name.rsplit(".", 1)[1], param_value) - return - else: - setattr(module, param_name.rsplit(".", 1)[1], param_value) - dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) - dequantized_down_proj = dequantized_down_proj.transpose(1,2).to(target_device) - module.down_proj = torch.nn.Parameter(dequantized_down_proj, requires_grad=False) - del module.down_proj_blocks - del module.down_proj_scales - return -def dequantize_and_quantize(module,param_name, tp_mode, model, param_value, target_device, empty_param, casting_dtype, to_contiguous, rank, device_mesh, swizzle_mx_value, swizzle_mx_scale): + model = kwargs.get("model", None) + empty_param = kwargs.get("empty_param", None) + casting_dtype = kwargs.get("casting_dtype", None) + to_contiguous = kwargs.get("to_contiguous", None) + rank = kwargs.get("rank", None) + device_mesh = kwargs.get("device_mesh", None) + + for proj in ["gate_up_proj", "down_proj"]: + if proj in param_name: + if device_mesh is not None: + param_value = shard_and_distribute_module( + model, param_value, empty_param, dq_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param=False + ) + blocks_attr = f"{proj}_blocks" + scales_attr = f"{proj}_scales" + if not hasattr(module, blocks_attr) and not hasattr(module, scales_attr): + setattr(module, param_name.rsplit(".", 1)[1], param_value) + return + else: + setattr(module, param_name.rsplit(".", 1)[1], param_value) + dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) + dequantized = dequantized.transpose(1, 2).to(target_device) + setattr(module, proj, torch.nn.Parameter(dequantized)) + delattr(module, blocks_attr) + delattr(module, scales_attr) + return + +def dequantize_and_quantize(module, param_name, param_value, target_device, swizzle_mx_value, swizzle_mx_scale, **kwargs): from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig from ..integrations.tensor_parallel import shard_and_distribute_module from ..modeling_utils import _load_parameter_into_model - if "gate_up_proj" in param_name: - if module.gate_up_proj_blocks.device.type == "meta" and module.gate_up_proj_scales.device.type == "meta": - if tp_mode: - shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) - else: - _load_parameter_into_model(model, param_name, param_value) - return - else: - # In this case the weights or the scales are already on the correct device, so param_value should be the other missing param - if (module.gate_up_proj_blocks.device != "meta" and "scales" in param_name) or (module.gate_up_proj_scales.device != "meta" and "blocks" in param_name): - if tp_mode: + model = kwargs.get("model", None) + empty_param = kwargs.get("empty_param", None) + casting_dtype = kwargs.get("casting_dtype", None) + to_contiguous = kwargs.get("to_contiguous", None) + rank = kwargs.get("rank", None) + device_mesh = kwargs.get("device_mesh", None) + # Combine logic for gate_up_proj and down_proj + for proj in ["gate_up_proj", "down_proj"]: + if proj in param_name: + blocks_attr = f"{proj}_blocks" + scales_attr = f"{proj}_scales" + right_pad_attr = f"{proj}_right_pad" + bottom_pad_attr = f"{proj}_bottom_pad" + precision_config_attr = f"{proj}_precision_config" + + # Check if both blocks and scales are still on meta device + blocks = getattr(module, blocks_attr) + scales = getattr(module, scales_attr) + if blocks.device.type == "meta" and scales.device.type == "meta": + if device_mesh is not None: shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) else: _load_parameter_into_model(model, param_name, param_value) + return else: - raise ValueError("Something went horribly wrong mate in gate_up_proj") - - dequantized_gate_up_proj = convert_moe_packed_tensors(module.gate_up_proj_blocks, module.gate_up_proj_scales) - dequantized_gate_up_proj = dequantized_gate_up_proj.transpose(1,2).to(target_device) - - module.device_mesh = device_mesh - module.rank = rank - - right_pad = module.gate_up_proj_right_pad - bottom_pad = module.gate_up_proj_bottom_pad - loaded_weight = torch.nn.functional.pad(dequantized_gate_up_proj, - (0, right_pad, 0, bottom_pad, 0, 0), - mode="constant", - value=0) - del dequantized_gate_up_proj - torch.cuda.empty_cache() - with torch.cuda.device(target_device): - loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, swizzle_mx_value, swizzle_mx_scale) - module.gate_up_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) - module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) - - elif "down_proj" in param_name: - if module.down_proj_blocks.device.type == "meta" and module.down_proj_scales.device.type == "meta": - if tp_mode: - shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) - else: - _load_parameter_into_model(model, param_name, param_value) - return - else: - if (module.down_proj_blocks.device != "meta" and "scales" in param_name) or (module.down_proj_scales.device != "meta" and "blocks" in param_name): - if tp_mode: + # One of the params is already loaded, so load the other + if device_mesh is not None: shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) else: _load_parameter_into_model(model, param_name, param_value) - else: - raise ValueError("Something went horribly wrong mate in down_proj") - dequantized_down_proj = convert_moe_packed_tensors(module.down_proj_blocks, module.down_proj_scales) - dequantized_down_proj = dequantized_down_proj.transpose(1,2).to(target_device) - module.device_mesh = device_mesh - module.rank = rank - - right_pad = module.down_proj_right_pad - bottom_pad = module.down_proj_bottom_pad - loaded_weight = torch.nn.functional.pad(dequantized_down_proj, - (0, right_pad, 0, bottom_pad, 0, 0), - mode="constant", - value=0) - del dequantized_down_proj - torch.cuda.empty_cache() - with torch.cuda.device(target_device): - loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, swizzle_mx_value, swizzle_mx_scale) - - module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) - module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) + dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) + dequantized = dequantized.transpose(1, 2).to(target_device) + + module.device_mesh = device_mesh + module.rank = rank + + right_pad = getattr(module, right_pad_attr) + bottom_pad = getattr(module, bottom_pad_attr) + loaded_weight = torch.nn.functional.pad( + dequantized, + (0, right_pad, 0, bottom_pad, 0, 0), + mode="constant", + value=0 + ) + del dequantized + with torch.cuda.device(target_device): + loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, swizzle_mx_value, swizzle_mx_scale) + setattr(module, precision_config_attr, PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex))) + setattr(module, proj, torch.nn.Parameter(loaded_weight, requires_grad=False)) + return def _replace_with_mxfp4_linear( model, diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 96af3c16e0de..33308a8ef3f7 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -1068,7 +1068,7 @@ def __init__(self): def shard_and_distribute_module( - model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh, set_param_inside=True + model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh, set_param=True ): # TODO: rename to shard_and_distribute_param r""" Main uses cases: @@ -1118,7 +1118,7 @@ def shard_and_distribute_module( # otherwise loading is crazy slow if not isinstance(param, torch.nn.Parameter): param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point()) - if set_param_inside: + if set_param: setattr(module_to_tp, param_type, param) # module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True) return param diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b5e18a047f4c..87adc656215f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -6174,9 +6174,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, if param_name in tied_param_names: continue - # Exception in the case of MXFP4 quantization, we need to update the param name to the original param name + # For example in the case of MXFP4 quantization, we need to update the param name to the original param name # because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name - if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {QuantizationMethod.MXFP4} and hf_quantizer.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name): + if hf_quantizer is not None: param_name = hf_quantizer.update_param_name(param_name) try: diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 93fff7be3d40..ba0034498a94 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -237,14 +237,10 @@ def postprocess_model(self, model: "PreTrainedModel", **kwargs): """ return self._process_model_after_weight_loading(model, **kwargs) - def dequantize(self, model): + def remove_quantization_config(self, model): """ - Potentially dequantize the model to retrieve the original model, with some loss in accuracy / performance. - Note not all quantization schemes support this. + Remove the quantization config from the model. """ - model = self._dequantize(model) - - # Delete quantizer and quantization config if hasattr(model, "hf_quantizer"): del model.hf_quantizer if hasattr(model.config, "quantization_config"): @@ -255,6 +251,20 @@ def dequantize(self, model): del model.quantization_method model.is_quantized = False + def dequantize(self, model): + """ + Potentially dequantize the model to retrieve the original model, with some loss in accuracy / performance. + Note not all quantization schemes support this. + """ + model = self._dequantize(model) + + # Delete quantizer and quantization config + del model.hf_quantizer + del model.config.quantization_config + del model.config._pre_quantization_dtype + del model.quantization_method + model.is_quantized = False + return model def get_cuda_warm_up_factor(self): diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index afa5b9c06ec8..2d93201c82ad 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -177,13 +177,11 @@ def create_quantized_param( (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0) - torch.cuda.empty_cache() loaded_weight, flex, mx = quantize_to_mxfp4( loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale ) - torch.cuda.empty_cache() module.gate_up_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) elif "down_proj" in param_name: @@ -193,16 +191,15 @@ def create_quantized_param( (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0).to(target_device) - torch.cuda.empty_cache() loaded_weight, flex, mx = quantize_to_mxfp4( loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale ) - torch.cuda.empty_cache() module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) - # we take this path if already quantized but not in a compatible way: + # we take this path if already quantized but not in a compatible way + # The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales else: empty_param = kwargs.get("empty_param", None) casting_dtype = kwargs.get("casting_dtype", None) @@ -214,48 +211,44 @@ def create_quantized_param( module, _ = get_module_from_name(model, param_name[:-len("_blocks")]) else: module, _ = get_module_from_name(model, param_name) + + shard_kwargs = { + "empty_param": empty_param, + "casting_dtype": casting_dtype, + "to_contiguous": to_contiguous, + "rank": rank, + "device_mesh": device_mesh, + "model": model, + } + if isinstance(module, Mxfp4OpenAIMoeExperts) or (isinstance(module, OpenAIMoeExperts) and self.quantization_config.dequantize): - tp_mode = kwargs.get("device_mesh", None) is not None if self.quantization_config.dequantize: - # neutral_param_name is the name of the parameter without the blocks or scales suffix, it's used in this case since we don't switch linears + # dq_param_name is the name of the parameter without the blocks or scales suffix, it's used in this case since we don't switch linears # so we only have the original param name - neutral_param_name = param_name[:-len("_blocks")] if "blocks" in param_name else param_name[:-len("_scales")] + dq_param_name = param_name[:-len("_blocks")] dequantize( module, param_name, - model, param_value, - neutral_param_name, target_device, - empty_param, - casting_dtype, - to_contiguous, - rank, - device_mesh + dq_param_name, + **shard_kwargs ) else: dequantize_and_quantize( module, param_name, - tp_mode, - model, param_value, target_device, - empty_param, - casting_dtype, - to_contiguous, - rank, - device_mesh, self.swizzle_mx_value, - self.swizzle_mx_scale + self.swizzle_mx_scale, + **shard_kwargs ) - def _dequantize(self, model): - return model - def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): # we are not really dequantizing, we are just removing everthing related to quantization here - self.dequantize(model) + if self.quantization_config.dequantize: + self.remove_quantization_config(model) def update_expected_keys(self, model: "PreTrainedModel", expected_keys: list[str], checkpoint_keys: list[str]): # Replace expected_keys for experts' gate_up_proj and down_proj with their _blocks and _scales variants @@ -312,10 +305,8 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li def update_tp_plan(self, config): if "OpenAIMoeConfig" in config.__class__.__name__: - if not hasattr(config, "base_model_tp_plan") or config.base_model_tp_plan is None: - config.base_model_tp_plan = {} - if not hasattr(config, "base_model_ep_plan") or config.base_model_ep_plan is None: - config.base_model_ep_plan = {} + if (not hasattr(config, "base_model_tp_plan") or config.base_model_tp_plan is None) or (not hasattr(config, "base_model_ep_plan") or config.base_model_ep_plan is None): + return config # Update TP plan with scales and blocks config.base_model_tp_plan.update({ @@ -337,10 +328,11 @@ def update_tp_plan(self, config): def update_param_name(self, param_name: str) -> str: - if "_blocks" in param_name: - return param_name.replace("_blocks", "") - elif "_scales" in param_name: - return param_name.replace("_scales", "") + if self.quantization_config.dequantize: + if "_blocks" in param_name: + return param_name.replace("_blocks", "") + elif "_scales" in param_name: + return param_name.replace("_scales", "") return param_name def is_serializable(self, safe_serialization=None): diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index fd5f62ec28c0..181eff65cb02 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -66,6 +66,7 @@ from .utils import ( ACCELERATE_MIN_VERSION, GGUF_MIN_VERSION, + TRITON_MIN_VERSION, is_accelerate_available, is_apex_available, is_apollo_torch_available, @@ -168,6 +169,8 @@ is_torchcodec_available, is_torchdynamo_available, is_torchvision_available, + is_triton_available, + is_triton_kernels_availalble, is_vision_available, is_vptq_available, strtobool, @@ -454,6 +457,23 @@ def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION): is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}" )(test_case) +def require_triton(min_version: str = TRITON_MIN_VERSION): + """ + Decorator marking a test that requires triton. These tests are skipped when triton isn't installed. + """ + def decorator(test_case): + return unittest.skipUnless(is_triton_available(min_version), f"test requires triton version >= {min_version}")( + test_case + ) + return decorator + +def require_triton_kernels(test_case): + """ + Decorator marking a test that requires triton_kernels. These tests are skipped when triton_kernels isn't installed. + """ + return unittest.skipUnless(is_triton_kernels_availalble(), "test requires triton_kernels")( + test_case + ) def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION): """ diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 4b96244f75b6..cf8bfd55fd34 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -114,6 +114,7 @@ from .import_utils import ( ACCELERATE_MIN_VERSION, ENV_VARS_TRUE_AND_AUTO_VALUES, + TRITON_MIN_VERSION, ENV_VARS_TRUE_VALUES, GGUF_MIN_VERSION, TORCH_FX_REQUIRED_VERSION, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index efd6c45e710f..cb9f4e173563 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -111,7 +111,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ VPTQ_MIN_VERSION = "0.0.4" TORCHAO_MIN_VERSION = "0.4.0" AUTOROUND_MIN_VERSION = "0.5.0" -TRITON_MIN_VERSION = "3.4.0" +TRITON_MIN_VERSION = "1.0.0" _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _apex_available = _is_package_available("apex") diff --git a/tests/quantization/mxfp4/__init__.py b/tests/quantization/mxfp4/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/quantization/mxfp4/test_mxfp4.py b/tests/quantization/mxfp4/test_mxfp4.py new file mode 100644 index 000000000000..0e1e636a333d --- /dev/null +++ b/tests/quantization/mxfp4/test_mxfp4.py @@ -0,0 +1,428 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest +from unittest.mock import patch + +import pytest + +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Mxfp4Config, OpenAIMoeForCausalLM +from transformers.testing_utils import ( + require_accelerate, + require_torch_large_gpu, + require_triton, + require_triton_kernels, + require_torch, + require_torch_gpu, + slow, + torch_device, +) +from transformers.utils import ( + is_triton_available, + is_triton_kernels_availalble, + is_torch_available, + is_accelerate_available, +) + + +if is_torch_available(): + import torch + + +class Mxfp4ConfigTest(unittest.TestCase): + def test_basic_config_creation(self): + """Test basic configuration creation with default values""" + config = Mxfp4Config() + self.assertEqual(config.quant_method.value, "mxfp4") + self.assertIsNone(config.modules_to_not_convert) + self.assertFalse(config.dequantize) + + def test_config_with_modules_to_not_convert(self): + """Test configuration with modules to not convert""" + modules = ["model.layers.*.self_attn", "lm_head"] + config = Mxfp4Config(modules_to_not_convert=modules) + self.assertEqual(config.modules_to_not_convert, modules) + + def test_config_with_dequantize(self): + """Test configuration with dequantize enabled""" + config = Mxfp4Config(dequantize=True) + self.assertTrue(config.dequantize) + + def test_get_loading_attributes(self): + """Test get_loading_attributes method""" + config = Mxfp4Config(dequantize=True) + attrs = config.get_loading_attributes() + self.assertEqual(attrs, {"dequantize": True}) + + def test_to_dict(self): + """Test configuration serialization to dict""" + config = Mxfp4Config(modules_to_not_convert=["lm_head"], dequantize=True) + config_dict = config.to_dict() + self.assertEqual(config_dict["quant_method"], "mxfp4") + self.assertEqual(config_dict["modules_to_not_convert"], ["lm_head"]) + self.assertTrue(config_dict["dequantize"]) + + def test_from_dict(self): + """Test configuration creation from dict""" + config_dict = {"quant_method": "mxfp4", "modules_to_not_convert": ["lm_head"], "dequantize": True} + config = Mxfp4Config.from_dict(config_dict) + self.assertEqual(config.modules_to_not_convert, ["lm_head"]) + self.assertTrue(config.dequantize) + + +class Mxfp4QuantizerTest(unittest.TestCase): + """Test the Mxfp4HfQuantizer class""" + + def setUp(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def test_quantizer_validation_no_torch(self): + """Test quantizer validation when torch is not available""" + with patch("transformers.quantizers.quantizer_mxfp4.is_torch_available", return_value=False): + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config() + quantizer = Mxfp4HfQuantizer(config) + + with self.assertRaises(ImportError): + quantizer.validate_environment() + + def test_quantizer_validation_no_cuda(self): + """Test quantizer validation when CUDA is not available""" + with patch("torch.cuda.is_available", return_value=False): + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config() + quantizer = Mxfp4HfQuantizer(config) + + with self.assertRaises(RuntimeError): + quantizer.validate_environment() + + def test_quantizer_validation_low_compute_capability(self): + """Test quantizer validation with low compute capability""" + with patch("torch.cuda.get_device_capability", return_value=(8, 0)): + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config() + quantizer = Mxfp4HfQuantizer(config) + + with self.assertRaises(ValueError): + quantizer.validate_environment() + + def test_quantizer_validation_low_compute_capability_with_dequantize(self): + """Test quantizer validation with low compute capability but dequantize enabled""" + with patch("torch.cuda.get_device_capability", return_value=(8, 0)): + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config(dequantize=True) + quantizer = Mxfp4HfQuantizer(config) + + # Should not raise error with dequantize=True + try: + quantizer.validate_environment() + except ValueError as e: + if "compute capability" in str(e): + self.fail("Should not raise compute capability error when dequantize=True") + + def test_quantizer_validation_missing_triton(self): + """Test quantizer validation when triton is not available""" + with ( + patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False), + patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False), + ): + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config() + quantizer = Mxfp4HfQuantizer(config) + quantizer.pre_quantized = False + with self.assertRaises(ValueError): + quantizer.validate_environment() + + def test_quantizer_validation_missing_triton_pre_quantized_no_dequantize(self): + """Test quantizer validation when triton is not available but model is pre-quantized and dequantize is False""" + with ( + patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False), + patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False), + ): + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config() + quantizer = Mxfp4HfQuantizer(config) + quantizer.pre_quantized = True + + # Should automatically set dequantize=True and warn + quantizer.validate_environment() + self.assertTrue(quantizer.quantization_config.dequantize) + + def test_update_torch_dtype(self): + """Test torch dtype updating""" + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config() + quantizer = Mxfp4HfQuantizer(config) + + # Should default to bfloat16 + result_dtype = quantizer.update_torch_dtype(None) + self.assertEqual(result_dtype, torch.bfloat16) + + # Should preserve existing dtype + result_dtype = quantizer.update_torch_dtype(torch.float32) + self.assertEqual(result_dtype, torch.float32) + + def test_update_expected_keys(self): + """Test expected keys updating for quantized models""" + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config() + quantizer = Mxfp4HfQuantizer(config) + + expected_keys = [ + "model.layers.0.mlp.experts.gate_up_proj", + "model.layers.0.mlp.experts.down_proj", + "model.embed_tokens.weight", + ] + + updated_keys = quantizer.update_expected_keys(None, expected_keys, []) + + expected_updated = [ + "model.layers.0.mlp.experts.gate_up_proj_blocks", + "model.layers.0.mlp.experts.gate_up_proj_scales", + "model.layers.0.mlp.experts.down_proj_blocks", + "model.layers.0.mlp.experts.down_proj_scales", + "model.embed_tokens.weight", + ] + + self.assertEqual(set(updated_keys), set(expected_updated)) + + def test_update_param_name_dequantize(self): + """Test parameter name updating when dequantizing""" + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config(dequantize=True) + quantizer = Mxfp4HfQuantizer(config) + + # Should remove _blocks suffix + param_name = "model.layers.0.mlp.experts.gate_up_proj_blocks" + updated_name = quantizer.update_param_name(param_name) + self.assertEqual(updated_name, "model.layers.0.mlp.experts.gate_up_proj") + + # Should remove _scales suffix + param_name = "model.layers.0.mlp.experts.down_proj_scales" + updated_name = quantizer.update_param_name(param_name) + self.assertEqual(updated_name, "model.layers.0.mlp.experts.down_proj") + + # Should not change other names + param_name = "model.embed_tokens.weight" + updated_name = quantizer.update_param_name(param_name) + self.assertEqual(updated_name, "model.embed_tokens.weight") + + def test_update_param_name_no_dequantize(self): + """Test parameter name updating when not dequantizing""" + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config(dequantize=False) + quantizer = Mxfp4HfQuantizer(config) + + param_name = "model.layers.0.mlp.experts.gate_up_proj_blocks" + updated_name = quantizer.update_param_name(param_name) + self.assertEqual(updated_name, param_name) + + def test_is_serializable(self): + """Test serialization capability""" + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config() + quantizer = Mxfp4HfQuantizer(config) + + # MXFP4 is not serializable with safetensors + self.assertFalse(quantizer.is_serializable()) + + def test_is_trainable(self): + """Test trainability""" + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config() + quantizer = Mxfp4HfQuantizer(config) + + # MXFP4 is not trainable + self.assertFalse(quantizer.is_trainable) + + +class Mxfp4IntegrationTest(unittest.TestCase): + """Test mxfp4 integration functions""" + + def test_should_convert_module(self): + """Test module conversion decision logic""" + from transformers.integrations.mxfp4 import should_convert_module + + # Should convert by default + self.assertTrue(should_convert_module(["model", "layers", "0", "mlp"], [])) + + # Should not convert if in exclusion list + patterns = ["model.layers.*.self_attn", "lm_head"] + self.assertFalse(should_convert_module(["model", "layers", "0", "self_attn"], patterns)) + self.assertFalse(should_convert_module(["lm_head"], patterns)) + + # Should convert if not in exclusion list + self.assertTrue(should_convert_module(["model", "layers", "0", "mlp", "experts"], patterns)) + + @require_torch + def test_convert_moe_packed_tensors(self): + """Test unpacking of quantized tensors""" + from transformers.integrations.mxfp4 import convert_moe_packed_tensors + + # Create dummy packed tensors + blocks = torch.randint(0, 255, (2, 4, 8), dtype=torch.uint8) + scales = torch.randint(100, 150, (2, 4), dtype=torch.uint8) + + result = convert_moe_packed_tensors(blocks, scales, dtype=torch.bfloat16) + + # Check output shape - should be [2, 4, 16] (8 * 2 for unpacking) + self.assertEqual(result.shape, (2, 4 * 16)) + self.assertEqual(result.dtype, torch.bfloat16) + + @require_triton(min_version="3.4.0") + @require_triton_kernels + @require_torch_gpu + @require_torch + def test_quantize_to_mxfp4(self): + """Test quantization function""" + from transformers.integrations.mxfp4 import quantize_to_mxfp4 + + # Create dummy weight tensor + w = torch.randn(32, 64, 128, dtype=torch.bfloat16, device=torch.device("cuda")) + + quantized_w, flex_data, mx_ctx = quantize_to_mxfp4(w, None, None) + + # Check that shapes are reasonable + self.assertEqual(quantized_w.dtype, torch.uint8) + self.assertIsNotNone(flex_data) + self.assertIsNotNone(mx_ctx) + + +# @require_torch +# @require_torch_large_gpu +# @slow +class Mxfp4ModelTest(unittest.TestCase): + """Test mxfp4 with actual models (requires specific model and hardware)""" + + # These should be paths to real OpenAI MoE models for proper testing + model_name_packed = "/fsx/mohamed/oai-hf/tests/20b_converted_packed" # TODO: Use real packed quantized model + + input_text = "Once upon a time" + + # Expected outputs for generation tests + EXPECTED_OUTPUTS = set() + EXPECTED_OUTPUTS.add("Once upon a time, in a small village, there lived a young") + + def setUp(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def tearDown(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def check_inference_correctness_quantized(self, model, tokenizer): + # Check that inference pass works on the model + encoded_input = tokenizer(self.input_text, return_tensors="pt").to(model.device) + + # Set pad token if not set + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + with torch.no_grad(): + output_sequences = model.generate( + **encoded_input, + max_new_tokens=10, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + use_cache=True, + ) + + generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True) + + self.assertIn(generated_text, self.EXPECTED_OUTPUTS) + + def test_openai_moe_model_loading_quantized_with_device_map(self): + """Test loading OpenAI MoE model with mxfp4 quantization and device_map""" + + quantization_config = Mxfp4Config(dequantize=False) + + # Test that config is properly set up + self.assertFalse(quantization_config.dequantize) + + model = OpenAIMoeForCausalLM.from_pretrained( + self.model_name_packed, + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed) + self.check_inference_correctness_quantized(model, tokenizer) + + def test_openai_moe_model_loading_dequantized_with_device_map(self): + """Test loading OpenAI MoE model with mxfp4 dequantization and device_map""" + + quantization_config = Mxfp4Config(dequantize=True) + + # Test that config is properly set up + self.assertTrue(quantization_config.dequantize) + + model = OpenAIMoeForCausalLM.from_pretrained( + self.model_name_packed, + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed) + self.check_inference_correctness_quantized(model, tokenizer) + + def test_model_device_map_validation(self): + """Test device map validation""" + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config() + quantizer = Mxfp4HfQuantizer(config) + quantizer.pre_quantized = False + + # Test with CPU in device map (should raise error for non-pre-quantized) + with self.assertRaises(ValueError): + quantizer.validate_environment(device_map={"": "cpu"}) + + def test_memory_footprint_comparison(self): + """Test memory footprint differences between quantized and unquantized models""" + + # Expected: quantized < dequantized < unquantized memory usage + quantization_config = Mxfp4Config(dequantize=True) + quantized_model = OpenAIMoeForCausalLM.from_pretrained( + self.model_name_packed, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + dequantized_model = OpenAIMoeForCausalLM.from_pretrained( + self.model_name_packed, + torch_dtype=torch.bfloat16, + device_map="auto", + quantization_config=quantization_config, + ) + quantized_mem = quantized_model.get_memory_footprint() + dequantized_mem = dequantized_model.get_memory_footprint() + self.assertLess(quantized_mem, dequantized_mem) \ No newline at end of file From 345afb135efac29562ca2fc9d4e52c0a4b3d4a48 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 28 Jul 2025 14:16:34 +0000 Subject: [PATCH 221/342] fix ! --- src/transformers/integrations/tensor_parallel.py | 9 +++++---- src/transformers/modeling_utils.py | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 6e3bb6dc8b3b..9e0fd8d3895c 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -525,12 +525,12 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False) return param -class VocabParallel(TensorParallelLayer): +class VocabParallel(TensorParallelLayer): """ VocabParallel is a tensor parallel layer that shards the embedding table along the last dimension. No need to do input masking as embedding would be stored in `_MaskPartial` which handles it (https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_embedding_ops.py#L70) - + This is useful if you want to train with long sequence length! """ @@ -582,6 +582,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me outputs = outputs.redistribute(placements=output_layouts, async_op=False) return outputs.to_local() if use_local_output else outputs + class ColwiseParallel(TensorParallelLayer): """ General tensor parallel layer for transformers. @@ -1143,12 +1144,12 @@ def distribute_model(model, distributed_config, device_mesh, tp_size): model._tp_size = tp_size model._device_mesh = device_mesh if distributed_config is not None: - distributed_config = DistributedConfig.from_config(distributed_config) + if isinstance(DistributedConfig, dict): + distributed_config = DistributedConfig.from_dict(distributed_config) if distributed_config.enable_expert_parallel: _plan = "_ep_plan" model._tp_plan = getattr(model.config, "base_model_ep_plan", model._tp_plan).copy() - # now fetch my childrens for name, module in model.named_children(): if plan := getattr(module, _plan, getattr(module, "tp_plan", None)): model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 746f8f71fa20..88fbed869d82 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4561,7 +4561,6 @@ def from_pretrained( load_in_8bit = kwargs.pop("load_in_8bit", False) load_in_4bit = kwargs.pop("load_in_4bit", False) quantization_config = kwargs.pop("quantization_config", None) - distributed_config = kwargs.pop("distributed_config", None) subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) From 009355a6d1873ae64c9baa7636b10ce9eef3da76 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 28 Jul 2025 15:58:31 +0000 Subject: [PATCH 222/342] fix registering custom attention --- src/transformers/modeling_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 88fbed869d82..bad89184c60e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2699,10 +2699,10 @@ def _check_and_adjust_attn_implementation( kernel_function = partial(flash_attention_forward, implementation=kernel) elif kernel_name is not None: kernel_function = getattr(kernel, kernel_name) - # Register it - ALL_ATTENTION_FUNCTIONS.register(repo_id, kernel_function) - ALL_MASK_ATTENTION_FUNCTIONS.register(repo_id, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) - applicable_attn_implementation = repo_id + ALL_ATTENTION_FUNCTIONS.register(applicable_attn_implementation, kernel_function) + ALL_MASK_ATTENTION_FUNCTIONS.register( + applicable_attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] + ) except Exception as e: logger.warning_once( f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using " From d237a90c0de1d6595362e35afc25dc917b4b9b83 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 29 Jul 2025 08:48:10 +0000 Subject: [PATCH 223/342] fix order --- src/transformers/integrations/mxfp4.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index cdb5d6b34d8f..e9d620699d7f 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -168,19 +168,6 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter return intermediate_cache3 -def mlp_forward(self, hidden_states): - import torch.distributed as dist - if dist.is_available() and dist.is_initialized(): - routing = routing_torch_dist - else: - from triton_kernels.routing import routing - routing = routing - hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) - router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) - routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) - routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) - return routed_out, router_logits - def routing_torch_dist( logits, n_expts_act, @@ -245,6 +232,19 @@ def topk(vals, k, expt_indx): hitted_experts = n_expts_act return RoutingData(gate_scal, hist, n_local_experts, hitted_experts, expt_data), gather_indx, scatter_indx +def mlp_forward(self, hidden_states): + import torch.distributed as dist + if dist.is_available() and dist.is_initialized(): + routing = routing_torch_dist + else: + from triton_kernels.routing import routing + routing = routing + hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) + router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) + routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) + routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) + return routed_out, router_logits + def should_convert_module(current_key_name, patterns): current_key_name_str = ".".join(current_key_name) if not any( From ccffc0b90bbf8c5d543c73978c97e9743ab5d937 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 29 Jul 2025 08:53:16 +0000 Subject: [PATCH 224/342] fixes --- src/transformers/integrations/mxfp4.py | 65 +++++++++++++++++++------- 1 file changed, 48 insertions(+), 17 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index e9d620699d7f..4323e23d2ac6 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -45,12 +45,15 @@ def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): axis=1, swizzle_axis=swizzle_axis, swizzle_scale=swizzle_mx_scale, - swizzle_value=swizzle_mx_value) + swizzle_value=swizzle_mx_value + ) + return w, InFlexData(), MicroscalingCtx( weight_scale=mx_scales, swizzle_scale=swizzle_mx_scale, swizzle_value=swizzle_mx_value, - actual_weight_scale_shape=weight_scale_shape) + actual_weight_scale_shape=weight_scale_shape + ) def convert_moe_packed_tensors( blocks, @@ -95,12 +98,14 @@ def convert_moe_packed_tensors( del idx_lo, idx_hi, blk, exp out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + return out class Mxfp4OpenAIMoeExperts(nn.Module): def __init__(self, config): super().__init__() + self.num_experts = config.num_local_experts self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size @@ -139,23 +144,28 @@ def __init__(self, config): def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs from triton_kernels.swiglu import swiglu_fn + with torch.cuda.device(hidden_states.device): act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),(self.alpha, None), 2) if self.hidden_size_pad is not None: - hidden_states = torch.nn.functional.pad(hidden_states, - (0, self.hidden_size_pad, 0, 0), - mode="constant", - value=0) - - intermediate_cache1 = matmul_ogs(hidden_states, - self.gate_up_proj, - self.gate_up_proj_bias.to(torch.float32), - routing_data, - gather_indx=gather_idx, - precision_config=self.gate_up_proj_precision_config, - gammas=None, - fused_activation=act) + hidden_states = torch.nn.functional.pad( + hidden_states, + (0, self.hidden_size_pad, 0, 0), + mode="constant", + value=0 + ) + + intermediate_cache1 = matmul_ogs( + hidden_states, + self.gate_up_proj, + self.gate_up_proj_bias.to(torch.float32), + routing_data, + gather_indx=gather_idx, + precision_config=self.gate_up_proj_precision_config, + gammas=None, + fused_activation=act + ) intermediate_cache3 = matmul_ogs( intermediate_cache1, @@ -239,6 +249,7 @@ def mlp_forward(self, hidden_states): else: from triton_kernels.routing import routing routing = routing + hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) @@ -309,14 +320,32 @@ def dequantize_and_quantize(module, param_name, param_value, target_device, swiz scales = getattr(module, scales_attr) if blocks.device.type == "meta" and scales.device.type == "meta": if device_mesh is not None: - shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) + shard_and_distribute_module( + model, + param_value, + empty_param, + param_name, + casting_dtype, + to_contiguous, + rank, + device_mesh + ) else: _load_parameter_into_model(model, param_name, param_value) return else: # One of the params is already loaded, so load the other if device_mesh is not None: - shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) + shard_and_distribute_module( + model, + param_value, + empty_param, + param_name, + casting_dtype, + to_contiguous, + rank, + device_mesh + ) else: _load_parameter_into_model(model, param_name, param_value) @@ -337,8 +366,10 @@ def dequantize_and_quantize(module, param_name, param_value, target_device, swiz del dequantized with torch.cuda.device(target_device): loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, swizzle_mx_value, swizzle_mx_scale) + setattr(module, precision_config_attr, PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex))) setattr(module, proj, torch.nn.Parameter(loaded_weight, requires_grad=False)) + return def _replace_with_mxfp4_linear( From f92878af227d0fc585801b870fcc37c5735dcb32 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 29 Jul 2025 09:39:15 +0000 Subject: [PATCH 225/342] some test nits --- tests/quantization/mxfp4/test_mxfp4.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/quantization/mxfp4/test_mxfp4.py b/tests/quantization/mxfp4/test_mxfp4.py index 0e1e636a333d..0c2f33ca1d2a 100644 --- a/tests/quantization/mxfp4/test_mxfp4.py +++ b/tests/quantization/mxfp4/test_mxfp4.py @@ -13,28 +13,20 @@ # limitations under the License. import gc -import tempfile import unittest from unittest.mock import patch -import pytest - -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Mxfp4Config, OpenAIMoeForCausalLM +from transformers import AutoTokenizer, Mxfp4Config, OpenAIMoeForCausalLM from transformers.testing_utils import ( - require_accelerate, + require_torch, + require_torch_gpu, require_torch_large_gpu, require_triton, require_triton_kernels, - require_torch, - require_torch_gpu, slow, - torch_device, ) from transformers.utils import ( - is_triton_available, - is_triton_kernels_availalble, is_torch_available, - is_accelerate_available, ) @@ -315,9 +307,9 @@ def test_quantize_to_mxfp4(self): self.assertIsNotNone(mx_ctx) -# @require_torch -# @require_torch_large_gpu -# @slow +@require_torch +@require_torch_large_gpu +@slow class Mxfp4ModelTest(unittest.TestCase): """Test mxfp4 with actual models (requires specific model and hardware)""" From 90522c41e8d2fe878cfb406ab2508a3e3d1ec243 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 29 Jul 2025 14:44:34 +0000 Subject: [PATCH 226/342] nits --- src/transformers/integrations/mxfp4.py | 9 ++++++--- src/transformers/quantizers/quantizer_mxfp4.py | 15 ++++++++++----- tests/quantization/mxfp4/test_mxfp4.py | 2 +- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 4323e23d2ac6..7492d8b417cf 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -32,6 +32,8 @@ -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, ] +# Copied from GPT_OSS repo +# TODO: Add absolute link when the repo is public def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): from triton_kernels.matmul_ogs import InFlexData, MicroscalingCtx from triton_kernels.numerics_details.mxfp import downcast_to_mxfp @@ -55,6 +57,8 @@ def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): actual_weight_scale_shape=weight_scale_shape ) +# Copied from GPT_OSS repo +# TODO: Add absolute link when the repo is public def convert_moe_packed_tensors( blocks, scales, @@ -178,6 +182,8 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter return intermediate_cache3 +# Adapted from GPT_OSS repo +# TODO: Add absolute link when the repo is public def routing_torch_dist( logits, n_expts_act, @@ -352,9 +358,6 @@ def dequantize_and_quantize(module, param_name, param_value, target_device, swiz dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) dequantized = dequantized.transpose(1, 2).to(target_device) - module.device_mesh = device_mesh - module.rank = rank - right_pad = getattr(module, right_pad_attr) bottom_pad = getattr(module, bottom_pad_attr) loaded_weight = torch.nn.functional.pad( diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 2d93201c82ad..019d7976e7d3 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -59,6 +59,14 @@ def validate_environment(self, *args, **kwargs): if not torch.cuda.is_available(): raise RuntimeError("Using MXFP4 quantized models requires a GPU") + if not is_accelerate_available(): + raise ImportError( + "Using mxfp4 requires Accelerate: `pip install 'accelerate>=1.8.0'`" + ) + + if self.quantization_config.dequantize: + return + compute_capability = torch.cuda.get_device_capability() major, minor = compute_capability @@ -68,20 +76,17 @@ def validate_environment(self, *args, **kwargs): "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16" ) self.quantization_config.dequantize = True + return else: # we can't quantize the model in this case so we raise an error raise ValueError( "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed" ) - if major < 9 and not self.quantization_config.dequantize: + if major < 9 : raise ValueError( "MXFP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100, or B100)" ) - if not is_accelerate_available(): - raise ImportError( - "Using mxfp4 requires Accelerate: `pip install 'accelerate>=1.8.0'`" - ) device_map = kwargs.get("device_map", None) if device_map is None: diff --git a/tests/quantization/mxfp4/test_mxfp4.py b/tests/quantization/mxfp4/test_mxfp4.py index 0c2f33ca1d2a..7dead0fff025 100644 --- a/tests/quantization/mxfp4/test_mxfp4.py +++ b/tests/quantization/mxfp4/test_mxfp4.py @@ -346,7 +346,7 @@ def check_inference_correctness_quantized(self, model, tokenizer): max_new_tokens=10, do_sample=False, pad_token_id=tokenizer.eos_token_id, - use_cache=True, + use_cache=False, ) generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True) From dbb8b20a80b23cf69423268ec9ac874597025da7 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 29 Jul 2025 14:48:34 +0000 Subject: [PATCH 227/342] nit --- src/transformers/quantizers/quantizer_mxfp4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 019d7976e7d3..50c9033d1524 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -61,7 +61,7 @@ def validate_environment(self, *args, **kwargs): if not is_accelerate_available(): raise ImportError( - "Using mxfp4 requires Accelerate: `pip install 'accelerate>=1.8.0'`" + "Using mxfp4 requires Accelerate: `pip install accelerate`" ) if self.quantization_config.dequantize: From edd92321acbfd4550c5210f17f8153ef44863eff Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Tue, 29 Jul 2025 17:57:18 +0000 Subject: [PATCH 228/342] fix --- src/transformers/quantizers/quantizer_mxfp4.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 50c9033d1524..dc20ab3cb284 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -166,7 +166,8 @@ def create_quantized_param( unexpected_keys: Optional[list[str]] = None, **kwargs, ): - from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + if is_triton_kernels_availalble(): + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig from ..integrations import Mxfp4OpenAIMoeExperts, dequantize, dequantize_and_quantize, quantize_to_mxfp4 from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts From dc2b16fef22119068c854aa8c31cc1a847eb4296 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 29 Jul 2025 23:56:57 +0000 Subject: [PATCH 229/342] Clamp sink logits --- src/transformers/models/openai_moe/modeling_openai_moe.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 38ffedb79f37..2f07dd23eb3d 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -248,8 +248,12 @@ def eager_attention_forward( # TODO: check wether both produce the same results or not! # scale the logits to prevent overflows logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values - sinks = torch.exp(sinks - logits_max) - unnormalized_scores = torch.exp(attn_weights - logits_max) + # sinks = torch.exp(sinks - logits_max) + # unnormalized_scores = torch.exp(attn_weights - logits_max) + # lewtun: Prevent overflow in BF16/FP16: exp(>80) → inf + sinks = torch.exp(torch.clamp(sinks - logits_max, max=80)) + unnormalized_scores = torch.exp(torch.clamp(attn_weights - logits_max, min=-80)) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks scores = unnormalized_scores / normalizer From b0508307ae54d69f8c6238e8828ccfe2f917440a Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 30 Jul 2025 00:14:15 +0000 Subject: [PATCH 230/342] Clean --- src/transformers/models/openai_moe/modeling_openai_moe.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 2f07dd23eb3d..0043ccf44f09 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -248,9 +248,7 @@ def eager_attention_forward( # TODO: check wether both produce the same results or not! # scale the logits to prevent overflows logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values - # sinks = torch.exp(sinks - logits_max) - # unnormalized_scores = torch.exp(attn_weights - logits_max) - # lewtun: Prevent overflow in BF16/FP16: exp(>80) → inf + # prevent overflow in BF16/FP16 when training with bsz>1. we clamp in the range [-80, 80] because exp(>80) ~ inf sinks = torch.exp(torch.clamp(sinks - logits_max, max=80)) unnormalized_scores = torch.exp(torch.clamp(attn_weights - logits_max, min=-80)) From e0e406ecb9b70532920c71ac81f8ccd05707f859 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 30 Jul 2025 01:37:49 +0000 Subject: [PATCH 231/342] Soft-max trick --- .../models/openai_moe/modeling_openai_moe.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 0043ccf44f09..9d0ec7739e84 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -247,13 +247,22 @@ def eager_attention_forward( sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) # TODO: check wether both produce the same results or not! # scale the logits to prevent overflows - logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values - # prevent overflow in BF16/FP16 when training with bsz>1. we clamp in the range [-80, 80] because exp(>80) ~ inf - sinks = torch.exp(torch.clamp(sinks - logits_max, max=80)) - unnormalized_scores = torch.exp(torch.clamp(attn_weights - logits_max, min=-80)) - - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - scores = unnormalized_scores / normalizer + # logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values + # # prevent overflow in BF16/FP16 when training with bsz>1. we clamp in the range [-80, 80] because exp(>80) ~ inf + # sinks = torch.exp(torch.clamp(sinks - logits_max, max=80)) + # unnormalized_scores = torch.exp(torch.clamp(attn_weights - logits_max, min=-80)) + + # normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + # scores = unnormalized_scores / normalizer + + # Numerically‑safe soft‑max: combine sinks with K‑logits, + # subtract per‑row max, then soft‑max once. + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + # soft‑max trick + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + # drop sink prob + scores = probs[..., :-1] # TODO we are going with this one to fix gradients becoming nans # sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) From 54e8825461524296e7af420e42da4eb1d3a57cb5 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 30 Jul 2025 01:45:11 +0000 Subject: [PATCH 232/342] Clean up --- .../models/openai_moe/modeling_openai_moe.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 9d0ec7739e84..43c07d4d3317 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -247,18 +247,10 @@ def eager_attention_forward( sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) # TODO: check wether both produce the same results or not! # scale the logits to prevent overflows - # logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values - # # prevent overflow in BF16/FP16 when training with bsz>1. we clamp in the range [-80, 80] because exp(>80) ~ inf - # sinks = torch.exp(torch.clamp(sinks - logits_max, max=80)) - # unnormalized_scores = torch.exp(torch.clamp(attn_weights - logits_max, min=-80)) - # normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - # scores = unnormalized_scores / normalizer - - # Numerically‑safe soft‑max: combine sinks with K‑logits, - # subtract per‑row max, then soft‑max once. + # combine sinks with K-logits to prevent overflow in BF16/FP16 when training with bsz>1 + # subtract per‑row max, then soft‑max once (soft-max trick) combined_logits = torch.cat([attn_weights, sinks], dim=-1) - # soft‑max trick combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) # drop sink prob From 0378ae863aee67815d22d954c254f42f28328f54 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 30 Jul 2025 02:03:07 +0000 Subject: [PATCH 233/342] p --- src/transformers/models/openai_moe/modeling_openai_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 43c07d4d3317..a85e64e9f166 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -248,7 +248,7 @@ def eager_attention_forward( # TODO: check wether both produce the same results or not! # scale the logits to prevent overflows - # combine sinks with K-logits to prevent overflow in BF16/FP16 when training with bsz>1 + # combine sinks with attention weight to prevent overflow in BF16/FP16 when training with bsz>1 # subtract per‑row max, then soft‑max once (soft-max trick) combined_logits = torch.cat([attn_weights, sinks], dim=-1) combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values From 077cfeef4625345d82d667daa261fd52b96e4c46 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 30 Jul 2025 09:23:48 +0000 Subject: [PATCH 234/342] fix deepspeed --- src/transformers/modeling_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bd41477330c1..057b05f9aee2 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -881,8 +881,11 @@ def _load_state_dict_into_meta_model( # and then cast it to CPU to avoid excessive memory usage on each GPU # in comparison to the sharded model across GPUs. if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): + param_name = hf_quantizer.update_param_name(param_name) module, param_type = get_module_from_name(model, param_name) value = getattr(module, param_type) + if value.device.type == "meta": + continue param_to = "cpu" if is_fsdp_enabled() and not is_local_dist_rank_0(): param_to = "meta" From bec11b7970dbedf71b9e59a0494b7e87bac4338b Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 30 Jul 2025 11:40:21 +0200 Subject: [PATCH 235/342] update both modeling and modular for cleanup --- .../models/openai_moe/modeling_openai_moe.py | 96 +++------ .../models/openai_moe/modular_openai_moe.py | 195 +++++++++--------- 2 files changed, 131 insertions(+), 160 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index a85e64e9f166..c09e3a3bf883 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -26,21 +26,20 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs, OutputRecorder +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, use_kernel_forward_from_hub +from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_openai_moe import OpenAIMoeConfig @use_kernel_forward_from_hub("RMSNorm") class OpenAIMoeRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-5): + def __init__(self, hidden_size, eps=1e-6): """ OpenAIMoeRMSNorm is equivalent to T5LayerNorm """ @@ -125,10 +124,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] next_states = next_states.sum(dim=0) - return next_states, routing_weights + return next_states -class TopKRouter(nn.Module): +class OpenAiMoeTopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok @@ -143,7 +142,7 @@ def forward(self, hidden_states): router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + return router_scores, router_indices, router_logits @use_kernel_forward_from_hub("MegaBlocksMoeMLP") @@ -154,18 +153,16 @@ def __init__(self, config): self.experts = OpenAIMoeExperts(config) def forward(self, hidden_states): - router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) - routed_out, router_weights = self.experts( - hidden_states, router_indices=router_indices, routing_weights=router_scores - ) - return routed_out, router_scores + router_scores, router_indices, _ = self.router(hidden_states) # (num_experts, seq_len) + routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) + return routed_out class OpenAIMoeRotaryEmbedding(nn.Module): def __init__(self, config: OpenAIMoeConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" @@ -238,28 +235,17 @@ def eager_attention_forward( ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) - # TODO: check wether both produce the same results or not! - # scale the logits to prevent overflows - # combine sinks with attention weight to prevent overflow in BF16/FP16 when training with bsz>1 - # subtract per‑row max, then soft‑max once (soft-max trick) - combined_logits = torch.cat([attn_weights, sinks], dim=-1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) - # drop sink prob - scores = probs[..., :-1] - - # TODO we are going with this one to fix gradients becoming nans - # sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) - # combined_logits = torch.cat([attn_weights, sinks], dim=-1) - # scores = nn.functional.softmax(combined_logits, dim=-1)[..., :-1] + scores = probs[..., :-1] # we drop the sink here attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -313,8 +299,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -359,9 +344,10 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) + # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, @@ -373,29 +359,30 @@ def forward( **kwargs, ) hidden_states = residual + hidden_states + + # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states, _ = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states @auto_docstring class OpenAIMoePreTrainedModel(PreTrainedModel): - config_class = OpenAIMoeConfig + config: OpenAIMoeConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["OpenAIMoeDecoderLayer", "OpenAIMoeAttention"] + _no_split_modules = ["OpenAIMoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True + _supports_flash_attn = True _supports_sdpa = False _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True + + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(OpenAIMoeMLP, index=1), + "router_logits": OutputRecorder(OpenAiMoeTopKRouter, index=2), "hidden_states": OpenAIMoeDecoderLayer, "attentions": OpenAIMoeAttention, } @@ -443,12 +430,6 @@ def __init__(self, config: OpenAIMoeConfig): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - @check_model_inputs @auto_docstring def forward( @@ -487,7 +468,6 @@ def forward( "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "position_ids": position_ids, } causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), @@ -566,6 +546,7 @@ def load_balancing_loss_func( else: batch_size, sequence_length = attention_mask.shape num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] @@ -582,8 +563,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) - .reshape(-1, routing_weights.shape[1]) + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) .to(compute_device) ) @@ -592,11 +573,7 @@ def load_balancing_loss_func( router_per_expert_attention_mask, dim=0 ) - rank = routing_weights.shape[1] * int(routing_weights.device.index) - # TODO not 100% this one is correct, will fix in another PR - overall_loss = torch.sum( - tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) - ) + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts @@ -618,18 +595,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def set_decoder(self, decoder): self.model = decoder @@ -643,7 +608,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -674,9 +639,12 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" + output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index a0efcb4806a3..56b773a4845c 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -25,7 +25,13 @@ ) from ...modeling_rope_utils import dynamic_rope_update from ...processing_utils import Unpack -from ...utils import auto_docstring, can_return_tuple, logging, TransformersKwargs, OutputRecorder +from ...utils import ( + auto_docstring, + logging, + TransformersKwargs, + use_kernel_forward_from_hub, +) +from ...utils.generic import check_model_inputs, OutputRecorder from ..llama.modeling_llama import ( LlamaDecoderLayer, LlamaPreTrainedModel, @@ -62,6 +68,7 @@ def __init__(self, config): self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) self.alpha = 1.702 + self.limit = 7.0 def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: """ @@ -73,50 +80,52 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig Args: hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) selected_experts (torch.Tensor): (batch_size * token_num, top_k) - routing_weights (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, num_experts) Returns: torch.Tensor """ batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) - num_experts = routing_weights.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + num_experts = routing_weights.shape[1] if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts).permute( - 2, 1, 0 - ) + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence lenght to get which experts + # are hit this time around expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hitted: + for expert_idx in expert_hitted[:]: with torch.no_grad(): - idx, top_x = torch.where( - expert_mask[expert_idx][0] - ) # idx: top-1/top-2 indicator, top_x: token indices - current_state = hidden_states[top_x] # (num_tokens, hidden_dim) - gate_up = ( - current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] - ) # (num_tokens, 2 * interm_dim) + _, token_idx = torch.where(expert_mask[expert_idx[0]]) + current_state = hidden_states[token_idx] + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] gate, up = gate_up[..., ::2], gate_up[..., 1::2] - glu = gate * torch.sigmoid(gate * self.alpha) # (num_tokens, interm_dim) - gated_output = (up + 1) * glu # (num_tokens, interm_dim) - out = ( - gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - ) # (num_tokens, hidden_dim) - weighted_output = out * routing_weights[top_x, idx, None] # (num_tokens, hidden_dim) - next_states.index_add_(0, top_x, weighted_output.to(hidden_states.dtype)[0]) + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + weighted_output = out[0] * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.view(batch_size, -1, self.hidden_size) else: hidden_states = hidden_states.repeat(num_experts, 1) hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) glu = gate * torch.sigmoid(gate * self.alpha) next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] - next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) # (num_experts, batch_size, seq_len, hidden_size) - return next_states, None + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) + next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + return next_states + -class TopKRouter(nn.Module): +class OpenAiMoeTopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok @@ -124,27 +133,15 @@ def __init__(self, config): self.hidden_dim = config.hidden_size self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) self.bias = nn.Parameter(torch.empty(self.num_experts)) - + def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1) # (num_experts, seq_len) - return router_scores, router_indices - -class TokenDispatcher(nn.Module): - # this module is important to add EP hook - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size + router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices, router_logits - def forward(self, routed_out, routing_weights): - # routed_out is (num_experts, batch_size, seq_len, hidden_size) - routed_out = routed_out * routing_weights[:, None, :, None] # we're throwing away computed routed_out for rest of experts - routed_out = routed_out.sum(dim=0) # (batch_size, seq_len, hidden_size) - return routed_out @use_kernel_forward_from_hub("MegaBlocksMoeMLP") class OpenAIMoeMLP(nn.Module): @@ -152,14 +149,11 @@ def __init__(self, config): super().__init__() self.router = TopKRouter(config) self.experts = OpenAIMoeExperts(config) - self.token_dispatcher = TokenDispatcher(config) def forward(self, hidden_states): - # we don't slice weight as its not compile compatible - router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) - routed_out, _ = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) #TODO: router_indices isn't used inside this func - hidden_states = self.token_dispatcher(routed_out, router_scores) - return hidden_states, router_scores + router_scores, router_indices, _ = self.router(hidden_states) # (num_experts, seq_len) + routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) + return routed_out class OpenAIMoeRotaryEmbedding(LlamaRotaryEmbedding): @@ -197,6 +191,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k_embed = _apply_rotary_emb(k, cos, sin) return q_embed, k_embed + def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -209,27 +204,23 @@ def eager_attention_forward( ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - sinks = module.sinks.reshape(1, -1, 1, 1).expand( - query.shape[0], -1, query.shape[-2], -1 - ) # TODO make sure the sink is like a new token - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - # scale the logits to prevent overflows - logits_max = torch.max(attn_weights, dim=-1, keepdim=True).values - sinks = torch.exp(sinks - logits_max) - unnormalized_scores = torch.exp(attn_weights - logits_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - scores = unnormalized_scores / normalizer + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) # ignore the sinks + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights + class OpenAIMoeAttention(Qwen2Attention): def __init__(self, config: OpenAIMoeConfig, layer_idx: int): super().__init__(config, layer_idx) @@ -247,6 +238,50 @@ def __init__(self, config: OpenAIMoeConfig, layer_idx: int): ) self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # main diff with Llama + s_aux=self.sinks, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + class OpenAIMoeDecoderLayer(LlamaDecoderLayer): def __init__(self, config: OpenAIMoeConfig, layer_idx: int): @@ -258,47 +293,17 @@ def __init__(self, config: OpenAIMoeConfig, layer_idx: int): self.post_attention_layernorm = OpenAIMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs: Unpack[TransformersKwargs], - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states, _ = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - class OpenAIMoePreTrainedModel(LlamaPreTrainedModel): _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] _supports_sdpa = False _supports_flex_attention = False _can_record_outputs = { - "router_logits": OutputRecorder(OpenAIMoeMLP, index=1), + "router_logits": OutputRecorder(OpenAiMoeTopKRouter, index=2), "hidden_states": OpenAIMoeDecoderLayer, - "attentions": OpenAIMoeAttention + "attentions": OpenAIMoeAttention, } + def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): @@ -376,12 +381,10 @@ def forward( attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - **flash_attn_kwargs, + **kwargs, ) hidden_states = self.norm(hidden_states) return MoeModelOutputWithPast( From 7d8ac2ed0026e6d473a07719c127bcbbc9b317c7 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 30 Jul 2025 10:25:58 +0000 Subject: [PATCH 236/342] contiguous --- src/transformers/integrations/mxfp4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 7492d8b417cf..c4409d7e000f 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -294,7 +294,7 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** else: setattr(module, param_name.rsplit(".", 1)[1], param_value) dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) - dequantized = dequantized.transpose(1, 2).to(target_device) + dequantized = dequantized.transpose(1, 2).contiguous().to(target_device) setattr(module, proj, torch.nn.Parameter(dequantized)) delattr(module, blocks_attr) delattr(module, scales_attr) From 42ab1088cd381ce7d0098d6acbe8fc1bb7fab8fb Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 30 Jul 2025 12:43:08 +0200 Subject: [PATCH 237/342] update tests --- .../openai_moe/test_modeling_openai_moe.py | 593 ++++++------------ 1 file changed, 191 insertions(+), 402 deletions(-) diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index c81ed898116f..014f6d6cbb69 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,453 +11,242 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Testing suite for the PyTorch openai model.""" +"""Testing suite for the PyTorch OpenaiMoe model.""" -import os import unittest -import torch +import pytest from packaging import version +from parameterized import parameterized +from pytest import mark -from transformers import ( - AutoTokenizer, - OpenAIMoeConfig, - OpenAIMoeForCausalLM, - OpenAIMoeModel, - is_torch_available, -) +from tests.tensor_parallel.test_tensor_parallel import TestTensorParallel +from transformers import AutoModelForCausalLM, AutoTokenizer, OpenaiMoeConfig, is_torch_available, pipeline +from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( Expectations, cleanup, + is_flash_attn_2_available, + require_flash_attn, + require_large_cpu_ram, require_read_token, require_torch, require_torch_accelerator, - require_tokenizers, - require_tiktoken, + require_torch_large_accelerator, + require_torch_large_gpu, slow, torch_device, ) -from ...generation.test_utils import GenerationTesterMixin +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin - - -class OpenAIMoeModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return OpenAIMoeConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - ) - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = OpenAIMoeModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict + +if is_torch_available(): + import torch + + from transformers import ( + OpenAIMoeForCausalLM, + OpenAIMoeModel, + ) + + +class OpenaiMoeModelTester(CausalLMModelTester): + if is_torch_available(): + config_class = OpenaiMoeConfig + base_model_class = OpenAIMoeModel + causal_lm_class = OpenAIMoeForCausalLM + + pipeline_model_mapping = ( + { + "feature-extraction": OpenaiMoeModel, + "text-generation": OpenaiMoeForCausalLM, + } + if is_torch_available() + else {} + ) @require_torch -class OpenAIMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = ( - ( - OpenAIMoeModel, - OpenAIMoeForCausalLM, - ) +class OpenaiMoeModelTest(CausalLMModelTest, unittest.TestCase): + all_model_classes = (OpenaiMoeModel, OpenaiMoeForCausalLM) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": OpenaiMoeModel, + "text-generation": OpenaiMoeForCausalLM, + } if is_torch_available() - else () + else {} ) + test_headmasking = False test_pruning = False - fx_compatible = False # Broken by attention refactor cc @Cyrilvallez - - # Need to use `0.8` instead of `0.9` for `test_cpu_offload` - # This is because we are hitting edge cases with the causal_mask buffer - model_split_percents = [0.5, 0.7, 0.8] - - # used in `test_torch_compile_for_training` - _torch_compile_train_cls = OpenAIMoeForCausalLM if is_torch_available() else None + _is_stateful = True + model_split_percents = [0.5, 0.6] + model_tester_class = OpenaiMoeModelTester def setUp(self): - self.model_tester = OpenAIMoeModelTester(self) - self.config_tester = ConfigTester(self, config_class=OpenAIMoeConfig, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() + self.model_tester = OpenaiMoeModelTester(self) + self.config_tester = ConfigTester(self, config_class=OpenaiMoeConfig, hidden_size=37) + + @unittest.skip("Failing because of unique cache (HybridCache)") + def test_model_outputs_equivalence(self, **kwargs): + pass + + @unittest.skip("OpenaiMoe's forcefully disables sdpa due to Sink") + def test_sdpa_can_dispatch_non_composite_models(self): + pass + + @unittest.skip("OpenaiMoe's eager attn/sdpa attn outputs are expected to be different") + def test_eager_matches_sdpa_generate(self): + pass + + @parameterized.expand([("random",), ("same",)]) + @pytest.mark.generate + @unittest.skip("OpenaiMoe has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("OpenaiMoe has HybridCache which is not compatible with assisted decoding") + def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): + pass + + @pytest.mark.generate + @unittest.skip("OpenaiMoe has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip("OpenaiMoe has HybridCache which is not compatible with dola decoding") + def test_dola_decoding_sample(self): + pass + + @unittest.skip("OpenaiMoe has HybridCache and doesn't support continue from past kv") + def test_generate_continue_from_past_key_values(self): + pass + + @unittest.skip("OpenaiMoe has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate(self): + pass + + @unittest.skip("OpenaiMoe has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("OpenaiMoe has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip("OpenaiMoe has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_with_static_cache(self): + pass + + @unittest.skip("OpenaiMoe has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + @unittest.skip("OpenaiMoe has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_continue_from_inputs_embeds(self): + pass + + @unittest.skip( + reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`" + " as in Dynamic Cache doesn't work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting" + ) + def test_multi_gpu_data_parallel_forward(self): + pass - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) + @unittest.skip("OpenaiMoe has HybridCache which auto-compiles. Compile and FA2 don't work together.") + def test_eager_matches_fa2_generate(self): + pass - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) + @unittest.skip("OpenaiMoe eager/FA2 attention outputs are expected to be different") + def test_flash_attn_2_equivalence(self): + pass -@unittest.skip(reason="No available checkpoint for integration tests yet") +@slow @require_torch_accelerator -class OpenAIMoeIntegrationTest(unittest.TestCase): - def tearDown(self): - # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves - # some memory allocated in the cache, which means some object is not being released properly. This causes some - # unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU. - # Investigate the root cause. - cleanup(torch_device, gc_collect=False) +class OpenaiMoeIntegrationTest(unittest.TestCase): + input_text = ["Hello I am doing", "Hi today"] - @slow - @require_read_token - def test_openai_3_1_hard(self): - """ - An integration test for openai 3.1. It tests against a long output to ensure the subtle numerical differences - from openai 3.1.'s RoPE can be detected - """ - # diff on `EXPECTED_TEXT`: - # 2024-08-26: updating from torch 2.3.1 to 2.4.0 slightly changes the results. - EXPECTED_TEXT = ( - "Tell me about the french revolution. The french revolution was a period of radical political and social " - "upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked " - "by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the " - "First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative " - "assembly that had not met since 1614. The Third Estate, which represented the common people, " - "demanded greater representation and eventually broke away to form the National Assembly. This marked " - "the beginning of the end of the absolute monarchy and the rise of the middle class.\n" - ) - - tokenizer = AutoTokenizer.from_pretrained("meta-openai/Meta-Openai-3.1-8B-Instruct") - model = OpenAIMoeForCausalLM.from_pretrained( - "meta-openai/Meta-Openai-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16 - ) - input_text = ["Tell me about the french revolution."] - model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device) - - generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False) - generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(generated_text, EXPECTED_TEXT) + def setUp(self): + cleanup(torch_device, gc_collect=True) - @slow - @require_read_token - def test_model_7b_logits_bf16(self): - input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + def tearDown(self): + cleanup(torch_device, gc_collect=True) - model = OpenAIMoeForCausalLM.from_pretrained( - "meta-openai/Openai-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" + @staticmethod + def load_and_forward(model_id, attn_implementation, input_text, **pretrained_kwargs): + if not isinstance(attn_implementation, list): + attn_implementation = [attn_implementation] + text = [] + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, **pretrained_kwargs).to( + torch_device ) - with torch.no_grad(): - out = model(torch.tensor([input_ids]).to(torch_device)) - # Expected mean on dim = -1 - - # fmt: off - expected_means = Expectations( - { - ("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]), - ("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]), - ("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]) - }) - - expected_mean = expected_means.get_expectation() - self.assertTrue( - torch.allclose( - expected_mean.to(torch_device), - out.logits.float().mean(-1), - atol=1e-2, - rtol=1e-2 - ) - ) + for attn in attn_implementation: + model.set_attn_implementation(attn) + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(torch_device) - # slicing logits[0, 0, 0:15] - expected_slices = Expectations( - { - ("xpu", 3): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]), - ("cuda", 7): torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]), - ("cuda", 8): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]) - }) - # fmt: on - expected_slice = expected_slices.get_expectation() - self.assertTrue( - torch.allclose( - expected_slice.to(torch_device), - out.logits[0, 0, :15].float(), - atol=1e-2, - rtol=1e-2, - ) - ) + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=False) + text += [output_text] + return text - @slow + @require_torch_large_accelerator @require_read_token - def test_model_7b_logits(self): - input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] - - model = OpenAIMoeForCausalLM.from_pretrained( - "meta-openai/Openai-2-7b-hf", device_map="auto", torch_dtype=torch.float16 - ) - - with torch.no_grad(): - out = model(torch.tensor([input_ids]).to(torch_device)) - - # fmt: off - # Expected mean on dim = -1 - expected_means = Expectations( - { - ("xpu", 3): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), - ("cuda", 7): torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]), - ("cuda", 8): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), - }) - - expected_mean = expected_means.get_expectation() - self.assertTrue( - torch.allclose( - expected_mean.to(torch_device), - out.logits.float().mean(-1), - atol=1e-2, - rtol=1e-2 - ) - ) - - # slicing logits[0, 0, 0:15] - expected_slices = Expectations( - { - ("xpu", 3): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]), - ("cuda", 7): torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]), - ("cuda", 8): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]) - }) - # fmt: on - - expected_slice = expected_slices.get_expectation() - self.assertTrue( - torch.allclose( - expected_slice.to(torch_device), - out.logits[0, 0, :15].float(), - atol=1e-2, - rtol=1e-2, - ) - ) - - @slow - def test_model_7b_dola_generation(self): - # ground truth text generated with dola_layers="low", repetition_penalty=1.2 - EXPECTED_TEXT_COMPLETION = ( - "Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of " - "physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of " - "relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our " - "understanding of space and time." - ) - prompt = "Simply put, the theory of relativity states that " - tokenizer = AutoTokenizer.from_pretrained("meta-openai/Openai-2-7b-chat-hf") - model = OpenAIMoeForCausalLM.from_pretrained( - "meta-openai/Openai-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16 - ) - model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - - # greedy generation outputs - generated_ids = model.generate( - **model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low" + def test_model_20b_bf16(self): + model_id = "" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", + "Hi today I'm going to be talking about the history of the United States. The United States of America", + ] + output_text = self.load_and_forward( + model_id, + ["eager", "kernel-community/triton-flash-attn-sink", "ft-hf-o-c/vllm-flash-attn3"], + self.input_text, ) - text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + self.assertEqual(output_text[0], EXPECTED_TEXTS) + self.assertEqual(output_text[1], EXPECTED_TEXTS) + self.assertEqual(output_text[2], EXPECTED_TEXTS) - @slow - @require_torch_accelerator + @require_torch_large_accelerator @require_read_token - def test_compile_static_cache(self): - # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 - # work as intended. See https://github.com/pytorch/pytorch/issues/121943 - if version.parse(torch.__version__) < version.parse("2.3.0"): - self.skipTest(reason="This test requires torch >= 2.3 to run.") - - NUM_TOKENS_TO_GENERATE = 40 - # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test - # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. - EXPECTED_TEXT_COMPLETION = [ - "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " - "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " - "theory of relativ", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " - "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ] - - prompts = [ - "Simply put, the theory of relativity states that ", - "My favorite all time favorite condiment is ketchup.", + def test_model_20b_bf16_use_kernels(self): + model_id = "" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", + "Hi today I'm going to be talking about the history of the United States. The United States of America", ] - tokenizer = AutoTokenizer.from_pretrained("meta-openai/Openai-2-7b-hf", pad_token="", padding_side="right") - model = OpenAIMoeForCausalLM.from_pretrained( - "meta-openai/Openai-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16 + output_text = self.load_and_forward( + model_id, + ["eager", "kernel-community/triton-flash-attn-sink", "ft-hf-o-c/vllm-flash-attn3"], + self.input_text, + use_kenels=True, ) - inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + self.assertEqual(output_text[0], EXPECTED_TEXTS) + self.assertEqual(output_text[1], EXPECTED_TEXTS) + self.assertEqual(output_text[2], EXPECTED_TEXTS) - # Dynamic Cache - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) - dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) - - # Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used) - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + @require_torch_large_accelerator + @require_read_token + def test_model_120b_bf16_use_kernels(self): + model_id = "" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", + "Hi today I'm going to be talking about the history of the United States. The United States of America", + ] + output_text = self.load_and_forward( + model_id, + ["eager", "kernel-community/triton-flash-attn-sink", "ft-hf-o-c/vllm-flash-attn3"], + self.input_text, + use_kenels=True, ) - static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + self.assertEqual(output_text[0], EXPECTED_TEXTS) + self.assertEqual(output_text[1], EXPECTED_TEXTS) + self.assertEqual(output_text[2], EXPECTED_TEXTS) -class OpenAIMoeTokenizationIntegrationTest(unittest.TestCase): - """Ensure the HF tokenizer extracted with `OpenAIMoeConverter` remains byte-level identical to - the reference `o200k_harmony` tiktoken encoding. - """ +class OpenAIMoeTPTest(TestTensorParallel): - def setUp(self): - import os - import tiktoken - - # Load the HF tokenizer (fast implementation) - self.tokenizer = AutoTokenizer.from_pretrained("/fsx/vb/converted_model") - - # Build the (pre-release) o200k_harmony encoding for tiktoken - o200k_base = tiktoken.get_encoding("o200k_base") - self.tkt_encoding = tiktoken.Encoding( - name="o200k_harmony", - pat_str=o200k_base._pat_str, - mergeable_ranks=o200k_base._mergeable_ranks, - special_tokens={ - **o200k_base._special_tokens, - "<|startoftext|>": 199998, - "<|endoftext|>": 199999, - "<|return|>": 200002, - "<|constrain|>": 200003, - "<|channel|>": 200005, - "<|start|>": 200006, - "<|end|>": 200007, - "<|message|>": 200008, - "<|call|>": 200012, - }, - ) - def _assert_equivalent(self, string: str): - # Encode - ids_hf = self.tokenizer.encode(string) - ids_tk = self.tkt_encoding.encode(string, allowed_special="all") - self.assertEqual(ids_hf, ids_tk, msg=f"HF vs tiktoken mismatch on: {string!r}") - - # Decode round-trip (with special tokens preserved) - decoded_hf = self.tokenizer.decode(ids_hf, skip_special_tokens=False) - decoded_tk = self.tkt_encoding.decode(ids_tk) - self.assertEqual(decoded_hf, decoded_tk, msg=f"Decode diff on: {string!r}") - - @slow - def test_equivalence_on_public_datasets(self): - import tqdm - from datasets import load_dataset - - # 1) Code-to-text dataset - ds = load_dataset("google/code_x_glue_ct_code_to_text", "go") - for item in tqdm.tqdm(ds["validation"]): - self._assert_equivalent(item["code"]) - - # 2) XNLI premises across all languages - ds = load_dataset("facebook/xnli", "all_languages") - for item in tqdm.tqdm(ds["train"]): - for premise in item["premise"].values(): - self._assert_equivalent(premise) From e9f130a54f0f3fcda7188ada745412f15885d2c7 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 30 Jul 2025 13:03:23 +0200 Subject: [PATCH 238/342] fix top_k router call --- src/transformers/integrations/__init__.py | 17 ++- src/transformers/integrations/mxfp4.py | 142 ++++++++++-------- .../integrations/tensor_parallel.py | 3 +- src/transformers/masking_utils.py | 1 + src/transformers/modeling_utils.py | 12 +- .../convert_openai_weights_to_hf.py | 77 ++++++---- .../models/openai_moe/modeling_openai_moe.py | 2 +- .../models/openai_moe/modular_openai_moe.py | 10 +- src/transformers/quantizers/auto.py | 3 +- .../quantizers/quantizer_mxfp4.py | 113 +++++++------- src/transformers/testing_utils.py | 9 +- src/transformers/utils/__init__.py | 6 +- src/transformers/utils/generic.py | 2 +- src/transformers/utils/import_utils.py | 4 + src/transformers/utils/quantization_config.py | 2 +- .../openai_moe/test_modeling_openai_moe.py | 22 +-- tests/quantization/mxfp4/test_mxfp4.py | 2 +- tests/tensor_parallel/test_tensor_parallel.py | 128 ++++++++-------- 18 files changed, 305 insertions(+), 250 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index cbe920f9b280..01d6a7eaf4a6 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -119,7 +119,14 @@ "run_hp_search_sigopt", "run_hp_search_wandb", ], - "mxfp4": ["replace_with_mxfp4_linear", "Mxfp4OpenAIMoeExperts", "quantize_to_mxfp4", "convert_moe_packed_tensors", "dequantize", "dequantize_and_quantize"], + "mxfp4": [ + "replace_with_mxfp4_linear", + "Mxfp4OpenAIMoeExperts", + "quantize_to_mxfp4", + "convert_moe_packed_tensors", + "dequantize", + "dequantize_and_quantize", + ], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], "spqr": ["replace_with_spqr_linear"], @@ -256,7 +263,13 @@ run_hp_search_sigopt, run_hp_search_wandb, ) - from .mxfp4 import replace_with_mxfp4_linear, quantize_to_mxfp4, Mxfp4OpenAIMoeExperts, dequantize, dequantize_and_quantize + from .mxfp4 import ( + Mxfp4OpenAIMoeExperts, + dequantize, + dequantize_and_quantize, + quantize_to_mxfp4, + replace_with_mxfp4_linear, + ) from .peft import PeftAdapterMixin from .quanto import replace_with_quanto_layers from .spqr import replace_with_spqr_linear diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 7492d8b417cf..cb2e8fc9a065 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -28,10 +28,25 @@ logger = logging.get_logger(__name__) FP4_VALUES = [ - +0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, ] + # Copied from GPT_OSS repo # TODO: Add absolute link when the repo is public def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): @@ -47,16 +62,21 @@ def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): axis=1, swizzle_axis=swizzle_axis, swizzle_scale=swizzle_mx_scale, - swizzle_value=swizzle_mx_value + swizzle_value=swizzle_mx_value, ) - return w, InFlexData(), MicroscalingCtx( - weight_scale=mx_scales, - swizzle_scale=swizzle_mx_scale, - swizzle_value=swizzle_mx_value, - actual_weight_scale_shape=weight_scale_shape + return ( + w, + InFlexData(), + MicroscalingCtx( + weight_scale=mx_scales, + swizzle_scale=swizzle_mx_scale, + swizzle_value=swizzle_mx_value, + actual_weight_scale_shape=weight_scale_shape, + ), ) + # Copied from GPT_OSS repo # TODO: Add absolute link when the repo is public def convert_moe_packed_tensors( @@ -70,14 +90,12 @@ def convert_moe_packed_tensors( scales = scales.to(torch.int32) - 127 - assert blocks.shape[:-1] == scales.shape, ( - f"{blocks.shape=} does not match {scales.shape=}" - ) + assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}" lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) *prefix_shape, G, B = blocks.shape - rows_total = math.prod(prefix_shape) * G + rows_total = math.prod(prefix_shape) * G blocks = blocks.reshape(rows_total, B) scales = scales.reshape(rows_total, 1) @@ -105,9 +123,9 @@ def convert_moe_packed_tensors( return out + class Mxfp4OpenAIMoeExperts(nn.Module): def __init__(self, config): - super().__init__() self.num_experts = config.num_local_experts @@ -116,48 +134,54 @@ def __init__(self, config): self.expert_dim = self.intermediate_size self.gate_up_proj_blocks = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size//32, 16, dtype=torch.uint8), requires_grad=False, + torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size // 32, 16, dtype=torch.uint8), + requires_grad=False, ) self.gate_up_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size//32, dtype=torch.uint8), requires_grad=False, + torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size // 32, dtype=torch.uint8), + requires_grad=False, + ) + self.gate_up_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.expert_dim, dtype=torch.float32), requires_grad=False ) - self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim, dtype=torch.float32), requires_grad=False) self.down_proj_blocks = nn.Parameter( - torch.zeros((self.num_experts, self.expert_dim, self.hidden_size//32, 16), dtype=torch.uint8), requires_grad=False, + torch.zeros((self.num_experts, self.expert_dim, self.hidden_size // 32, 16), dtype=torch.uint8), + requires_grad=False, ) self.down_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, self.expert_dim, self.hidden_size//32, dtype=torch.uint8), requires_grad=False, + torch.zeros(self.num_experts, self.expert_dim, self.hidden_size // 32, dtype=torch.uint8), + requires_grad=False, + ) + self.down_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, self.expert_dim, dtype=torch.float32), requires_grad=False ) - self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.expert_dim, dtype=torch.float32), requires_grad=False) self.alpha = 1.702 - self.gate_up_proj_precision_config = None self.down_proj_precision_config = None # TODO: To remove once we make sure that we don't need this # smallest_even_divide_number = lambda x, n: (x // n + 1) * n if x % n != 0 else x - self.gate_up_proj_right_pad = 0 #smallest_even_divide_number(self.intermediate_size * 2, 256) - self.intermediate_size * 2 + self.gate_up_proj_right_pad = ( + 0 # smallest_even_divide_number(self.intermediate_size * 2, 256) - self.intermediate_size * 2 + ) self.gate_up_proj_bottom_pad = 0 - self.down_proj_right_pad = 0 #smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size - self.down_proj_bottom_pad = 0 #self.gate_up_proj_right_pad // 2 - self.hidden_size_pad = 0 #smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size + self.down_proj_right_pad = 0 # smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size + self.down_proj_bottom_pad = 0 # self.gate_up_proj_right_pad // 2 + self.hidden_size_pad = 0 # smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs from triton_kernels.swiglu import swiglu_fn with torch.cuda.device(hidden_states.device): - act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),(self.alpha, None), 2) + act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2) if self.hidden_size_pad is not None: hidden_states = torch.nn.functional.pad( - hidden_states, - (0, self.hidden_size_pad, 0, 0), - mode="constant", - value=0 + hidden_states, (0, self.hidden_size_pad, 0, 0), mode="constant", value=0 ) intermediate_cache1 = matmul_ogs( @@ -168,7 +192,7 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter gather_indx=gather_idx, precision_config=self.gate_up_proj_precision_config, gammas=None, - fused_activation=act + fused_activation=act, ) intermediate_cache3 = matmul_ogs( @@ -178,10 +202,12 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter routing_data, scatter_indx=scatter_idx, precision_config=self.down_proj_precision_config, - gammas=routing_data.gate_scal) + gammas=routing_data.gate_scal, + ) return intermediate_cache3 + # Adapted from GPT_OSS repo # TODO: Add absolute link when the repo is public def routing_torch_dist( @@ -217,11 +243,10 @@ def topk(vals, k, expt_indx): expt_indx, sort_indices = torch.sort(expt_indx, dim=1) expt_scal = torch.gather(expt_scal, 1, sort_indices) - # Flatten and mask for local experts expt_scal = expt_scal.reshape(-1) - hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1)[local_expert_start : local_expert_end] + hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1)[local_expert_start:local_expert_end] expt_indx = expt_indx.view(-1).to(torch.int32) @@ -238,7 +263,6 @@ def topk(vals, k, expt_indx): topk_indx = torch.where(gate_indx[topk_indx] == replace_value, replace_value, topk_indx) - # # Routing metadata for local expert computation gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int()) scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int()) @@ -248,12 +272,15 @@ def topk(vals, k, expt_indx): hitted_experts = n_expts_act return RoutingData(gate_scal, hist, n_local_experts, hitted_experts, expt_data), gather_indx, scatter_indx + def mlp_forward(self, hidden_states): import torch.distributed as dist + if dist.is_available() and dist.is_initialized(): routing = routing_torch_dist else: from triton_kernels.routing import routing + routing = routing hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) @@ -262,6 +289,7 @@ def mlp_forward(self, hidden_states): routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) return routed_out, router_logits + def should_convert_module(current_key_name, patterns): current_key_name_str = ".".join(current_key_name) if not any( @@ -270,6 +298,7 @@ def should_convert_module(current_key_name, patterns): return True return False + def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs): from ..integrations.tensor_parallel import shard_and_distribute_module @@ -284,7 +313,15 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** if proj in param_name: if device_mesh is not None: param_value = shard_and_distribute_module( - model, param_value, empty_param, dq_param_name, casting_dtype, to_contiguous, rank, device_mesh, set_param=False + model, + param_value, + empty_param, + dq_param_name, + casting_dtype, + to_contiguous, + rank, + device_mesh, + set_param=False, ) blocks_attr = f"{proj}_blocks" scales_attr = f"{proj}_scales" @@ -300,7 +337,10 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** delattr(module, scales_attr) return -def dequantize_and_quantize(module, param_name, param_value, target_device, swizzle_mx_value, swizzle_mx_scale, **kwargs): + +def dequantize_and_quantize( + module, param_name, param_value, target_device, swizzle_mx_value, swizzle_mx_scale, **kwargs +): from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig from ..integrations.tensor_parallel import shard_and_distribute_module @@ -327,14 +367,7 @@ def dequantize_and_quantize(module, param_name, param_value, target_device, swiz if blocks.device.type == "meta" and scales.device.type == "meta": if device_mesh is not None: shard_and_distribute_module( - model, - param_value, - empty_param, - param_name, - casting_dtype, - to_contiguous, - rank, - device_mesh + model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh ) else: _load_parameter_into_model(model, param_name, param_value) @@ -343,14 +376,7 @@ def dequantize_and_quantize(module, param_name, param_value, target_device, swiz # One of the params is already loaded, so load the other if device_mesh is not None: shard_and_distribute_module( - model, - param_value, - empty_param, - param_name, - casting_dtype, - to_contiguous, - rank, - device_mesh + model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh ) else: _load_parameter_into_model(model, param_name, param_value) @@ -361,10 +387,7 @@ def dequantize_and_quantize(module, param_name, param_value, target_device, swiz right_pad = getattr(module, right_pad_attr) bottom_pad = getattr(module, bottom_pad_attr) loaded_weight = torch.nn.functional.pad( - dequantized, - (0, right_pad, 0, bottom_pad, 0, 0), - mode="constant", - value=0 + dequantized, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0 ) del dequantized with torch.cuda.device(target_device): @@ -375,6 +398,7 @@ def dequantize_and_quantize(module, param_name, param_value, target_device, swiz return + def _replace_with_mxfp4_linear( model, modules_to_not_convert=None, @@ -394,9 +418,10 @@ def _replace_with_mxfp4_linear( if module.__class__.__name__ == "OpenAIMoeExperts" and not quantization_config.dequantize: with init_empty_weights(): model._modules[name] = Mxfp4OpenAIMoeExperts(config) - has_been_replaced=True + has_been_replaced = True if module.__class__.__name__ == "OpenAIMoeMLP" and not quantization_config.dequantize: from types import MethodType + module.forward = MethodType(mlp_forward, module) if len(list(module.children())) > 0: _, has_been_replaced = _replace_with_mxfp4_linear( @@ -418,7 +443,6 @@ def replace_with_mxfp4_linear( quantization_config=None, config=None, ): - if quantization_config.dequantize: return model @@ -434,7 +458,7 @@ def replace_with_mxfp4_linear( quantization_config, config=config, ) - if not has_been_replaced : + if not has_been_replaced: logger.warning( "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." " Please double check your model architecture, or submit an issue on github if you think this is" diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 5c4a075eb7d4..f9746aabd41a 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -21,9 +21,9 @@ import torch import torch.distributed as dist -from torch.autograd import Function from torch import nn +from typing import Optional from ..distributed import DistributedConfig from ..utils import is_torch_greater_or_equal, logging from ..utils.generic import GeneralInterface @@ -152,7 +152,6 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig "I64": torch.int64, "F8_E4M3": torch.float8_e4m3fn, "F8_E5M2": torch.float8_e5m2, - } diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 75b246785391..901572917561 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -143,6 +143,7 @@ def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offs This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, not start and end indices. """ + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: return mask_function(batch_idx, head_idx, q_idx + q_offset, kv_idx + kv_offset) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bd41477330c1..0880ca04255b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -821,16 +821,22 @@ def _load_state_dict_into_meta_model( device_mesh.get_local_rank(), device_mesh, ) - else: # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param: + else: # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param: sharding_kwargs = { "empty_param": empty_param, "casting_dtype": casting_dtype, "to_contiguous": to_contiguous, "rank": device_mesh.get_local_rank(), - "device_mesh": device_mesh + "device_mesh": device_mesh, } hf_quantizer.create_quantized_param( - model, param, param_name, device_mesh.get_local_rank(), state_dict, unexpected_keys, **sharding_kwargs + model, + param, + param_name, + device_mesh.get_local_rank(), + state_dict, + unexpected_keys, + **sharding_kwargs, ) else: param = param[...] diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index c8daf8a43ca5..69176fc96f0c 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -17,7 +17,7 @@ import json import os from pathlib import Path -from typing import List, Optional +from typing import Optional import regex as re import tiktoken @@ -74,11 +74,27 @@ def convert_old_keys_to_new_keys(state_dict_keys: Optional[dict] = None): output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) return output_dict + FP4_VALUES = [ - +0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, ] + def convert_moe_packed_tensors( blocks, scales, @@ -87,16 +103,15 @@ def convert_moe_packed_tensors( rows_per_chunk: int = 32768 * 1024, ) -> torch.Tensor: import math + scales = scales.to(torch.int32) - 127 - assert blocks.shape[:-1] == scales.shape, ( - f"{blocks.shape=} does not match {scales.shape=}" - ) + assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}" lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) *prefix_shape, G, B = blocks.shape - rows_total = math.prod(prefix_shape) * G + rows_total = math.prod(prefix_shape) * G blocks = blocks.reshape(rows_total, B) scales = scales.reshape(rows_total, 1) @@ -141,13 +156,15 @@ def write_model( rope_scaling = { "beta_fast": float(original_config.pop("rope_ntk_beta")), "beta_slow": float(original_config.pop("rope_ntk_alpha")), - "factor": float(original_config.pop('rope_scaling_factor')), + "factor": float(original_config.pop("rope_scaling_factor")), "rope_type": "yarn", "truncate": False, - "original_max_position_embeddings": 4096 - } + "original_max_position_embeddings": 4096, + } - config = OpenAIMoeConfig(num_local_experts=num_local_experts, rope_scaling=rope_scaling, eos_token_id=eos_token_id, **original_config) + config = OpenAIMoeConfig( + num_local_experts=num_local_experts, rope_scaling=rope_scaling, eos_token_id=eos_token_id, **original_config + ) print(f"Fetching all parameters from the checkpoint at {input_base_path}...") final_ = {} @@ -188,12 +205,12 @@ def write_model( # deal with packed weights blocks = final_[key] scales = final_[key.replace("blocks", "scales")] - new_key = new_key.replace(".blocks","") + new_key = new_key.replace(".blocks", "") unpacked_tensors = convert_moe_packed_tensors(blocks, scales, dtype=torch.bfloat16) unpacked_tensors = unpacked_tensors.permute(0, 2, 1).contiguous() # einsum in orignal, I use bmm state_dict[new_key] = unpacked_tensors else: - raise(f"Unidentified {key}, please double check the state dict") + raise (f"Unidentified {key}, please double check the state dict") else: if "scales" in new_key: new_key = new_key.replace(".scales", "_scales") @@ -202,7 +219,7 @@ def write_model( new_key = new_key.replace(".blocks", "_blocks") state_dict[new_key] = final_[key].contiguous() else: - raise(f"Unidentified {key}, please double check the state dict") + raise (f"Unidentified {key}, please double check the state dict") else: weight = final_[key] if not re.search("norm", new_key): @@ -227,13 +244,14 @@ def write_model( else: print("Saving the checkpoint in mxfp4 format") config.quantization_config = { - "quant_method": "mxfp4", - "modules_to_not_convert":[ - "model.layers.*.self_attn", - "model.layers.*.mlp.router", - "model.embed_tokens", - "lm_head" - ]} + "quant_method": "mxfp4", + "modules_to_not_convert": [ + "model.layers.*.self_attn", + "model.layers.*.mlp.router", + "model.embed_tokens", + "lm_head", + ], + } config.save_pretrained(model_path) save_sharded_model(state_dict, model_path) del state_dict @@ -263,7 +281,7 @@ def write_model( def save_sharded_model(state_dict, model_path): from safetensors.torch import save_file - max_shard_size = 4800000000 # 4.8 GB + max_shard_size = 4800000000 # 4.8 GB os.makedirs(model_path, exist_ok=True) shard_size_counter = 0 shard_id = 0 @@ -273,7 +291,7 @@ def save_sharded_model(state_dict, model_path): safetensors_index["metadata"] = {"total_size": 0} safetensors_index["weight_map"] = {} for key in state_dict.keys(): - size = state_dict[key].numel()*state_dict[key].element_size() + size = state_dict[key].numel() * state_dict[key].element_size() safetensors_index["metadata"]["total_size"] += size safetensors_index["weight_map"][key] = shard_id if shard_size_counter + size > max_shard_size: @@ -286,12 +304,10 @@ def save_sharded_model(state_dict, model_path): total_sharded_dict[shard_id] = shard_state_dict num_shards = len(total_sharded_dict) - 1 for shard_id, shard_state_dict in total_sharded_dict.items(): - save_file( - shard_state_dict, - os.path.join(model_path, f"model-{shard_id:05d}-of-{num_shards:05d}.safetensors") - ) + save_file(shard_state_dict, os.path.join(model_path, f"model-{shard_id:05d}-of-{num_shards:05d}.safetensors")) create_safetensors_index(safetensors_index, num_shards, model_path) + def create_safetensors_index(safetensors_index, num_shards, model_path): for key in safetensors_index["weight_map"].keys(): shard_id = safetensors_index["weight_map"][key] @@ -386,9 +402,7 @@ def __init__( special_tokens_map.setdefault(f"<|reserved_{k}|>", k) # Keep only token strings (sorted by ID) for TikTokenConverter. - self.additional_special_tokens = [ - tok for tok, _ in sorted(special_tokens_map.items(), key=lambda x: x[1]) - ] + self.additional_special_tokens = [tok for tok, _ in sorted(special_tokens_map.items(), key=lambda x: x[1])] tokenizer = self.converted() if chat_template is not None: kwargs["chat_template"] = chat_template @@ -402,6 +416,7 @@ def __init__( **kwargs, ) + def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): # Updated Harmony chat template chat_template = """{# Harmony chat template -------------------------------------------------- @@ -505,7 +520,7 @@ def main(): parser.add_argument( "--special_tokens", default=None, - type=List[str], + type=list[str], help="The list of special tokens that should be added to the ", ) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index c09e3a3bf883..5556e4c5ba81 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -149,7 +149,7 @@ def forward(self, hidden_states): class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() - self.router = TopKRouter(config) + self.router = OpenAiMoeTopKRouter(config) self.experts = OpenAIMoeExperts(config) def forward(self, hidden_states): diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 56b773a4845c..f784fd528a74 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Optional import torch from torch import nn @@ -26,12 +26,12 @@ from ...modeling_rope_utils import dynamic_rope_update from ...processing_utils import Unpack from ...utils import ( + TransformersKwargs, auto_docstring, logging, - TransformersKwargs, use_kernel_forward_from_hub, ) -from ...utils.generic import check_model_inputs, OutputRecorder +from ...utils.generic import OutputRecorder, check_model_inputs from ..llama.modeling_llama import ( LlamaDecoderLayer, LlamaPreTrainedModel, @@ -147,7 +147,7 @@ def forward(self, hidden_states): class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() - self.router = TopKRouter(config) + self.router = OpenAiMoeTopKRouter(config) self.experts = OpenAIMoeExperts(config) def forward(self, hidden_states): @@ -335,7 +335,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 88f133ac713d..f5f959310c6b 100644 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -215,7 +215,8 @@ def merge_quantization_configs( if ( isinstance( - quantization_config, (GPTQConfig, AwqConfig, AutoRoundConfig, FbgemmFp8Config, CompressedTensorsConfig, Mxfp4Config) + quantization_config, + (GPTQConfig, AwqConfig, AutoRoundConfig, FbgemmFp8Config, CompressedTensorsConfig, Mxfp4Config), ) and quantization_config_from_args is not None ): diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index dc20ab3cb284..74f99a822aaa 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -60,9 +60,7 @@ def validate_environment(self, *args, **kwargs): raise RuntimeError("Using MXFP4 quantized models requires a GPU") if not is_accelerate_available(): - raise ImportError( - "Using mxfp4 requires Accelerate: `pip install accelerate`" - ) + raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`") if self.quantization_config.dequantize: return @@ -79,11 +77,9 @@ def validate_environment(self, *args, **kwargs): return else: # we can't quantize the model in this case so we raise an error - raise ValueError( - "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed" - ) + raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed") - if major < 9 : + if major < 9: raise ValueError( "MXFP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100, or B100)" ) @@ -106,6 +102,7 @@ def validate_environment(self, *args, **kwargs): "Please use a quantized checkpoint or remove the CPU or disk device from the device_map." ) from triton_kernels.numerics_details.mxfp import SwizzlingType + # TODO: Explain what swizzle_mx_value and swizzle_mx_scale are if major < 9: # NYI for Ampere @@ -146,11 +143,13 @@ def check_quantized_param( # if we are dequantizing, the model doesn't have scales, and blocks only params like gate_up_proj and down_proj so we need to handle this case differently if self.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name): - module, tensor_name = get_module_from_name(model, param_name[:-len("_blocks")]) + module, tensor_name = get_module_from_name(model, param_name[: -len("_blocks")]) else: module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module, Mxfp4OpenAIMoeExperts) or (isinstance(module, OpenAIMoeExperts) and self.quantization_config.dequantize): + if isinstance(module, Mxfp4OpenAIMoeExperts) or ( + isinstance(module, OpenAIMoeExperts) and self.quantization_config.dequantize + ): if tensor_name in ["down_proj_bias", "gate_up_proj_bias"]: return False return True @@ -179,28 +178,24 @@ def create_quantized_param( if "gate_up_proj" in param_name: right_pad = module.gate_up_proj_right_pad bottom_pad = module.gate_up_proj_bottom_pad - loaded_weight = torch.nn.functional.pad(param_value, - (0, right_pad, 0, bottom_pad, 0, 0), - mode="constant", - value=0) + loaded_weight = torch.nn.functional.pad( + param_value, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0 + ) loaded_weight, flex, mx = quantize_to_mxfp4( - loaded_weight, - self.swizzle_mx_value, - self.swizzle_mx_scale + loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale + ) + module.gate_up_proj_precision_config = PrecisionConfig( + mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex) ) - module.gate_up_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) elif "down_proj" in param_name: right_pad = module.down_proj_right_pad bottom_pad = module.down_proj_bottom_pad - loaded_weight = torch.nn.functional.pad(param_value, - (0, right_pad, 0, bottom_pad, 0, 0), - mode="constant", - value=0).to(target_device) + loaded_weight = torch.nn.functional.pad( + param_value, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0 + ).to(target_device) loaded_weight, flex, mx = quantize_to_mxfp4( - loaded_weight, - self.swizzle_mx_value, - self.swizzle_mx_scale + loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale ) module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) @@ -214,32 +209,27 @@ def create_quantized_param( device_mesh = kwargs.get("device_mesh", None) if ("blocks" in param_name or "scales" in param_name) and self.quantization_config.dequantize: # blocks and scales have the same length that's this works for both - module, _ = get_module_from_name(model, param_name[:-len("_blocks")]) + module, _ = get_module_from_name(model, param_name[: -len("_blocks")]) else: module, _ = get_module_from_name(model, param_name) shard_kwargs = { - "empty_param": empty_param, - "casting_dtype": casting_dtype, - "to_contiguous": to_contiguous, - "rank": rank, - "device_mesh": device_mesh, - "model": model, + "empty_param": empty_param, + "casting_dtype": casting_dtype, + "to_contiguous": to_contiguous, + "rank": rank, + "device_mesh": device_mesh, + "model": model, } - if isinstance(module, Mxfp4OpenAIMoeExperts) or (isinstance(module, OpenAIMoeExperts) and self.quantization_config.dequantize): + if isinstance(module, Mxfp4OpenAIMoeExperts) or ( + isinstance(module, OpenAIMoeExperts) and self.quantization_config.dequantize + ): if self.quantization_config.dequantize: # dq_param_name is the name of the parameter without the blocks or scales suffix, it's used in this case since we don't switch linears # so we only have the original param name - dq_param_name = param_name[:-len("_blocks")] - dequantize( - module, - param_name, - param_value, - target_device, - dq_param_name, - **shard_kwargs - ) + dq_param_name = param_name[: -len("_blocks")] + dequantize(module, param_name, param_value, target_device, dq_param_name, **shard_kwargs) else: dequantize_and_quantize( module, @@ -248,8 +238,8 @@ def create_quantized_param( target_device, self.swizzle_mx_value, self.swizzle_mx_scale, - **shard_kwargs - ) + **shard_kwargs, + ) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): # we are not really dequantizing, we are just removing everthing related to quantization here @@ -261,11 +251,11 @@ def update_expected_keys(self, model: "PreTrainedModel", expected_keys: list[str new_expected_keys = [] for key in expected_keys: if key.endswith(".mlp.experts.gate_up_proj"): - base = key[:-len("gate_up_proj")] + base = key[: -len("gate_up_proj")] new_expected_keys.append(base + "gate_up_proj_blocks") new_expected_keys.append(base + "gate_up_proj_scales") elif key.endswith(".mlp.experts.down_proj"): - base = key[:-len("down_proj")] + base = key[: -len("down_proj")] new_expected_keys.append(base + "down_proj_blocks") new_expected_keys.append(base + "down_proj_scales") else: @@ -311,28 +301,33 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li def update_tp_plan(self, config): if "OpenAIMoeConfig" in config.__class__.__name__: - if (not hasattr(config, "base_model_tp_plan") or config.base_model_tp_plan is None) or (not hasattr(config, "base_model_ep_plan") or config.base_model_ep_plan is None): + if (not hasattr(config, "base_model_tp_plan") or config.base_model_tp_plan is None) or ( + not hasattr(config, "base_model_ep_plan") or config.base_model_ep_plan is None + ): return config # Update TP plan with scales and blocks - config.base_model_tp_plan.update({ - "layers.*.mlp.experts.gate_up_proj_blocks": "local_packed_rowwise", - "layers.*.mlp.experts.gate_up_proj_scales": "local_packed_rowwise", - "layers.*.mlp.experts.down_proj_blocks": "local_colwise", - "layers.*.mlp.experts.down_proj_scales": "local_colwise", - }) + config.base_model_tp_plan.update( + { + "layers.*.mlp.experts.gate_up_proj_blocks": "local_packed_rowwise", + "layers.*.mlp.experts.gate_up_proj_scales": "local_packed_rowwise", + "layers.*.mlp.experts.down_proj_blocks": "local_colwise", + "layers.*.mlp.experts.down_proj_scales": "local_colwise", + } + ) # Update EP plan with scales and blocks - config.base_model_ep_plan.update({ - "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm", - "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm", - "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm", - "layers.*.mlp.experts.down_proj_scales": "grouped_gemm", - }) + config.base_model_ep_plan.update( + { + "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm", + "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm", + "layers.*.mlp.experts.down_proj_scales": "grouped_gemm", + } + ) return config - def update_param_name(self, param_name: str) -> str: if self.quantization_config.dequantize: if "_blocks" in param_name: diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 181eff65cb02..020564dd2454 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -457,23 +457,26 @@ def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION): is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}" )(test_case) + def require_triton(min_version: str = TRITON_MIN_VERSION): """ Decorator marking a test that requires triton. These tests are skipped when triton isn't installed. """ + def decorator(test_case): return unittest.skipUnless(is_triton_available(min_version), f"test requires triton version >= {min_version}")( test_case ) + return decorator + def require_triton_kernels(test_case): """ Decorator marking a test that requires triton_kernels. These tests are skipped when triton_kernels isn't installed. """ - return unittest.skipUnless(is_triton_kernels_availalble(), "test requires triton_kernels")( - test_case - ) + return unittest.skipUnless(is_triton_kernels_availalble(), "test requires triton_kernels")(test_case) + def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION): """ diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 41ad23e7b11e..9d31896e8330 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -114,10 +114,10 @@ from .import_utils import ( ACCELERATE_MIN_VERSION, ENV_VARS_TRUE_AND_AUTO_VALUES, - TRITON_MIN_VERSION, ENV_VARS_TRUE_VALUES, GGUF_MIN_VERSION, TORCH_FX_REQUIRED_VERSION, + TRITON_MIN_VERSION, USE_JAX, USE_TF, USE_TORCH, @@ -268,6 +268,8 @@ is_torchvision_available, is_torchvision_v2_available, is_training_run_on_sagemaker, + is_triton_available, + is_triton_kernels_availalble, is_uroman_available, is_vision_available, is_vptq_available, @@ -275,8 +277,6 @@ is_yt_dlp_available, requires_backends, torch_only_method, - is_triton_available, - is_triton_kernels_availalble, ) from .peft_utils import ( ADAPTER_CONFIG_NAME, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index bf0792055389..92badb421510 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -1039,7 +1039,7 @@ def wrapped_forward(*args, **kwargs): collected_outputs[key] += (output,) elif output[index] is not None: if key not in collected_outputs: - collected_outputs[key] =(output[index],) + collected_outputs[key] = (output[index],) else: collected_outputs[key] += (output[index],) return output diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 1dfe478a34f4..d787320484f3 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -408,12 +408,15 @@ def is_torch_deterministic(): return False + def is_triton_available(min_version: str = TRITON_MIN_VERSION): return _triton_available and version.parse(_triton_version) >= version.parse(min_version) + def is_triton_kernels_availalble(): return _triton_kernels_available + def is_hadamard_available(): return _hadamard_available @@ -1591,6 +1594,7 @@ def is_liger_kernel_available(): return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0") + def is_rich_available(): return _rich_available diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 1069d02e4765..ea807b9c51a9 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -2076,4 +2076,4 @@ def __init__( def get_loading_attributes(self): return { "dequantize": self.dequantize, - } \ No newline at end of file + } diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index 014f6d6cbb69..2dc2a9cd4c71 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -16,24 +16,16 @@ import unittest import pytest -from packaging import version from parameterized import parameterized -from pytest import mark -from tests.tensor_parallel.test_tensor_parallel import TestTensorParallel -from transformers import AutoModelForCausalLM, AutoTokenizer, OpenaiMoeConfig, is_torch_available, pipeline -from transformers.generation.configuration_utils import GenerationConfig +from tests.tensor_parallel.test_tensor_parallel import TensorParallelTestBase +from transformers import AutoModelForCausalLM, AutoTokenizer, OpenaiMoeConfig, is_torch_available from transformers.testing_utils import ( - Expectations, cleanup, - is_flash_attn_2_available, - require_flash_attn, - require_large_cpu_ram, require_read_token, require_torch, require_torch_accelerator, require_torch_large_accelerator, - require_torch_large_gpu, slow, torch_device, ) @@ -247,6 +239,14 @@ def test_model_120b_bf16_use_kernels(self): self.assertEqual(output_text[1], EXPECTED_TEXTS) self.assertEqual(output_text[2], EXPECTED_TEXTS) -class OpenAIMoeTPTest(TestTensorParallel): +class OpenAIMoeTPTest(TensorParallelTestBase): + def test_model_training(self): + self.run_tensor_parallel_test( + model_id="openai/openai-moe-20b", mode="training", expected_output="you with something?" + ) + def test_model_generate(self): + self.run_tensor_parallel_test( + model_id="openai/openai-moe-20b", mode="generate", expected_output="with something" + ) diff --git a/tests/quantization/mxfp4/test_mxfp4.py b/tests/quantization/mxfp4/test_mxfp4.py index 7dead0fff025..70f451b25b9f 100644 --- a/tests/quantization/mxfp4/test_mxfp4.py +++ b/tests/quantization/mxfp4/test_mxfp4.py @@ -417,4 +417,4 @@ def test_memory_footprint_comparison(self): ) quantized_mem = quantized_model.get_memory_footprint() dequantized_mem = dequantized_model.get_memory_footprint() - self.assertLess(quantized_mem, dequantized_mem) \ No newline at end of file + self.assertLess(quantized_mem, dequantized_mem) diff --git a/tests/tensor_parallel/test_tensor_parallel.py b/tests/tensor_parallel/test_tensor_parallel.py index 1904fc8bd1e7..ff64c1296ecc 100644 --- a/tests/tensor_parallel/test_tensor_parallel.py +++ b/tests/tensor_parallel/test_tensor_parallel.py @@ -18,7 +18,7 @@ import subprocess import tempfile import textwrap - +from typing import Optional from transformers import is_torch_available from transformers.integrations.tensor_parallel import get_packed_weights, repack_weights from transformers.testing_utils import ( @@ -64,76 +64,68 @@ def size(self): assert torch.allclose(unpacked_weights, original_packed_weights) -class TestTensorParallel(TestCasePlus): +class TensorParallelTestBase(TestCasePlus): nproc_per_node = 2 - def torchrun(self, script: str, is_torchrun: bool = True): - """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary.""" + def run_torch_distributed_test(self, script: str, is_torchrun: bool = True): + """Run the given Python script in a subprocess using torchrun or python3.""" with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp: tmp.write(script) tmp.flush() tmp.seek(0) - if is_torchrun: - cmd = ( - f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}" + cmd = ( + ( + f"torchrun --nproc_per_node {self.nproc_per_node} " + f"--master_port {get_torch_dist_unique_port()} {tmp.name}" ).split() - else: - cmd = ["python3", tmp.name] + if is_torchrun + else ["python3", tmp.name] + ) - # Note that the subprocess will be waited for here, and raise an error if not successful try: - _ = subprocess.run(cmd, capture_output=True, env=self.get_env(), text=True, check=True) + subprocess.run(cmd, capture_output=True, env=self.get_env(), text=True, check=True) except subprocess.CalledProcessError as e: - raise Exception(f"The following error was captured: {e.stderr}") - - def test_model_forward(self): - script_to_run = textwrap.dedent( - """ - import torch - import os - from transformers import AutoModelForCausalLM, AutoTokenizer - - model_id = "JackFram/llama-68m" - - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto") - torch.distributed.barrier() - - has_dtensor = 0 - for name, parameter in model.named_parameters(): - if isinstance(parameter.data, torch.distributed.tensor.DTensor): - has_dtensor = 1 - break - - assert has_dtensor == 1, "TP model must has DTensor" - - tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=False) - prompt = "Can I help" - - inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) - outputs = model(inputs) + raise Exception(f"Subprocess failed with:\nSTDOUT:\n{e.stdout}\n\nSTDERR:\n{e.stderr}") + + def run_tensor_parallel_test(self, model_id: str, mode: str = "training", expected_output: str = None): + """ + Runs a tensor-parallel test for either training (forward) or generate mode. + + Args: + model_id: The model to test. + mode: "training" or "generate". + expected_output: Token or string to assert for training mode. + """ + if mode not in ("training", "generate"): + raise ValueError(f"Invalid mode '{mode}', must be 'training' or 'generate'") + + # Only the outputs line changes between training and generate + outputs_line = ( + "outputs = model(inputs)" + if mode == "training" + else 'outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static")' + ) + # Expected assertion differs slightly depending on the mode + if mode == "training": + assertion = f""" next_token_logits = outputs[0][:, -1, :] next_token = torch.argmax(next_token_logits, dim=-1) response = tokenizer.decode(next_token) - assert response == "with" - - torch.distributed.barrier() - torch.distributed.destroy_process_group() + assert response == "{expected_output}", f"Expected token '{{expected_output}}', got '{{response}}'" """ - ) - self.torchrun(script_to_run) - - def test_model_generate(self): - script_to_run = textwrap.dedent( + else: + assertion = """ + output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) + assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'" """ + + script = f""" import torch import os from transformers import AutoModelForCausalLM, AutoTokenizer - model_id = "JackFram/llama-68m" + model_id = "{model_id}" rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) @@ -141,30 +133,24 @@ def test_model_generate(self): model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto") torch.distributed.barrier() - model.forward = torch.compile(model.forward) - - has_dtensor = 0 - for name, parameter in model.named_parameters(): - if isinstance(parameter.data, torch.distributed.tensor.DTensor): - has_dtensor = 1 - break + has_dtensor = any( + isinstance(p.data, torch.distributed.tensor.DTensor) + for _, p in model.named_parameters() + ) + assert has_dtensor, "TP model must have DTensor" - assert has_dtensor == 1, "TP model must has DTensor" - - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=False) prompt = "Can I help" inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) - outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static") + {outputs_line} - output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) - assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'" + {assertion} torch.distributed.barrier() torch.distributed.destroy_process_group() - """ - ) - self.torchrun(script_to_run) + """ + self.run_torch_distributed_test(textwrap.dedent(script)) @require_huggingface_hub_greater_or_equal("0.31.4") def test_model_save(self): @@ -209,6 +195,14 @@ def test_model_save(self): del non_tp_tensor, tp_tensor +class TestTensorParallel(TensorParallelTestBase): + def test_model_training(self): + self.run_tensor_parallel_test(model_id="JackFram/llama-68m", mode="training", expected_output="with") + + def test_model_generate(self): + self.run_tensor_parallel_test(model_id="JackFram/llama-68m", mode="generate") + + @require_torch_multi_accelerator class TestTensorParallelAccelerator(TestTensorParallel): nproc_per_node = backend_device_count(torch_device) From da77d5e3979a4849ea74cd1f067dca3678478f83 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 30 Jul 2025 13:21:34 +0200 Subject: [PATCH 239/342] revert renaming --- .../integrations/tensor_parallel.py | 2 +- .../openai_moe/configuration_openai_moe.py | 15 +---- .../models/openai_moe/modeling_openai_moe.py | 9 +-- .../models/openai_moe/modular_openai_moe.py | 8 +-- .../openai_moe/test_modeling_openai_moe.py | 58 +++++++++---------- tests/tensor_parallel/test_tensor_parallel.py | 2 +- 6 files changed, 41 insertions(+), 53 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index f9746aabd41a..df9340cdf1ef 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -18,12 +18,12 @@ import os import re from functools import partial, reduce +from typing import Optional import torch import torch.distributed as dist from torch import nn -from typing import Optional from ..distributed import DistributedConfig from ..utils import is_torch_greater_or_equal, logging from ..utils.generic import GeneralInterface diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/openai_moe/configuration_openai_moe.py index 581a623280c4..76b4e4fca651 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/openai_moe/configuration_openai_moe.py @@ -26,25 +26,12 @@ class OpenAIMoeConfig(PretrainedConfig): """ model_type = "openai_moe" - base_model_tp_plan = { - # "embed_tokens": "vocab_parallel_rowwise", - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.self_attn.sinks": "local_rowwise", - "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise", - "layers.*.mlp.experts.gate_up_proj_bias": "local_packed_rowwise", - "layers.*.mlp.experts.down_proj": "local_colwise", - "layers.*.mlp.experts.down_proj_bias": "local", # TODO: add smthg that says bias exists only once for all TPs - # "layers.*.mlp": "gather", - } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } - base_model_ep_plan = { + base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 5556e4c5ba81..088d3575f198 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -26,13 +26,14 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations.hub_kernels import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, use_kernel_forward_from_hub +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_openai_moe import OpenAIMoeConfig @@ -127,7 +128,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig return next_states -class OpenAiMoeTopKRouter(nn.Module): +class OpenAIMoeTopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok @@ -149,7 +150,7 @@ def forward(self, hidden_states): class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() - self.router = OpenAiMoeTopKRouter(config) + self.router = OpenAIMoeTopKRouter(config) self.experts = OpenAIMoeExperts(config) def forward(self, hidden_states): @@ -382,7 +383,7 @@ class OpenAIMoePreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(OpenAiMoeTopKRouter, index=2), + "router_logits": OutputRecorder(OpenAIMoeTopKRouter, index=2), "hidden_states": OpenAIMoeDecoderLayer, "attentions": OpenAIMoeAttention, } diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index f784fd528a74..32eb60ad75bd 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -19,6 +19,7 @@ from torch.nn import functional as F from ...cache_utils import Cache, DynamicCache +from ...integrations.hub_kernels import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_outputs import ( MoeModelOutputWithPast, @@ -29,7 +30,6 @@ TransformersKwargs, auto_docstring, logging, - use_kernel_forward_from_hub, ) from ...utils.generic import OutputRecorder, check_model_inputs from ..llama.modeling_llama import ( @@ -125,7 +125,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig return next_states -class OpenAiMoeTopKRouter(nn.Module): +class OpenAIMoeTopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok @@ -147,7 +147,7 @@ def forward(self, hidden_states): class OpenAIMoeMLP(nn.Module): def __init__(self, config): super().__init__() - self.router = OpenAiMoeTopKRouter(config) + self.router = OpenAIMoeTopKRouter(config) self.experts = OpenAIMoeExperts(config) def forward(self, hidden_states): @@ -299,7 +299,7 @@ class OpenAIMoePreTrainedModel(LlamaPreTrainedModel): _supports_sdpa = False _supports_flex_attention = False _can_record_outputs = { - "router_logits": OutputRecorder(OpenAiMoeTopKRouter, index=2), + "router_logits": OutputRecorder(OpenAIMoeTopKRouter, index=2), "hidden_states": OpenAIMoeDecoderLayer, "attentions": OpenAIMoeAttention, } diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index 2dc2a9cd4c71..bdfbdfeeb8e6 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Testing suite for the PyTorch OpenaiMoe model.""" +"""Testing suite for the PyTorch OpenAIMoe model.""" import unittest @@ -19,7 +19,7 @@ from parameterized import parameterized from tests.tensor_parallel.test_tensor_parallel import TensorParallelTestBase -from transformers import AutoModelForCausalLM, AutoTokenizer, OpenaiMoeConfig, is_torch_available +from transformers import AutoModelForCausalLM, AutoTokenizer, OpenAIMoeConfig, is_torch_available from transformers.testing_utils import ( cleanup, require_read_token, @@ -43,16 +43,16 @@ ) -class OpenaiMoeModelTester(CausalLMModelTester): +class OpenAIMoeModelTester(CausalLMModelTester): if is_torch_available(): - config_class = OpenaiMoeConfig + config_class = OpenAIMoeConfig base_model_class = OpenAIMoeModel causal_lm_class = OpenAIMoeForCausalLM pipeline_model_mapping = ( { - "feature-extraction": OpenaiMoeModel, - "text-generation": OpenaiMoeForCausalLM, + "feature-extraction": OpenAIMoeModel, + "text-generation": OpenAIMoeForCausalLM, } if is_torch_available() else {} @@ -60,12 +60,12 @@ class OpenaiMoeModelTester(CausalLMModelTester): @require_torch -class OpenaiMoeModelTest(CausalLMModelTest, unittest.TestCase): - all_model_classes = (OpenaiMoeModel, OpenaiMoeForCausalLM) if is_torch_available() else () +class OpenAIMoeModelTest(CausalLMModelTest, unittest.TestCase): + all_model_classes = (OpenAIMoeModel, OpenAIMoeForCausalLM) if is_torch_available() else () pipeline_model_mapping = ( { - "feature-extraction": OpenaiMoeModel, - "text-generation": OpenaiMoeForCausalLM, + "feature-extraction": OpenAIMoeModel, + "text-generation": OpenAIMoeForCausalLM, } if is_torch_available() else {} @@ -75,68 +75,68 @@ class OpenaiMoeModelTest(CausalLMModelTest, unittest.TestCase): test_pruning = False _is_stateful = True model_split_percents = [0.5, 0.6] - model_tester_class = OpenaiMoeModelTester + model_tester_class = OpenAIMoeModelTester def setUp(self): - self.model_tester = OpenaiMoeModelTester(self) - self.config_tester = ConfigTester(self, config_class=OpenaiMoeConfig, hidden_size=37) + self.model_tester = OpenAIMoeModelTester(self) + self.config_tester = ConfigTester(self, config_class=OpenAIMoeConfig, hidden_size=37) @unittest.skip("Failing because of unique cache (HybridCache)") def test_model_outputs_equivalence(self, **kwargs): pass - @unittest.skip("OpenaiMoe's forcefully disables sdpa due to Sink") + @unittest.skip("OpenAIMoe's forcefully disables sdpa due to Sink") def test_sdpa_can_dispatch_non_composite_models(self): pass - @unittest.skip("OpenaiMoe's eager attn/sdpa attn outputs are expected to be different") + @unittest.skip("OpenAIMoe's eager attn/sdpa attn outputs are expected to be different") def test_eager_matches_sdpa_generate(self): pass @parameterized.expand([("random",), ("same",)]) @pytest.mark.generate - @unittest.skip("OpenaiMoe has HybridCache which is not compatible with assisted decoding") + @unittest.skip("OpenAIMoe has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_matches_greedy_search(self, assistant_type): pass - @unittest.skip("OpenaiMoe has HybridCache which is not compatible with assisted decoding") + @unittest.skip("OpenAIMoe has HybridCache which is not compatible with assisted decoding") def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): pass @pytest.mark.generate - @unittest.skip("OpenaiMoe has HybridCache which is not compatible with assisted decoding") + @unittest.skip("OpenAIMoe has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_sample(self): pass - @unittest.skip("OpenaiMoe has HybridCache which is not compatible with dola decoding") + @unittest.skip("OpenAIMoe has HybridCache which is not compatible with dola decoding") def test_dola_decoding_sample(self): pass - @unittest.skip("OpenaiMoe has HybridCache and doesn't support continue from past kv") + @unittest.skip("OpenAIMoe has HybridCache and doesn't support continue from past kv") def test_generate_continue_from_past_key_values(self): pass - @unittest.skip("OpenaiMoe has HybridCache and doesn't support contrastive generation") + @unittest.skip("OpenAIMoe has HybridCache and doesn't support contrastive generation") def test_contrastive_generate(self): pass - @unittest.skip("OpenaiMoe has HybridCache and doesn't support contrastive generation") + @unittest.skip("OpenAIMoe has HybridCache and doesn't support contrastive generation") def test_contrastive_generate_dict_outputs_use_cache(self): pass - @unittest.skip("OpenaiMoe has HybridCache and doesn't support contrastive generation") + @unittest.skip("OpenAIMoe has HybridCache and doesn't support contrastive generation") def test_contrastive_generate_low_memory(self): pass - @unittest.skip("OpenaiMoe has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + @unittest.skip("OpenAIMoe has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") def test_generate_with_static_cache(self): pass - @unittest.skip("OpenaiMoe has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + @unittest.skip("OpenAIMoe has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip("OpenaiMoe has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + @unittest.skip("OpenAIMoe has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") def test_generate_continue_from_inputs_embeds(self): pass @@ -147,18 +147,18 @@ def test_generate_continue_from_inputs_embeds(self): def test_multi_gpu_data_parallel_forward(self): pass - @unittest.skip("OpenaiMoe has HybridCache which auto-compiles. Compile and FA2 don't work together.") + @unittest.skip("OpenAIMoe has HybridCache which auto-compiles. Compile and FA2 don't work together.") def test_eager_matches_fa2_generate(self): pass - @unittest.skip("OpenaiMoe eager/FA2 attention outputs are expected to be different") + @unittest.skip("OpenAIMoe eager/FA2 attention outputs are expected to be different") def test_flash_attn_2_equivalence(self): pass @slow @require_torch_accelerator -class OpenaiMoeIntegrationTest(unittest.TestCase): +class OpenAIMoeIntegrationTest(unittest.TestCase): input_text = ["Hello I am doing", "Hi today"] def setUp(self): diff --git a/tests/tensor_parallel/test_tensor_parallel.py b/tests/tensor_parallel/test_tensor_parallel.py index ff64c1296ecc..b2ac43b14d1b 100644 --- a/tests/tensor_parallel/test_tensor_parallel.py +++ b/tests/tensor_parallel/test_tensor_parallel.py @@ -18,7 +18,7 @@ import subprocess import tempfile import textwrap -from typing import Optional + from transformers import is_torch_available from transformers.integrations.tensor_parallel import get_packed_weights, repack_weights from transformers.testing_utils import ( From 5b0bd40285d23b9b4ab47dc87d44564ce4bf023e Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 30 Jul 2025 13:29:27 +0200 Subject: [PATCH 240/342] test nits --- tests/models/openai_moe/test_modeling_openai_moe.py | 8 ++++++++ tests/tensor_parallel/test_tensor_parallel.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index bdfbdfeeb8e6..6de0a1b25595 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -26,6 +26,7 @@ require_torch, require_torch_accelerator, require_torch_large_accelerator, + require_torch_multi_accelerator, slow, torch_device, ) @@ -240,13 +241,20 @@ def test_model_120b_bf16_use_kernels(self): self.assertEqual(output_text[2], EXPECTED_TEXTS) +@require_torch_multi_accelerator class OpenAIMoeTPTest(TensorParallelTestBase): def test_model_training(self): self.run_tensor_parallel_test( model_id="openai/openai-moe-20b", mode="training", expected_output="you with something?" ) + self.run_tensor_parallel_test( + model_id="openai/openai-moe-120b", mode="training", expected_output="you with something?" + ) def test_model_generate(self): self.run_tensor_parallel_test( model_id="openai/openai-moe-20b", mode="generate", expected_output="with something" ) + self.run_tensor_parallel_test( + model_id="openai/openai-moe-120b", mode="generate", expected_output="with something" + ) diff --git a/tests/tensor_parallel/test_tensor_parallel.py b/tests/tensor_parallel/test_tensor_parallel.py index b2ac43b14d1b..30282f22d74a 100644 --- a/tests/tensor_parallel/test_tensor_parallel.py +++ b/tests/tensor_parallel/test_tensor_parallel.py @@ -177,7 +177,7 @@ def test_model_save(self): model.save_pretrained(result_dir) """ ) - self.torchrun(script_to_run, is_torchrun=is_torchrun) + self.run_torch_distributed_test(script_to_run, is_torchrun=is_torchrun) non_tp_model_path = os.path.join(tmp_dir, "nontp") tp_model_path = os.path.join(tmp_dir, "tp") From b43d2cd40eb581a348ec643a4af6ebc04aa0c7cf Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 30 Jul 2025 13:59:15 +0200 Subject: [PATCH 241/342] small fixes for EP --- src/transformers/integrations/tensor_parallel.py | 2 +- tests/models/openai_moe/test_modeling_openai_moe.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index df9340cdf1ef..d7c96b2091ce 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -939,7 +939,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me """ ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size() num_local_experts = mod.num_experts // ep_size - router_scores, router_indices = outputs + router_scores, router_indices = outputs[0], outputs[1] router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts] router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, 0) router_indices = router_indices % num_local_experts diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index 6de0a1b25595..d6db76b8c2af 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -187,7 +187,6 @@ def load_and_forward(model_id, attn_implementation, input_text, **pretrained_kwa text += [output_text] return text - @require_torch_large_accelerator @require_read_token def test_model_20b_bf16(self): model_id = "" @@ -204,7 +203,6 @@ def test_model_20b_bf16(self): self.assertEqual(output_text[1], EXPECTED_TEXTS) self.assertEqual(output_text[2], EXPECTED_TEXTS) - @require_torch_large_accelerator @require_read_token def test_model_20b_bf16_use_kernels(self): model_id = "" @@ -222,7 +220,6 @@ def test_model_20b_bf16_use_kernels(self): self.assertEqual(output_text[1], EXPECTED_TEXTS) self.assertEqual(output_text[2], EXPECTED_TEXTS) - @require_torch_large_accelerator @require_read_token def test_model_120b_bf16_use_kernels(self): model_id = "" @@ -241,6 +238,7 @@ def test_model_120b_bf16_use_kernels(self): self.assertEqual(output_text[2], EXPECTED_TEXTS) +@slow @require_torch_multi_accelerator class OpenAIMoeTPTest(TensorParallelTestBase): def test_model_training(self): From 13ec4ef3a052178f23a71753452c2ec57c99ffe3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 30 Jul 2025 11:59:40 +0000 Subject: [PATCH 242/342] fix path for our local tests --- tests/models/openai_moe/test_modeling_openai_moe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index 6de0a1b25595..930b2b686c89 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -245,16 +245,16 @@ def test_model_120b_bf16_use_kernels(self): class OpenAIMoeTPTest(TensorParallelTestBase): def test_model_training(self): self.run_tensor_parallel_test( - model_id="openai/openai-moe-20b", mode="training", expected_output="you with something?" + model_id="/fsx/vb/new-oai/20b-converted-quantized", mode="training", expected_output="you with something?" ) self.run_tensor_parallel_test( - model_id="openai/openai-moe-120b", mode="training", expected_output="you with something?" + model_id="/fsx/vb/new-oai/120b-converted-quantized", mode="training", expected_output="you with something?" ) def test_model_generate(self): self.run_tensor_parallel_test( - model_id="openai/openai-moe-20b", mode="generate", expected_output="with something" + model_id="/fsx/vb/new-oai/20b-converted-quantized", mode="generate", expected_output="with something" ) self.run_tensor_parallel_test( - model_id="openai/openai-moe-120b", mode="generate", expected_output="with something" + model_id="/fsx/vb/new-oai/120b-converted-quantized", mode="generate", expected_output="with something" ) From 0276225a686d4a0adb0f0e216652843adaaec7da Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 30 Jul 2025 14:04:21 +0200 Subject: [PATCH 243/342] update as I should not have broken that! --- src/transformers/integrations/tensor_parallel.py | 2 +- src/transformers/models/openai_moe/modeling_openai_moe.py | 6 +++--- src/transformers/models/openai_moe/modular_openai_moe.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index d7c96b2091ce..df9340cdf1ef 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -939,7 +939,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me """ ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size() num_local_experts = mod.num_experts // ep_size - router_scores, router_indices = outputs[0], outputs[1] + router_scores, router_indices = outputs router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts] router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, 0) router_indices = router_indices % num_local_experts diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 088d3575f198..ac6de23fb52c 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -143,7 +143,7 @@ def forward(self, hidden_states): router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices, router_logits + return router_scores, router_indices @use_kernel_forward_from_hub("MegaBlocksMoeMLP") @@ -154,7 +154,7 @@ def __init__(self, config): self.experts = OpenAIMoeExperts(config) def forward(self, hidden_states): - router_scores, router_indices, _ = self.router(hidden_states) # (num_experts, seq_len) + router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) return routed_out @@ -383,7 +383,7 @@ class OpenAIMoePreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(OpenAIMoeTopKRouter, index=2), + "router_logits": OutputRecorder(OpenAIMoeTopKRouter, index=0), "hidden_states": OpenAIMoeDecoderLayer, "attentions": OpenAIMoeAttention, } diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index 32eb60ad75bd..b91d83fd457b 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -140,7 +140,7 @@ def forward(self, hidden_states): router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices, router_logits + return router_scores, router_indices @use_kernel_forward_from_hub("MegaBlocksMoeMLP") @@ -151,7 +151,7 @@ def __init__(self, config): self.experts = OpenAIMoeExperts(config) def forward(self, hidden_states): - router_scores, router_indices, _ = self.router(hidden_states) # (num_experts, seq_len) + router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) return routed_out @@ -299,7 +299,7 @@ class OpenAIMoePreTrainedModel(LlamaPreTrainedModel): _supports_sdpa = False _supports_flex_attention = False _can_record_outputs = { - "router_logits": OutputRecorder(OpenAIMoeTopKRouter, index=2), + "router_logits": OutputRecorder(OpenAIMoeTopKRouter, index=0), "hidden_states": OpenAIMoeDecoderLayer, "attentions": OpenAIMoeAttention, } From a34b39ca02adfa248dd2e9c4f4d6b0b7a8cdfb34 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 30 Jul 2025 14:13:34 +0200 Subject: [PATCH 244/342] fix the loss of mixtral --- src/transformers/models/mixtral/modeling_mixtral.py | 9 ++++++--- .../models/openai_moe/modeling_openai_moe.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 043862a3a2c0..d52aaea702c1 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -543,8 +543,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) .to(compute_device) ) @@ -553,7 +553,10 @@ def load_balancing_loss_func( router_per_expert_attention_mask, dim=0 ) - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + rank = routing_weights.shape[1] * int(routing_weights.device.index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) return overall_loss * num_experts diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index ac6de23fb52c..8a4f17db189f 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -564,8 +564,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) .to(compute_device) ) @@ -574,7 +574,10 @@ def load_balancing_loss_func( router_per_expert_attention_mask, dim=0 ) - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + rank = routing_weights.shape[1] * int(routing_weights.device.index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) return overall_loss * num_experts From e7cc5914aa1fa2b965a0605785fda9c9fa657302 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 30 Jul 2025 14:15:56 +0200 Subject: [PATCH 245/342] revert part of the changes related to router_scores, kernel probably no ready for that! --- .../models/openai_moe/modeling_openai_moe.py | 8 ++-- .../models/openai_moe/modular_openai_moe.py | 39 +++++++++++++++++-- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/openai_moe/modeling_openai_moe.py index 8a4f17db189f..48e6e9f260e1 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/openai_moe/modeling_openai_moe.py @@ -156,7 +156,7 @@ def __init__(self, config): def forward(self, hidden_states): router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) - return routed_out + return routed_out, router_scores class OpenAIMoeRotaryEmbedding(nn.Module): @@ -315,8 +315,8 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=self.sliding_window, # main diff with Llama - s_aux=self.sinks, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama **kwargs, ) @@ -364,7 +364,7 @@ def forward( # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) + hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores hidden_states = residual + hidden_states return hidden_states diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/openai_moe/modular_openai_moe.py index b91d83fd457b..1c10ee01d791 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/openai_moe/modular_openai_moe.py @@ -153,7 +153,7 @@ def __init__(self, config): def forward(self, hidden_states): router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) - return routed_out + return routed_out, router_scores class OpenAIMoeRotaryEmbedding(LlamaRotaryEmbedding): @@ -273,8 +273,8 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=self.sliding_window, # main diff with Llama - s_aux=self.sinks, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama **kwargs, ) @@ -293,6 +293,39 @@ def __init__(self, config: OpenAIMoeConfig, layer_idx: int): self.post_attention_layernorm = OpenAIMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores + hidden_states = residual + hidden_states + return hidden_states + class OpenAIMoePreTrainedModel(LlamaPreTrainedModel): _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] From f1245b4caad8cd85bf89b1cc95e5d0d90ab0040e Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 30 Jul 2025 14:26:05 +0200 Subject: [PATCH 246/342] deleting a small nit --- src/transformers/integrations/tensor_parallel.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index df9340cdf1ef..245b25802a54 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -880,8 +880,6 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, param = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype) if to_contiguous: param = param.contiguous() - if "gate_up" in param_type and False: - param = torch.cat([param[..., ::2], param[..., 1::2]], dim=-1) return param From 9b387ca9e5a15b2e45e7772ce103b32dd30a19ae Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 30 Jul 2025 13:02:30 +0000 Subject: [PATCH 247/342] update arch --- .../models/openai_moe/convert_openai_weights_to_hf.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 69176fc96f0c..401201d47fe8 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -252,17 +252,16 @@ def write_model( "lm_head", ], } + # required as we don't save the model with save_pretrained + config.architectures = ["OpenAIMoeForCausalLM"] config.save_pretrained(model_path) save_sharded_model(state_dict, model_path) del state_dict - # Safety check: reload the converted model gc.collect() - # TODO: remove when mxfp4 pr is merged - if not mxfp4: - print("Reloading the model to check if it's saved correctly.") - OpenAIMoeForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") - print("Model reloaded successfully.") + print("Reloading the model to check if it's saved correctly.") + OpenAIMoeForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") + print("Model reloaded successfully.") # generation config if instruct: From 6c0effa971eeccc81592a6b84d1d2ce35a5f7113 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Wed, 30 Jul 2025 13:13:33 +0000 Subject: [PATCH 248/342] fix post processing --- src/transformers/modeling_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 057b05f9aee2..f38cc16959f4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -884,7 +884,8 @@ def _load_state_dict_into_meta_model( param_name = hf_quantizer.update_param_name(param_name) module, param_type = get_module_from_name(model, param_name) value = getattr(module, param_type) - if value.device.type == "meta": + # special case for OpenAIMoeForCausalLM, we wait for the param to be leave the meta device before casting it to cpu + if model.__class__.__name__ == "OpenAIMoeForCausalLM" and value.device.type == "meta": continue param_to = "cpu" if is_fsdp_enabled() and not is_local_dist_rank_0(): @@ -5121,8 +5122,8 @@ def _assign_original_dtype(module): dispatch_model(model, **device_map_kwargs) if hf_quantizer is not None: - hf_quantizer.postprocess_model(model, config=config) model.hf_quantizer = hf_quantizer + hf_quantizer.postprocess_model(model, config=config) if _adapter_model_path is not None: adapter_kwargs["key_mapping"] = key_mapping From ab0f92951169206dbac4b53e48a156a44dc7f5ea Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 30 Jul 2025 14:10:41 +0000 Subject: [PATCH 249/342] update --- src/transformers/integrations/mxfp4.py | 116 +++++++++++++----- .../integrations/tensor_parallel.py | 1 + .../quantizers/quantizer_mxfp4.py | 36 ++---- src/transformers/utils/import_utils.py | 5 + 4 files changed, 96 insertions(+), 62 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index cb2e8fc9a065..f39a52173d5b 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -49,32 +49,84 @@ # Copied from GPT_OSS repo # TODO: Add absolute link when the repo is public -def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): - from triton_kernels.matmul_ogs import InFlexData, MicroscalingCtx +# def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): +# from triton_kernels.matmul_ogs import InFlexData, MicroscalingCtx +# from triton_kernels.numerics_details.mxfp import downcast_to_mxfp + +# swizzle_axis = 2 if swizzle_mx_scale or swizzle_mx_value else None +# w = w.to(torch.bfloat16) + +# w, mx_scales, weight_scale_shape = downcast_to_mxfp( +# w, +# torch.uint8, +# axis=1, +# swizzle_axis=swizzle_axis, +# swizzle_scale=swizzle_mx_scale, +# swizzle_value=swizzle_mx_value, +# ) + +# return ( +# w, +# InFlexData(), +# MicroscalingCtx( +# weight_scale=mx_scales, +# swizzle_scale=swizzle_mx_scale, +# swizzle_value=swizzle_mx_value, +# actual_weight_scale_shape=weight_scale_shape, +# ), +# ) + +def quantize_to_mxfp4(w): from triton_kernels.numerics_details.mxfp import downcast_to_mxfp - - swizzle_axis = 2 if swizzle_mx_scale or swizzle_mx_value else None - w = w.to(torch.bfloat16) - - w, mx_scales, weight_scale_shape = downcast_to_mxfp( - w, - torch.uint8, - axis=1, - swizzle_axis=swizzle_axis, - swizzle_scale=swizzle_mx_scale, - swizzle_value=swizzle_mx_value, - ) - - return ( - w, - InFlexData(), - MicroscalingCtx( - weight_scale=mx_scales, - swizzle_scale=swizzle_mx_scale, - swizzle_value=swizzle_mx_value, - actual_weight_scale_shape=weight_scale_shape, - ), - ) + from triton_kernels.tensor import convert_layout + from triton_kernels.tensor import wrap_torch_tensor, FP4 + from triton_kernels.tensor_details import layout + + # FIXME warp need to be adjusted based on batch size + # only apply to batched mode + num_warps = 8 + + w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1) + value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) + scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1,num_warps=num_warps) + + #TODO: check if needed + # if current_platform.is_cuda() and \ + # torch.cuda.get_device_capability()[0] == 10: + # constraints = { + # "is_persistent": True, + # "epilogue_subtile": 1, + # } + # opt_flags.update_opt_flags_constraints(constraints) + # # transpose the tensor so that the quantization axis is on dim1 + # quant_tensor = quant_tensor.transpose(-2, -1) + # scale = scale.transpose(-2, -1) + + w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) + w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts) + return w, w_scale + +# def swizzle_mxfp4(quant_tensor, scale, num_warps): +# value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout( +# mx_axis=1) +# scale_layout, scale_layout_opts = ( +# layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, +# num_warps=num_warps)) +# if current_platform.is_cuda() and \ +# torch.cuda.get_device_capability()[0] == 10: +# constraints = { +# "is_persistent": True, +# "epilogue_subtile": 1, +# } +# opt_flags.update_opt_flags_constraints(constraints) +# # transpose the tensor so that the quantization axis is on dim1 +# quant_tensor = quant_tensor.transpose(-2, -1) +# scale = scale.transpose(-2, -1) +# quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4), +# value_layout, **value_layout_opts) +# scale = convert_layout(wrap_torch_tensor(scale), scale_layout, +# **scale_layout_opts) +# return quant_tensor, InFlexData(), scale # Copied from GPT_OSS repo @@ -339,9 +391,9 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** def dequantize_and_quantize( - module, param_name, param_value, target_device, swizzle_mx_value, swizzle_mx_scale, **kwargs + module, param_name, param_value, target_device, **kwargs ): - from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, InFlexData from ..integrations.tensor_parallel import shard_and_distribute_module from ..modeling_utils import _load_parameter_into_model @@ -386,16 +438,14 @@ def dequantize_and_quantize( right_pad = getattr(module, right_pad_attr) bottom_pad = getattr(module, bottom_pad_attr) - loaded_weight = torch.nn.functional.pad( + dequantized = torch.nn.functional.pad( dequantized, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0 ) - del dequantized with torch.cuda.device(target_device): - loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, swizzle_mx_value, swizzle_mx_scale) - - setattr(module, precision_config_attr, PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex))) - setattr(module, proj, torch.nn.Parameter(loaded_weight, requires_grad=False)) + quantized, weight_scale = quantize_to_mxfp4(dequantized) + setattr(module, precision_config_attr, PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()))) + setattr(module, proj, torch.nn.Parameter(quantized, requires_grad=False)) return diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 245b25802a54..7898907d73c7 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -1085,6 +1085,7 @@ def shard_and_distribute_module( logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}") if current_shard_plan is not None: + print(current_shard_plan) try: tp_layer = ALL_PARALLEL_STYLES[current_shard_plan] param = tp_layer.partition_tensor( diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 74f99a822aaa..b3e2db105845 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -101,22 +101,6 @@ def validate_environment(self, *args, **kwargs): "This is not supported when the model is quantized on the fly. " "Please use a quantized checkpoint or remove the CPU or disk device from the device_map." ) - from triton_kernels.numerics_details.mxfp import SwizzlingType - - # TODO: Explain what swizzle_mx_value and swizzle_mx_scale are - if major < 9: - # NYI for Ampere - swizzle_mx_value = None - swizzle_mx_scale = None - elif major < 10: - swizzle_mx_value = SwizzlingType.HOPPER - swizzle_mx_scale = None - else: - swizzle_mx_value = None - swizzle_mx_scale = SwizzlingType.BLACKWELL - - self.swizzle_mx_value = swizzle_mx_value - self.swizzle_mx_scale = swizzle_mx_scale def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: @@ -166,7 +150,7 @@ def create_quantized_param( **kwargs, ): if is_triton_kernels_availalble(): - from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, InFlexData from ..integrations import Mxfp4OpenAIMoeExperts, dequantize, dequantize_and_quantize, quantize_to_mxfp4 from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts @@ -181,24 +165,20 @@ def create_quantized_param( loaded_weight = torch.nn.functional.pad( param_value, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0 ) - loaded_weight, flex, mx = quantize_to_mxfp4( - loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale - ) + quantized, weight_scale = quantize_to_mxfp4(loaded_weight) module.gate_up_proj_precision_config = PrecisionConfig( - mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex) + weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()) ) - module.gate_up_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) + module.gate_up_proj = torch.nn.Parameter(quantized, requires_grad=False) elif "down_proj" in param_name: right_pad = module.down_proj_right_pad bottom_pad = module.down_proj_bottom_pad loaded_weight = torch.nn.functional.pad( param_value, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0 ).to(target_device) - loaded_weight, flex, mx = quantize_to_mxfp4( - loaded_weight, self.swizzle_mx_value, self.swizzle_mx_scale - ) - module.down_proj_precision_config = PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex)) - module.down_proj = torch.nn.Parameter(loaded_weight, requires_grad=False) + quantized, weight_scale = quantize_to_mxfp4(loaded_weight) + module.down_proj_precision_config = PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())) + module.down_proj = torch.nn.Parameter(quantized, requires_grad=False) # we take this path if already quantized but not in a compatible way # The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales else: @@ -236,8 +216,6 @@ def create_quantized_param( param_name, param_value, target_device, - self.swizzle_mx_value, - self.swizzle_mx_scale, **shard_kwargs, ) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index d787320484f3..9f8e21e89945 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -76,6 +76,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ package_version = importlib.metadata.version("amd-quark") except Exception: package_exists = False + elif pkg_name == "triton": + try: + package_version = importlib.metadata.version("pytorch-triton") + except Exception: + package_exists = False else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False From c80bd4482360b71cb6e7b2c2486b9de5862060c1 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 30 Jul 2025 14:43:27 +0000 Subject: [PATCH 250/342] running version but not expected output --- src/transformers/integrations/mxfp4.py | 6 +-- .../integrations/tensor_parallel.py | 1 - .../quantizers/quantizer_mxfp4.py | 37 ++++++------------- 3 files changed, 15 insertions(+), 29 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index f39a52173d5b..f36249db9fb2 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -442,10 +442,10 @@ def dequantize_and_quantize( dequantized, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0 ) with torch.cuda.device(target_device): - quantized, weight_scale = quantize_to_mxfp4(dequantized) - + triton_weight_tensor, weight_scale = quantize_to_mxfp4(dequantized) setattr(module, precision_config_attr, PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()))) - setattr(module, proj, torch.nn.Parameter(quantized, requires_grad=False)) + setattr(module, proj, triton_weight_tensor) + setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False)) return diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 7898907d73c7..245b25802a54 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -1085,7 +1085,6 @@ def shard_and_distribute_module( logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}") if current_shard_plan is not None: - print(current_shard_plan) try: tp_layer = ALL_PARALLEL_STYLES[current_shard_plan] param = tp_layer.partition_tensor( diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index b3e2db105845..b78dd2b0fed1 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -165,20 +165,23 @@ def create_quantized_param( loaded_weight = torch.nn.functional.pad( param_value, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0 ) - quantized, weight_scale = quantize_to_mxfp4(loaded_weight) + triton_weight_tensor, weight_scale = quantize_to_mxfp4(loaded_weight) module.gate_up_proj_precision_config = PrecisionConfig( weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()) ) - module.gate_up_proj = torch.nn.Parameter(quantized, requires_grad=False) + module.gate_up_proj = triton_weight_tensor + # module.gate_up_proj_blocks = torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False) elif "down_proj" in param_name: right_pad = module.down_proj_right_pad bottom_pad = module.down_proj_bottom_pad loaded_weight = torch.nn.functional.pad( param_value, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0 ).to(target_device) - quantized, weight_scale = quantize_to_mxfp4(loaded_weight) + triton_weight_tensor, weight_scale = quantize_to_mxfp4(loaded_weight) module.down_proj_precision_config = PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())) - module.down_proj = torch.nn.Parameter(quantized, requires_grad=False) + module.down_proj = triton_weight_tensor + # module.down_proj_blocks = torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False) + # we take this path if already quantized but not in a compatible way # The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales else: @@ -279,31 +282,15 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li def update_tp_plan(self, config): if "OpenAIMoeConfig" in config.__class__.__name__: - if (not hasattr(config, "base_model_tp_plan") or config.base_model_tp_plan is None) or ( - not hasattr(config, "base_model_ep_plan") or config.base_model_ep_plan is None - ): - return config - - # Update TP plan with scales and blocks - config.base_model_tp_plan.update( - { - "layers.*.mlp.experts.gate_up_proj_blocks": "local_packed_rowwise", - "layers.*.mlp.experts.gate_up_proj_scales": "local_packed_rowwise", - "layers.*.mlp.experts.down_proj_blocks": "local_colwise", - "layers.*.mlp.experts.down_proj_scales": "local_colwise", - } - ) - - # Update EP plan with scales and blocks - config.base_model_ep_plan.update( - { + if getattr(config, "base_model_tp_plan", None) is not None: + config.base_model_tp_plan.update( + { "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm", "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm", "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm", "layers.*.mlp.experts.down_proj_scales": "grouped_gemm", - } - ) - + } + ) return config def update_param_name(self, param_name: str) -> str: From dc1251830c0427130df9980ed77b2bcf5d8958aa Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 31 Jul 2025 10:29:46 +0000 Subject: [PATCH 251/342] moving to cuda --- src/transformers/integrations/mxfp4.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 512e79978ac3..3e93da66c3f4 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -88,6 +88,11 @@ def convert_moe_packed_tensors( ) -> torch.Tensor: import math + # Check if blocks and scales are on CPU, and move to GPU if so + if not blocks.is_cuda and torch.cuda.is_available(): + blocks = blocks.cuda() + scales = scales.cuda() + scales = scales.to(torch.int32) - 127 assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}" @@ -121,6 +126,11 @@ def convert_moe_packed_tensors( out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + # TODO: Delete after making sure this is not necessary! since we go back to cpu in the end in create_quantized_param using .to(target_device) + # Move back to CPU if needed + # if need_to_move_back: + # out = out.cpu() + return out @@ -332,6 +342,9 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** setattr(module, param_name.rsplit(".", 1)[1], param_value) dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) dequantized = dequantized.transpose(1, 2).contiguous().to(target_device) + # TODO: this is perhaps necessary since if target_device is cpu, and the param was on gpu + if target_device == "cpu" and torch.cuda.is_available(): + torch.cuda.empty_cache() setattr(module, proj, torch.nn.Parameter(dequantized)) delattr(module, blocks_attr) delattr(module, scales_attr) From 20dfa56da4a45258198db2cce61981ad06bcf7d0 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 31 Jul 2025 10:33:35 +0000 Subject: [PATCH 252/342] initial commit --- .../integrations/integration_utils.py | 201 +----------------- 1 file changed, 6 insertions(+), 195 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 41a32b63acbb..da43d0213cc3 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -833,211 +833,22 @@ def setup(self, args, state, model, **kwargs): - **WANDB_DISABLED** (`bool`, *optional*, defaults to `False`): Whether to disable wandb entirely. Set `WANDB_DISABLED=true` to disable. """ - if self._wandb is None: - return - self._initialized = True - - # prepare to handle potential configuration issues during setup - from wandb.sdk.lib.config_util import ConfigError as WandbConfigError - - if state.is_world_process_zero: - logger.info( - 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' - ) - combined_dict = {**args.to_dict()} - - if hasattr(model, "config") and model.config is not None: - model_config = model.config if isinstance(model.config, dict) else model.config.to_dict() - combined_dict = {**model_config, **combined_dict} - if hasattr(model, "peft_config") and model.peft_config is not None: - peft_config = model.peft_config - combined_dict = {**{"peft_config": peft_config}, **combined_dict} - trial_name = state.trial_name - init_args = {} - if trial_name is not None: - init_args["name"] = trial_name - init_args["group"] = args.run_name - elif args.run_name is not None: - init_args["name"] = args.run_name - if args.run_name == args.output_dir: - self._wandb.termwarn( - "The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was " - "not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.", - repeat=False, - ) - - if self._wandb.run is None: - self._wandb.init( - project=os.getenv("WANDB_PROJECT", "huggingface"), - **init_args, - ) - # add config parameters (run may have been created manually) - self._wandb.config.update(combined_dict or {}, allow_val_change=True) - - # define default x-axis (for latest wandb versions) - if getattr(self._wandb, "define_metric", None): - self._wandb.define_metric("train/global_step") - self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True) - - # keep track of model topology and gradients, unsupported on TPU - _watch_model = os.getenv("WANDB_WATCH", "false") - if not is_torch_xla_available() and _watch_model in ("all", "parameters", "gradients"): - self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps)) - self._wandb.run._label(code="transformers_trainer") - - # add number of model parameters to wandb config - try: - self._wandb.config["model/num_parameters"] = model.num_parameters() - except AttributeError: - logger.info( - "Could not log the number of model parameters in Weights & Biases due to an AttributeError." - ) - except WandbConfigError: - logger.warning( - "A ConfigError was raised whilst setting the number of model parameters in Weights & Biases config." - ) - - # log the initial model architecture to an artifact - if self._log_model.is_enabled: - with tempfile.TemporaryDirectory() as temp_dir: - model_name = ( - f"model-{self._wandb.run.id}" - if (args.run_name is None or args.run_name == args.output_dir) - else f"model-{self._wandb.run.name}" - ) - model_artifact = self._wandb.Artifact( - name=model_name, - type="model", - metadata={ - "model_config": model.config.to_dict() if hasattr(model, "config") else None, - "num_parameters": self._wandb.config.get("model/num_parameters"), - "initial_model": True, - }, - ) - # add the architecture to a separate text file - save_model_architecture_to_file(model, temp_dir) - - for f in Path(temp_dir).glob("*"): - if f.is_file(): - with model_artifact.new_file(f.name, mode="wb") as fa: - fa.write(f.read_bytes()) - self._wandb.run.log_artifact(model_artifact, aliases=["base_model"]) - - badge_markdown = ( - f'[Visualize in Weights & Biases]({self._wandb.run.url})' - ) - - modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" + pass def on_train_begin(self, args, state, control, model=None, **kwargs): - if self._wandb is None: - return - hp_search = state.is_hyper_param_search - if hp_search: - self._wandb.finish() - self._initialized = False - args.run_name = None - if not self._initialized: - self.setup(args, state, model, **kwargs) + pass def on_train_end(self, args: TrainingArguments, state, control, model=None, processing_class=None, **kwargs): - if self._wandb is None: - return - if self._log_model.is_enabled and self._initialized and state.is_world_process_zero: - from ..trainer import Trainer - - args_for_fake = copy.deepcopy(args) - args_for_fake.deepspeed = None - args_for_fake.deepspeed_plugin = None - fake_trainer = Trainer( - args=args_for_fake, model=model, processing_class=processing_class, eval_dataset=["fake"] - ) - with tempfile.TemporaryDirectory() as temp_dir: - fake_trainer.save_model(temp_dir) - metadata = ( - { - k: v - for k, v in dict(self._wandb.summary).items() - if isinstance(v, numbers.Number) and not k.startswith("_") - } - if not args.load_best_model_at_end - else { - f"eval/{args.metric_for_best_model}": state.best_metric, - "train/total_floss": state.total_flos, - "model/num_parameters": self._wandb.config.get("model/num_parameters"), - } - ) - metadata["final_model"] = True - logger.info("Logging model artifacts. ...") - model_name = ( - f"model-{self._wandb.run.id}" - if (args.run_name is None or args.run_name == args.output_dir) - else f"model-{self._wandb.run.name}" - ) - # add the model architecture to a separate text file - save_model_architecture_to_file(model, temp_dir) - - artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata) - for f in Path(temp_dir).glob("*"): - if f.is_file(): - with artifact.new_file(f.name, mode="wb") as fa: - fa.write(f.read_bytes()) - self._wandb.run.log_artifact(artifact, aliases=["final_model"]) + pass def on_log(self, args, state, control, model=None, logs=None, **kwargs): - single_value_scalars = [ - "train_runtime", - "train_samples_per_second", - "train_steps_per_second", - "train_loss", - "total_flos", - ] - - if self._wandb is None: - return - if not self._initialized: - self.setup(args, state, model) - if state.is_world_process_zero: - for k, v in logs.items(): - if k in single_value_scalars: - self._wandb.run.summary[k] = v - non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars} - non_scalar_logs = rewrite_logs(non_scalar_logs) - self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step}) + pass def on_save(self, args, state, control, **kwargs): - if self._log_model == WandbLogModel.CHECKPOINT and self._initialized and state.is_world_process_zero: - checkpoint_metadata = { - k: v - for k, v in dict(self._wandb.summary).items() - if isinstance(v, numbers.Number) and not k.startswith("_") - } - checkpoint_metadata["model/num_parameters"] = self._wandb.config.get("model/num_parameters") - - ckpt_dir = f"checkpoint-{state.global_step}" - artifact_path = os.path.join(args.output_dir, ckpt_dir) - logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...") - checkpoint_name = ( - f"model-{self._wandb.run.id}" - if (args.run_name is None or args.run_name == args.output_dir) - else f"model-{self._wandb.run.name}" - ) - artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata) - artifact.add_dir(artifact_path) - self._wandb.log_artifact( - artifact, aliases=[f"epoch_{round(state.epoch, 2)}", f"checkpoint_global_step_{state.global_step}"] - ) + pass def on_predict(self, args, state, control, metrics, **kwargs): - if self._wandb is None: - return - if not self._initialized: - self.setup(args, state, **kwargs) - if state.is_world_process_zero: - metrics = rewrite_logs(metrics) - self._wandb.log(metrics) + pass class TrackioCallback(TrainerCallback): From 228a982652095a8d14fa65a690b23ed0b225e6a7 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 31 Jul 2025 10:35:26 +0000 Subject: [PATCH 253/342] revert --- .../integrations/integration_utils.py | 201 +++++++++++++++++- 1 file changed, 195 insertions(+), 6 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index da43d0213cc3..41a32b63acbb 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -833,22 +833,211 @@ def setup(self, args, state, model, **kwargs): - **WANDB_DISABLED** (`bool`, *optional*, defaults to `False`): Whether to disable wandb entirely. Set `WANDB_DISABLED=true` to disable. """ - pass + if self._wandb is None: + return + self._initialized = True + + # prepare to handle potential configuration issues during setup + from wandb.sdk.lib.config_util import ConfigError as WandbConfigError + + if state.is_world_process_zero: + logger.info( + 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' + ) + combined_dict = {**args.to_dict()} + + if hasattr(model, "config") and model.config is not None: + model_config = model.config if isinstance(model.config, dict) else model.config.to_dict() + combined_dict = {**model_config, **combined_dict} + if hasattr(model, "peft_config") and model.peft_config is not None: + peft_config = model.peft_config + combined_dict = {**{"peft_config": peft_config}, **combined_dict} + trial_name = state.trial_name + init_args = {} + if trial_name is not None: + init_args["name"] = trial_name + init_args["group"] = args.run_name + elif args.run_name is not None: + init_args["name"] = args.run_name + if args.run_name == args.output_dir: + self._wandb.termwarn( + "The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was " + "not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.", + repeat=False, + ) + + if self._wandb.run is None: + self._wandb.init( + project=os.getenv("WANDB_PROJECT", "huggingface"), + **init_args, + ) + # add config parameters (run may have been created manually) + self._wandb.config.update(combined_dict or {}, allow_val_change=True) + + # define default x-axis (for latest wandb versions) + if getattr(self._wandb, "define_metric", None): + self._wandb.define_metric("train/global_step") + self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True) + + # keep track of model topology and gradients, unsupported on TPU + _watch_model = os.getenv("WANDB_WATCH", "false") + if not is_torch_xla_available() and _watch_model in ("all", "parameters", "gradients"): + self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps)) + self._wandb.run._label(code="transformers_trainer") + + # add number of model parameters to wandb config + try: + self._wandb.config["model/num_parameters"] = model.num_parameters() + except AttributeError: + logger.info( + "Could not log the number of model parameters in Weights & Biases due to an AttributeError." + ) + except WandbConfigError: + logger.warning( + "A ConfigError was raised whilst setting the number of model parameters in Weights & Biases config." + ) + + # log the initial model architecture to an artifact + if self._log_model.is_enabled: + with tempfile.TemporaryDirectory() as temp_dir: + model_name = ( + f"model-{self._wandb.run.id}" + if (args.run_name is None or args.run_name == args.output_dir) + else f"model-{self._wandb.run.name}" + ) + model_artifact = self._wandb.Artifact( + name=model_name, + type="model", + metadata={ + "model_config": model.config.to_dict() if hasattr(model, "config") else None, + "num_parameters": self._wandb.config.get("model/num_parameters"), + "initial_model": True, + }, + ) + # add the architecture to a separate text file + save_model_architecture_to_file(model, temp_dir) + + for f in Path(temp_dir).glob("*"): + if f.is_file(): + with model_artifact.new_file(f.name, mode="wb") as fa: + fa.write(f.read_bytes()) + self._wandb.run.log_artifact(model_artifact, aliases=["base_model"]) + + badge_markdown = ( + f'[Visualize in Weights & Biases]({self._wandb.run.url})' + ) + + modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" def on_train_begin(self, args, state, control, model=None, **kwargs): - pass + if self._wandb is None: + return + hp_search = state.is_hyper_param_search + if hp_search: + self._wandb.finish() + self._initialized = False + args.run_name = None + if not self._initialized: + self.setup(args, state, model, **kwargs) def on_train_end(self, args: TrainingArguments, state, control, model=None, processing_class=None, **kwargs): - pass + if self._wandb is None: + return + if self._log_model.is_enabled and self._initialized and state.is_world_process_zero: + from ..trainer import Trainer + + args_for_fake = copy.deepcopy(args) + args_for_fake.deepspeed = None + args_for_fake.deepspeed_plugin = None + fake_trainer = Trainer( + args=args_for_fake, model=model, processing_class=processing_class, eval_dataset=["fake"] + ) + with tempfile.TemporaryDirectory() as temp_dir: + fake_trainer.save_model(temp_dir) + metadata = ( + { + k: v + for k, v in dict(self._wandb.summary).items() + if isinstance(v, numbers.Number) and not k.startswith("_") + } + if not args.load_best_model_at_end + else { + f"eval/{args.metric_for_best_model}": state.best_metric, + "train/total_floss": state.total_flos, + "model/num_parameters": self._wandb.config.get("model/num_parameters"), + } + ) + metadata["final_model"] = True + logger.info("Logging model artifacts. ...") + model_name = ( + f"model-{self._wandb.run.id}" + if (args.run_name is None or args.run_name == args.output_dir) + else f"model-{self._wandb.run.name}" + ) + # add the model architecture to a separate text file + save_model_architecture_to_file(model, temp_dir) + + artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata) + for f in Path(temp_dir).glob("*"): + if f.is_file(): + with artifact.new_file(f.name, mode="wb") as fa: + fa.write(f.read_bytes()) + self._wandb.run.log_artifact(artifact, aliases=["final_model"]) def on_log(self, args, state, control, model=None, logs=None, **kwargs): - pass + single_value_scalars = [ + "train_runtime", + "train_samples_per_second", + "train_steps_per_second", + "train_loss", + "total_flos", + ] + + if self._wandb is None: + return + if not self._initialized: + self.setup(args, state, model) + if state.is_world_process_zero: + for k, v in logs.items(): + if k in single_value_scalars: + self._wandb.run.summary[k] = v + non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars} + non_scalar_logs = rewrite_logs(non_scalar_logs) + self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step}) def on_save(self, args, state, control, **kwargs): - pass + if self._log_model == WandbLogModel.CHECKPOINT and self._initialized and state.is_world_process_zero: + checkpoint_metadata = { + k: v + for k, v in dict(self._wandb.summary).items() + if isinstance(v, numbers.Number) and not k.startswith("_") + } + checkpoint_metadata["model/num_parameters"] = self._wandb.config.get("model/num_parameters") + + ckpt_dir = f"checkpoint-{state.global_step}" + artifact_path = os.path.join(args.output_dir, ckpt_dir) + logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...") + checkpoint_name = ( + f"model-{self._wandb.run.id}" + if (args.run_name is None or args.run_name == args.output_dir) + else f"model-{self._wandb.run.name}" + ) + artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata) + artifact.add_dir(artifact_path) + self._wandb.log_artifact( + artifact, aliases=[f"epoch_{round(state.epoch, 2)}", f"checkpoint_global_step_{state.global_step}"] + ) def on_predict(self, args, state, control, metrics, **kwargs): - pass + if self._wandb is None: + return + if not self._initialized: + self.setup(args, state, **kwargs) + if state.is_world_process_zero: + metrics = rewrite_logs(metrics) + self._wandb.log(metrics) class TrackioCallback(TrainerCallback): From 5a597336bbadf140cefa5664a83b9daa70bff2f9 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 31 Jul 2025 12:17:47 +0000 Subject: [PATCH 254/342] erroring when loading on cpu --- src/transformers/quantizers/quantizer_mxfp4.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 74f99a822aaa..891d851d9929 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -85,10 +85,12 @@ def validate_environment(self, *args, **kwargs): ) device_map = kwargs.get("device_map", None) - if device_map is None: - logger.warning_once( + if device_map is None and not self.quantization_config.dequantize: + raise ValueError( "You have loaded an FP4 model on CPU and have a CUDA device available, make sure to set " "your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. " + "If you are attempting to train the model, please consider dequantizing the model first by passing " + "quantization_config=Mxfp4Config(dequantize=True) to .from_pretrained()" ) elif device_map is not None: if ( @@ -342,4 +344,5 @@ def is_serializable(self, safe_serialization=None): @property def is_trainable(self) -> bool: + logger.warning_once("MXFP4 quantization don't support training, please consider dequantizing the model first by passing quantization_config=Mxfp4Config(dequantize=True) to .from_pretrained()") return False From 910ccfecaf73731f1125a449d23337f66048e80a Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 31 Jul 2025 13:42:47 +0000 Subject: [PATCH 255/342] updates --- src/transformers/integrations/mxfp4.py | 9 +++++++-- src/transformers/modeling_utils.py | 2 +- src/transformers/quantizers/quantizer_mxfp4.py | 6 ++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 3e93da66c3f4..39437e5b5995 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ..modeling_utils import is_deepspeed_zero3_enabled, is_fsdp_enabled from ..utils import is_accelerate_available, is_torch_available, logging @@ -403,11 +404,15 @@ def dequantize_and_quantize( dequantized, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0 ) del dequantized + + original_device = target_device + if (is_fsdp_enabled() or is_deepspeed_zero3_enabled()) and target_device == "cpu": + loaded_weight = loaded_weight.cuda() + target_device = "cuda" with torch.cuda.device(target_device): loaded_weight, flex, mx = quantize_to_mxfp4(loaded_weight, swizzle_mx_value, swizzle_mx_scale) - setattr(module, precision_config_attr, PrecisionConfig(mx_ctx=mx, flex_ctx=FlexCtx(rhs_data=flex))) - setattr(module, proj, torch.nn.Parameter(loaded_weight, requires_grad=False)) + setattr(module, proj, torch.nn.Parameter(loaded_weight.to(original_device), requires_grad=False)) return diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e0271cd90a81..aed2a787062a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -897,7 +897,7 @@ def _load_state_dict_into_meta_model( if is_fsdp_enabled() and not is_local_dist_rank_0(): param_to = "meta" val_kwargs = {} - if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params": + if (hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params") or (value.dtype == torch.uint8 or value.dtype == torch.int8): val_kwargs["requires_grad"] = False value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__) setattr(module, param_type, value) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 891d851d9929..ffe5538545ea 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -85,12 +85,10 @@ def validate_environment(self, *args, **kwargs): ) device_map = kwargs.get("device_map", None) - if device_map is None and not self.quantization_config.dequantize: - raise ValueError( + if device_map is None: + logger.warning_once( "You have loaded an FP4 model on CPU and have a CUDA device available, make sure to set " "your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. " - "If you are attempting to train the model, please consider dequantizing the model first by passing " - "quantization_config=Mxfp4Config(dequantize=True) to .from_pretrained()" ) elif device_map is not None: if ( From 212acd0f0ef1a84c5ff2285ab32156a9f6fda9e1 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 31 Jul 2025 13:51:07 +0000 Subject: [PATCH 256/342] del blocks, scales --- src/transformers/integrations/mxfp4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 39437e5b5995..94cbbf9c9077 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -131,7 +131,7 @@ def convert_moe_packed_tensors( # Move back to CPU if needed # if need_to_move_back: # out = out.cpu() - + del blocks, scales return out From 5c6d3b2c6f389ab641f632da873c4f450ac2468d Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 31 Jul 2025 14:27:55 +0000 Subject: [PATCH 257/342] fix --- src/transformers/integrations/mxfp4.py | 80 +++++--------------------- src/transformers/modeling_utils.py | 6 +- 2 files changed, 17 insertions(+), 69 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 75e25438c094..c3d48d16fbfe 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -47,88 +47,36 @@ ] -# Copied from GPT_OSS repo -# TODO: Add absolute link when the repo is public -# def quantize_to_mxfp4(w, swizzle_mx_value, swizzle_mx_scale): -# from triton_kernels.matmul_ogs import InFlexData, MicroscalingCtx -# from triton_kernels.numerics_details.mxfp import downcast_to_mxfp - -# swizzle_axis = 2 if swizzle_mx_scale or swizzle_mx_value else None -# w = w.to(torch.bfloat16) - -# w, mx_scales, weight_scale_shape = downcast_to_mxfp( -# w, -# torch.uint8, -# axis=1, -# swizzle_axis=swizzle_axis, -# swizzle_scale=swizzle_mx_scale, -# swizzle_value=swizzle_mx_value, -# ) - -# return ( -# w, -# InFlexData(), -# MicroscalingCtx( -# weight_scale=mx_scales, -# swizzle_scale=swizzle_mx_scale, -# swizzle_value=swizzle_mx_value, -# actual_weight_scale_shape=weight_scale_shape, -# ), -# ) - +# Copied from GPT_OSS repo and vllm def quantize_to_mxfp4(w): from triton_kernels.numerics_details.mxfp import downcast_to_mxfp from triton_kernels.tensor import convert_layout from triton_kernels.tensor import wrap_torch_tensor, FP4 from triton_kernels.tensor_details import layout + from triton_kernels.tensor_details.layout import StridedLayout + import triton_kernels.matmul_ogs_details.opt_flags as opt_flags - # FIXME warp need to be adjusted based on batch size - # only apply to batched mode - num_warps = 8 - w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1) value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) - scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1,num_warps=num_warps) + w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) - #TODO: check if needed - # if current_platform.is_cuda() and \ - # torch.cuda.get_device_capability()[0] == 10: + # TODO : add that when we are actually sure that it works on B200 + # if torch.cuda.get_device_capability()[0] == 10: # constraints = { # "is_persistent": True, # "epilogue_subtile": 1, # } # opt_flags.update_opt_flags_constraints(constraints) # # transpose the tensor so that the quantization axis is on dim1 - # quant_tensor = quant_tensor.transpose(-2, -1) - # scale = scale.transpose(-2, -1) + + + # TODO: there is still an issue with the scales on hopper + # scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=8) + # w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts) + w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) - w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) - w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts) return w, w_scale -# def swizzle_mxfp4(quant_tensor, scale, num_warps): -# value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout( -# mx_axis=1) -# scale_layout, scale_layout_opts = ( -# layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, -# num_warps=num_warps)) -# if current_platform.is_cuda() and \ -# torch.cuda.get_device_capability()[0] == 10: -# constraints = { -# "is_persistent": True, -# "epilogue_subtile": 1, -# } -# opt_flags.update_opt_flags_constraints(constraints) -# # transpose the tensor so that the quantization axis is on dim1 -# quant_tensor = quant_tensor.transpose(-2, -1) -# scale = scale.transpose(-2, -1) -# quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4), -# value_layout, **value_layout_opts) -# scale = convert_layout(wrap_torch_tensor(scale), scale_layout, -# **scale_layout_opts) -# return quant_tensor, InFlexData(), scale - - # Copied from GPT_OSS repo # TODO: Add absolute link when the repo is public def convert_moe_packed_tensors( @@ -434,7 +382,7 @@ def dequantize_and_quantize( _load_parameter_into_model(model, param_name, param_value) dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) - dequantized = dequantized.transpose(1, 2).to(target_device) + dequantized = dequantized.transpose(1, 2).contiguous().to(target_device) right_pad = getattr(module, right_pad_attr) bottom_pad = getattr(module, bottom_pad_attr) @@ -444,7 +392,7 @@ def dequantize_and_quantize( with torch.cuda.device(target_device): triton_weight_tensor, weight_scale = quantize_to_mxfp4(dequantized) setattr(module, precision_config_attr, PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()))) - setattr(module, proj, triton_weight_tensor) + setattr(module, proj, triton_weight_tensor) setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False)) return diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e0271cd90a81..a8937f2c11aa 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4961,9 +4961,6 @@ def from_pretrained( # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs) - if _torch_distributed_available and device_mesh is not None: - model = distribute_model(model, distributed_config, device_mesh, tp_size) - # Make sure to tie the weights correctly model.tie_weights() @@ -5008,6 +5005,9 @@ def _assign_original_dtype(module): config._pre_quantization_dtype = original_dtype _assign_original_dtype(model) + + if _torch_distributed_available and device_mesh is not None: + model = distribute_model(model, distributed_config, device_mesh, tp_size) # Prepare the full device map if device_map is not None: From 5ec240fcdb4f86fa077721c83df32361be1f49c9 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 31 Jul 2025 14:31:00 +0000 Subject: [PATCH 258/342] style --- src/transformers/integrations/mxfp4.py | 14 ++++++-------- src/transformers/modeling_utils.py | 2 +- src/transformers/quantizers/quantizer_mxfp4.py | 2 +- .../models/openai_moe/test_modeling_openai_moe.py | 1 - 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index c3d48d16fbfe..600617a324bd 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -50,16 +50,14 @@ # Copied from GPT_OSS repo and vllm def quantize_to_mxfp4(w): from triton_kernels.numerics_details.mxfp import downcast_to_mxfp - from triton_kernels.tensor import convert_layout - from triton_kernels.tensor import wrap_torch_tensor, FP4 + from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor from triton_kernels.tensor_details import layout from triton_kernels.tensor_details.layout import StridedLayout - import triton_kernels.matmul_ogs_details.opt_flags as opt_flags - + w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1) value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) - + # TODO : add that when we are actually sure that it works on B200 # if torch.cuda.get_device_capability()[0] == 10: # constraints = { @@ -69,12 +67,12 @@ def quantize_to_mxfp4(w): # opt_flags.update_opt_flags_constraints(constraints) # # transpose the tensor so that the quantization axis is on dim1 - + # TODO: there is still an issue with the scales on hopper # scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=8) # w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts) w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) - + return w, w_scale # Copied from GPT_OSS repo @@ -341,7 +339,7 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** def dequantize_and_quantize( module, param_name, param_value, target_device, **kwargs ): - from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, InFlexData + from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig from ..integrations.tensor_parallel import shard_and_distribute_module from ..modeling_utils import _load_parameter_into_model diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a8937f2c11aa..8a43df51aeda 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5005,7 +5005,7 @@ def _assign_original_dtype(module): config._pre_quantization_dtype = original_dtype _assign_original_dtype(model) - + if _torch_distributed_available and device_mesh is not None: model = distribute_model(model, distributed_config, device_mesh, tp_size) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index b78dd2b0fed1..122d1315c3a6 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -150,7 +150,7 @@ def create_quantized_param( **kwargs, ): if is_triton_kernels_availalble(): - from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, InFlexData + from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig from ..integrations import Mxfp4OpenAIMoeExperts, dequantize, dequantize_and_quantize, quantize_to_mxfp4 from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/openai_moe/test_modeling_openai_moe.py index 992e06f476c8..d91d39a218ea 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/openai_moe/test_modeling_openai_moe.py @@ -25,7 +25,6 @@ require_read_token, require_torch, require_torch_accelerator, - require_torch_large_accelerator, require_torch_multi_accelerator, slow, torch_device, From 2faa7ca4dc8e55608faeb1a9bd955361f7b655d6 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 31 Jul 2025 14:34:06 +0000 Subject: [PATCH 259/342] rm comm --- src/transformers/quantizers/quantizer_mxfp4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 122d1315c3a6..45cae4579345 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -170,7 +170,7 @@ def create_quantized_param( weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()) ) module.gate_up_proj = triton_weight_tensor - # module.gate_up_proj_blocks = torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False) + module.gate_up_proj_blocks = torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False) elif "down_proj" in param_name: right_pad = module.down_proj_right_pad bottom_pad = module.down_proj_bottom_pad @@ -180,7 +180,7 @@ def create_quantized_param( triton_weight_tensor, weight_scale = quantize_to_mxfp4(loaded_weight) module.down_proj_precision_config = PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())) module.down_proj = triton_weight_tensor - # module.down_proj_blocks = torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False) + module.down_proj_blocks = torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False) # we take this path if already quantized but not in a compatible way # The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales From c5b8cecdf69fb370f417f04bba73c81be06e16de Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Thu, 31 Jul 2025 14:35:56 +0000 Subject: [PATCH 260/342] comment --- src/transformers/integrations/mxfp4.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 94cbbf9c9077..b59b063c0a4b 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -406,6 +406,7 @@ def dequantize_and_quantize( del dequantized original_device = target_device + # for fsdp and deepspeed since the model is load on cpu, we need to move the weight to gpu for quantization if (is_fsdp_enabled() or is_deepspeed_zero3_enabled()) and target_device == "cpu": loaded_weight = loaded_weight.cuda() target_device = "cuda" From 79dd4fc14469c383bee6e0c4f3164845b7fd9ce7 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 31 Jul 2025 14:55:12 +0000 Subject: [PATCH 261/342] add comment --- src/transformers/integrations/mxfp4.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 600617a324bd..0999315202bf 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -390,6 +390,7 @@ def dequantize_and_quantize( with torch.cuda.device(target_device): triton_weight_tensor, weight_scale = quantize_to_mxfp4(dequantized) setattr(module, precision_config_attr, PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()))) + # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It is like a subtensor setattr(module, proj, triton_weight_tensor) setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False)) return From d238ea4eb50d8a9e761914ea7d024912adce4e61 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 31 Jul 2025 15:04:20 +0000 Subject: [PATCH 262/342] style --- src/transformers/integrations/mxfp4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 861854e976ea..133329b04ea1 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -420,7 +420,6 @@ def dequantize_and_quantize( # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It is like a subtensor setattr(module, proj, triton_weight_tensor) setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False)) - return From a7dd97fd049dac143d838084ef62912de0e9aaf8 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 31 Jul 2025 15:07:30 +0000 Subject: [PATCH 263/342] remove duplicated lines --- src/transformers/integrations/mxfp4.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 133329b04ea1..8cf25e57a95f 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -398,16 +398,10 @@ def dequantize_and_quantize( right_pad = getattr(module, right_pad_attr) bottom_pad = getattr(module, bottom_pad_attr) + dequantized = torch.nn.functional.pad( dequantized, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0 ) - with torch.cuda.device(target_device): - triton_weight_tensor, weight_scale = quantize_to_mxfp4(dequantized) - setattr(module, precision_config_attr, PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()))) - # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It is like a subtensor - setattr(module, proj, triton_weight_tensor) - setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False)) - original_device = target_device # for fsdp and deepspeed since the model is load on cpu, we need to move the weight to gpu for quantization if (is_fsdp_enabled() or is_deepspeed_zero3_enabled()) and target_device == "cpu": From cf4843b42d87c7b92e8319ee562a078ac9249925 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 1 Aug 2025 00:20:07 +0000 Subject: [PATCH 264/342] Fix minor issue with weight_map conversion script --- .../models/openai_moe/convert_openai_weights_to_hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 401201d47fe8..6ebcd20585d5 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -291,8 +291,6 @@ def save_sharded_model(state_dict, model_path): safetensors_index["weight_map"] = {} for key in state_dict.keys(): size = state_dict[key].numel() * state_dict[key].element_size() - safetensors_index["metadata"]["total_size"] += size - safetensors_index["weight_map"][key] = shard_id if shard_size_counter + size > max_shard_size: total_sharded_dict[shard_id] = shard_state_dict shard_id += 1 @@ -300,6 +298,8 @@ def save_sharded_model(state_dict, model_path): shard_state_dict = {} shard_state_dict[key] = state_dict[key] shard_size_counter += size + safetensors_index["metadata"]["total_size"] += size + safetensors_index["weight_map"][key] = shard_id total_sharded_dict[shard_id] = shard_state_dict num_shards = len(total_sharded_dict) - 1 for shard_id, shard_state_dict in total_sharded_dict.items(): From 8b7a73f20bf13de6ff224d55eca2c120ddb28881 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 31 Jul 2025 21:14:29 -0700 Subject: [PATCH 265/342] fix sampling params --- .../models/openai_moe/convert_openai_weights_to_hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py index 401201d47fe8..f1b1d0ba709f 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py @@ -271,8 +271,8 @@ def write_model( do_sample=True, eos_token_id=[200002, 199999], # <|return|>, <|endoftext|> pad_token_id=199999, # <|endoftext|> - temperature=0.6, - top_p=0.9, + temperature=1.0, + top_p=1.0, ) generation_config.save_pretrained(model_path) From 08b031b7706d7b9f5c654c4c73d4901485844cf0 Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 1 Aug 2025 08:48:49 +0200 Subject: [PATCH 266/342] rename to final name --- src/transformers/integrations/__init__.py | 4 +- src/transformers/integrations/mxfp4.py | 8 +- src/transformers/modeling_utils.py | 4 +- src/transformers/models/__init__.py | 2 +- .../models/auto/configuration_auto.py | 4 +- src/transformers/models/auto/modeling_auto.py | 4 +- .../models/auto/tokenization_auto.py | 2 +- .../{openai_moe => gpt_oss}/__init__.py | 4 +- .../configuration_gpt_oss.py} | 6 +- .../convert_gpt_oss_weights_to_hf.py} | 18 ++-- .../modeling_gpt_oss.py} | 82 +++++++++---------- .../modular_gpt_oss.py} | 58 ++++++------- .../quantizers/quantizer_mxfp4.py | 24 +++--- .../{openai_moe => gpt_oss}/__init__.py | 0 .../test_modeling_gpt_oss.py} | 68 +++++++-------- tests/quantization/mxfp4/test_mxfp4.py | 14 ++-- 16 files changed, 151 insertions(+), 151 deletions(-) rename src/transformers/models/{openai_moe => gpt_oss}/__init__.py (91%) rename src/transformers/models/{openai_moe/configuration_openai_moe.py => gpt_oss/configuration_gpt_oss.py} (97%) rename src/transformers/models/{openai_moe/convert_openai_weights_to_hf.py => gpt_oss/convert_gpt_oss_weights_to_hf.py} (97%) rename src/transformers/models/{openai_moe/modeling_openai_moe.py => gpt_oss/modeling_gpt_oss.py} (92%) rename src/transformers/models/{openai_moe/modular_openai_moe.py => gpt_oss/modular_gpt_oss.py} (92%) rename tests/models/{openai_moe => gpt_oss}/__init__.py (100%) rename tests/models/{openai_moe/test_modeling_openai_moe.py => gpt_oss/test_modeling_gpt_oss.py} (74%) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 01d6a7eaf4a6..683eb6c9e73a 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -121,7 +121,7 @@ ], "mxfp4": [ "replace_with_mxfp4_linear", - "Mxfp4OpenAIMoeExperts", + "Mxfp4GptOssExperts", "quantize_to_mxfp4", "convert_moe_packed_tensors", "dequantize", @@ -264,7 +264,7 @@ run_hp_search_wandb, ) from .mxfp4 import ( - Mxfp4OpenAIMoeExperts, + Mxfp4GptOssExperts, dequantize, dequantize_and_quantize, quantize_to_mxfp4, diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 8cf25e57a95f..01c384a63855 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -133,7 +133,7 @@ def convert_moe_packed_tensors( return out -class Mxfp4OpenAIMoeExperts(nn.Module): +class Mxfp4GptOssExperts(nn.Module): def __init__(self, config): super().__init__() @@ -433,11 +433,11 @@ def _replace_with_mxfp4_linear( if not should_convert_module(current_key_name, modules_to_not_convert): current_key_name.pop(-1) continue - if module.__class__.__name__ == "OpenAIMoeExperts" and not quantization_config.dequantize: + if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize: with init_empty_weights(): - model._modules[name] = Mxfp4OpenAIMoeExperts(config) + model._modules[name] = Mxfp4GptOssExperts(config) has_been_replaced = True - if module.__class__.__name__ == "OpenAIMoeMLP" and not quantization_config.dequantize: + if module.__class__.__name__ == "GptOssMLP" and not quantization_config.dequantize: from types import MethodType module.forward = MethodType(mlp_forward, module) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5b47b27d3250..3cd78cd0c7cd 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -890,8 +890,8 @@ def _load_state_dict_into_meta_model( param_name = hf_quantizer.update_param_name(param_name) module, param_type = get_module_from_name(model, param_name) value = getattr(module, param_type) - # special case for OpenAIMoeForCausalLM, we wait for the param to be leave the meta device before casting it to cpu - if model.__class__.__name__ == "OpenAIMoeForCausalLM" and value.device.type == "meta": + # special case for GptOssForCausalLM, we wait for the param to be leave the meta device before casting it to cpu + if model.__class__.__name__ == "GptOssForCausalLM" and value.device.type == "meta": continue param_to = "cpu" if is_fsdp_enabled() and not is_local_dist_rank_0(): diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index f68ea78c29d2..275b0aea6dce 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -234,7 +234,7 @@ from .omdet_turbo import * from .oneformer import * from .openai import * - from .openai_moe import * + from .gpt_oss import * from .opt import * from .owlv2 import * from .owlvit import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 823f3f4d2b99..ee69412a14ec 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -272,7 +272,7 @@ ("oneformer", "OneFormerConfig"), ("open-llama", "OpenLlamaConfig"), ("openai-gpt", "OpenAIGPTConfig"), - ("openai_moe", "OpenAIMoeConfig"), + ("gpt_oss", "GptOssConfig"), ("opt", "OPTConfig"), ("owlv2", "Owlv2Config"), ("owlvit", "OwlViTConfig"), @@ -688,7 +688,7 @@ ("oneformer", "OneFormer"), ("open-llama", "OpenLlama"), ("openai-gpt", "OpenAI GPT"), - ("openai_moe", "OpenAIMoe"), + ("gpt_oss", "GptOss"), ("opt", "OPT"), ("owlv2", "OWLv2"), ("owlvit", "OWL-ViT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index fe4dbf253b41..682b1822414e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -260,7 +260,7 @@ ("oneformer", "OneFormerModel"), ("open-llama", "OpenLlamaModel"), ("openai-gpt", "OpenAIGPTModel"), - ("openai_moe", "OpenAIMoeModel"), + ("gpt_oss", "GptOssModel"), ("opt", "OPTModel"), ("owlv2", "Owlv2Model"), ("owlvit", "OwlViTModel"), @@ -664,7 +664,7 @@ ("olmoe", "OlmoeForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"), ("openai-gpt", "OpenAIGPTLMHeadModel"), - ("openai_moe", "OpenAIMoeForCausalLM"), + ("gpt_oss", "GptOssForCausalLM"), ("opt", "OPTForCausalLM"), ("pegasus", "PegasusForCausalLM"), ("persimmon", "PersimmonForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 5eb13a59c53d..cb3d18f81cd9 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -483,7 +483,7 @@ "openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None), ), - ("openai_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("gpt_oss", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/openai_moe/__init__.py b/src/transformers/models/gpt_oss/__init__.py similarity index 91% rename from src/transformers/models/openai_moe/__init__.py rename to src/transformers/models/gpt_oss/__init__.py index b35ca7cd1dc6..19e12e75ef8f 100644 --- a/src/transformers/models/openai_moe/__init__.py +++ b/src/transformers/models/gpt_oss/__init__.py @@ -18,8 +18,8 @@ if TYPE_CHECKING: - from .configuration_openai_moe import * - from .modeling_openai_moe import * + from .configuration_gpt_oss import * + from .modeling_gpt_oss import * else: import sys diff --git a/src/transformers/models/openai_moe/configuration_openai_moe.py b/src/transformers/models/gpt_oss/configuration_gpt_oss.py similarity index 97% rename from src/transformers/models/openai_moe/configuration_openai_moe.py rename to src/transformers/models/gpt_oss/configuration_gpt_oss.py index 76b4e4fca651..9a8668e5798d 100644 --- a/src/transformers/models/openai_moe/configuration_openai_moe.py +++ b/src/transformers/models/gpt_oss/configuration_gpt_oss.py @@ -18,14 +18,14 @@ from ...modeling_rope_utils import rope_config_validation -class OpenAIMoeConfig(PretrainedConfig): +class GptOssConfig(PretrainedConfig): r""" This will yield a configuration to that of the BERT [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture. """ - model_type = "openai_moe" + model_type = "gpt_oss" base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), @@ -116,4 +116,4 @@ def __init__( ) -__all__ = ["OpenAIMoeConfig"] +__all__ = ["GptOssConfig"] diff --git a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py similarity index 97% rename from src/transformers/models/openai_moe/convert_openai_weights_to_hf.py rename to src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 6ebcd20585d5..e675d43e91ef 100644 --- a/src/transformers/models/openai_moe/convert_openai_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -26,8 +26,8 @@ from transformers import ( GenerationConfig, - OpenAIMoeConfig, - OpenAIMoeForCausalLM, + GptOssConfig, + GptOssForCausalLM, PreTrainedTokenizerFast, ) from transformers.convert_slow_tokenizer import TikTokenConverter @@ -162,7 +162,7 @@ def write_model( "original_max_position_embeddings": 4096, } - config = OpenAIMoeConfig( + config = GptOssConfig( num_local_experts=num_local_experts, rope_scaling=rope_scaling, eos_token_id=eos_token_id, **original_config ) @@ -230,9 +230,9 @@ def write_model( gc.collect() if not mxfp4: - print("Loading the checkpoint in a OpenAIMoe model for unpacked format") + print("Loading the checkpoint in a GptOss model for unpacked format") with torch.device("meta"): - model = OpenAIMoeForCausalLM(config) + model = GptOssForCausalLM(config) model.load_state_dict(state_dict, strict=True, assign=True) print("Checkpoint loaded successfully.") del config._name_or_path @@ -253,14 +253,14 @@ def write_model( ], } # required as we don't save the model with save_pretrained - config.architectures = ["OpenAIMoeForCausalLM"] + config.architectures = ["GptOssForCausalLM"] config.save_pretrained(model_path) save_sharded_model(state_dict, model_path) del state_dict gc.collect() print("Reloading the model to check if it's saved correctly.") - OpenAIMoeForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") + GptOssForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") print("Model reloaded successfully.") # generation config @@ -340,7 +340,7 @@ def bytes_to_unicode(): return dict(zip(bs, cs)) -class OpenAIMoeConverter(TikTokenConverter): +class GptOssConverter(TikTokenConverter): def extract_vocab_merges_from_model(self, tiktoken_url: str): tokenizer = tiktoken.get_encoding(tiktoken_url) self.pattern = tokenizer._pat_str @@ -486,7 +486,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {%- endif -%} """ - converter = OpenAIMoeConverter( + converter = GptOssConverter( vocab_file=tokenizer_path, model_max_length=None, chat_template=chat_template if instruct else None, diff --git a/src/transformers/models/openai_moe/modeling_openai_moe.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py similarity index 92% rename from src/transformers/models/openai_moe/modeling_openai_moe.py rename to src/transformers/models/gpt_oss/modeling_gpt_oss.py index 48e6e9f260e1..4259543f8f84 100644 --- a/src/transformers/models/openai_moe/modeling_openai_moe.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -1,8 +1,8 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/openai_moe/modular_openai_moe.py. +# This file was automatically generated from src/transformers/models/gpt_oss/modular_gpt_oss.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the -# modular_openai_moe.py file directly. One of our CI enforces this. +# modular_gpt_oss.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The HuggingFace Team. All rights reserved. @@ -35,14 +35,14 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import OutputRecorder, check_model_inputs -from .configuration_openai_moe import OpenAIMoeConfig +from .configuration_gpt_oss import GptOssConfig @use_kernel_forward_from_hub("RMSNorm") -class OpenAIMoeRMSNorm(nn.Module): +class GptOssRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ - OpenAIMoeRMSNorm is equivalent to T5LayerNorm + GptOssRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -59,7 +59,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class OpenAIMoeExperts(nn.Module): +class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() self.intermediate_size = config.intermediate_size @@ -128,7 +128,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig return next_states -class OpenAIMoeTopKRouter(nn.Module): +class GptOssTopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok @@ -147,11 +147,11 @@ def forward(self, hidden_states): @use_kernel_forward_from_hub("MegaBlocksMoeMLP") -class OpenAIMoeMLP(nn.Module): +class GptOssMLP(nn.Module): def __init__(self, config): super().__init__() - self.router = OpenAIMoeTopKRouter(config) - self.experts = OpenAIMoeExperts(config) + self.router = GptOssTopKRouter(config) + self.experts = GptOssExperts(config) def forward(self, hidden_states): router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) @@ -159,8 +159,8 @@ def forward(self, hidden_states): return routed_out, router_scores -class OpenAIMoeRotaryEmbedding(nn.Module): - def __init__(self, config: OpenAIMoeConfig, device=None): +class GptOssRotaryEmbedding(nn.Module): + def __init__(self, config: GptOssConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): @@ -253,10 +253,10 @@ def eager_attention_forward( return attn_output, attn_weights -class OpenAIMoeAttention(nn.Module): +class GptOssAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: OpenAIMoeConfig, layer_idx: int): + def __init__(self, config: GptOssConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx @@ -325,14 +325,14 @@ def forward( return attn_output, attn_weights -class OpenAIMoeDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: OpenAIMoeConfig, layer_idx: int): +class GptOssDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: GptOssConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = OpenAIMoeAttention(config=config, layer_idx=layer_idx) - self.mlp = OpenAIMoeMLP(config) - self.input_layernorm = OpenAIMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = OpenAIMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx) + self.mlp = GptOssMLP(config) + self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] def forward( @@ -370,11 +370,11 @@ def forward( @auto_docstring -class OpenAIMoePreTrainedModel(PreTrainedModel): - config: OpenAIMoeConfig +class GptOssPreTrainedModel(PreTrainedModel): + config: GptOssConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["OpenAIMoeDecoderLayer"] + _no_split_modules = ["GptOssDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = False @@ -383,9 +383,9 @@ class OpenAIMoePreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(OpenAIMoeTopKRouter, index=0), - "hidden_states": OpenAIMoeDecoderLayer, - "attentions": OpenAIMoeAttention, + "router_logits": OutputRecorder(GptOssTopKRouter, index=0), + "hidden_states": GptOssDecoderLayer, + "attentions": GptOssAttention, } _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] _supports_flex_attention = False @@ -400,32 +400,32 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, OpenAIMoeRMSNorm): + elif isinstance(module, GptOssRMSNorm): module.weight.data.fill_(1.0) - elif isinstance(module, OpenAIMoeExperts): + elif isinstance(module, GptOssExperts): module.gate_up_proj.data.normal_(mean=0.0, std=std) module.gate_up_proj_bias.data.zero_() module.down_proj.data.normal_(mean=0.0, std=std) module.down_proj_bias.data.zero_() - elif isinstance(module, OpenAIMoeAttention): + elif isinstance(module, GptOssAttention): module.sinks.data.normal_(mean=0.0, std=std) @auto_docstring -class OpenAIMoeModel(OpenAIMoePreTrainedModel): - _no_split_modules = ["OpenAIMoeDecoderLayer"] +class GptOssModel(GptOssPreTrainedModel): + _no_split_modules = ["GptOssDecoderLayer"] - def __init__(self, config: OpenAIMoeConfig): + def __init__(self, config: GptOssConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( - [OpenAIMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [GptOssDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = OpenAIMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = OpenAIMoeRotaryEmbedding(config=config) + self.norm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = GptOssRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -582,14 +582,14 @@ def load_balancing_loss_func( @auto_docstring -class OpenAIMoeForCausalLM(OpenAIMoePreTrainedModel, GenerationMixin): +class GptOssForCausalLM(GptOssPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) - self.model = OpenAIMoeModel(config) + self.model = GptOssModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.router_aux_loss_coef @@ -630,10 +630,10 @@ def forward( Example: ```python - >>> from transformers import AutoTokenizer, OpenAIMoeForCausalLM + >>> from transformers import AutoTokenizer, GptOssForCausalLM - >>> model = OpenAIMoeForCausalLM.from_pretrained("mistralai/OpenAIMoe-8x7B-v0.1") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/OpenAIMoe-8x7B-v0.1") + >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -692,4 +692,4 @@ def forward( ) -__all__ = ["OpenAIMoeForCausalLM", "OpenAIMoeModel", "OpenAIMoePreTrainedModel"] +__all__ = ["GptOssForCausalLM", "GptOssModel", "GptOssPreTrainedModel"] diff --git a/src/transformers/models/openai_moe/modular_openai_moe.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py similarity index 92% rename from src/transformers/models/openai_moe/modular_openai_moe.py rename to src/transformers/models/gpt_oss/modular_gpt_oss.py index 1c10ee01d791..b2c973ff49ba 100644 --- a/src/transformers/models/openai_moe/modular_openai_moe.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -41,13 +41,13 @@ ) from ..mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel from ..qwen2.modeling_qwen2 import Qwen2Attention -from .configuration_openai_moe import OpenAIMoeConfig +from .configuration_gpt_oss import GptOssConfig logger = logging.get_logger(__name__) -class OpenAIMoeRMSNorm(LlamaRMSNorm): +class GptOssRMSNorm(LlamaRMSNorm): def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) @@ -56,7 +56,7 @@ def forward(self, hidden_states): return (self.weight * hidden_states).to(input_dtype) # main diff with Llama -class OpenAIMoeExperts(nn.Module): +class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() self.intermediate_size = config.intermediate_size @@ -125,7 +125,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig return next_states -class OpenAIMoeTopKRouter(nn.Module): +class GptOssTopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok @@ -144,11 +144,11 @@ def forward(self, hidden_states): @use_kernel_forward_from_hub("MegaBlocksMoeMLP") -class OpenAIMoeMLP(nn.Module): +class GptOssMLP(nn.Module): def __init__(self, config): super().__init__() - self.router = OpenAIMoeTopKRouter(config) - self.experts = OpenAIMoeExperts(config) + self.router = GptOssTopKRouter(config) + self.experts = GptOssExperts(config) def forward(self, hidden_states): router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) @@ -156,7 +156,7 @@ def forward(self, hidden_states): return routed_out, router_scores -class OpenAIMoeRotaryEmbedding(LlamaRotaryEmbedding): +class GptOssRotaryEmbedding(LlamaRotaryEmbedding): @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): @@ -221,8 +221,8 @@ def eager_attention_forward( return attn_output, attn_weights -class OpenAIMoeAttention(Qwen2Attention): - def __init__(self, config: OpenAIMoeConfig, layer_idx: int): +class GptOssAttention(Qwen2Attention): + def __init__(self, config: GptOssConfig, layer_idx: int): super().__init__(config, layer_idx) self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias @@ -283,14 +283,14 @@ def forward( return attn_output, attn_weights -class OpenAIMoeDecoderLayer(LlamaDecoderLayer): - def __init__(self, config: OpenAIMoeConfig, layer_idx: int): +class GptOssDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: GptOssConfig, layer_idx: int): super().__init__(config, layer_idx) self.hidden_size = config.hidden_size - self.self_attn = OpenAIMoeAttention(config=config, layer_idx=layer_idx) - self.mlp = OpenAIMoeMLP(config) - self.input_layernorm = OpenAIMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = OpenAIMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx) + self.mlp = GptOssMLP(config) + self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] def forward( @@ -327,14 +327,14 @@ def forward( return hidden_states -class OpenAIMoePreTrainedModel(LlamaPreTrainedModel): +class GptOssPreTrainedModel(LlamaPreTrainedModel): _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] _supports_sdpa = False _supports_flex_attention = False _can_record_outputs = { - "router_logits": OutputRecorder(OpenAIMoeTopKRouter, index=0), - "hidden_states": OpenAIMoeDecoderLayer, - "attentions": OpenAIMoeAttention, + "router_logits": OutputRecorder(GptOssTopKRouter, index=0), + "hidden_states": GptOssDecoderLayer, + "attentions": GptOssAttention, } def _init_weights(self, module): @@ -347,19 +347,19 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, OpenAIMoeRMSNorm): + elif isinstance(module, GptOssRMSNorm): module.weight.data.fill_(1.0) - elif isinstance(module, OpenAIMoeExperts): + elif isinstance(module, GptOssExperts): module.gate_up_proj.data.normal_(mean=0.0, std=std) module.gate_up_proj_bias.data.zero_() module.down_proj.data.normal_(mean=0.0, std=std) module.down_proj_bias.data.zero_() - elif isinstance(module, OpenAIMoeAttention): + elif isinstance(module, GptOssAttention): module.sinks.data.normal_(mean=0.0, std=std) -class OpenAIMoeModel(MixtralModel): - _no_split_modules = ["OpenAIMoeDecoderLayer"] +class GptOssModel(MixtralModel): + _no_split_modules = ["GptOssDecoderLayer"] @check_model_inputs @auto_docstring @@ -426,12 +426,12 @@ def forward( ) -class OpenAIMoeForCausalLM(MixtralForCausalLM): +class GptOssForCausalLM(MixtralForCausalLM): pass __all__ = [ - "OpenAIMoeForCausalLM", - "OpenAIMoeModel", - "OpenAIMoePreTrainedModel", + "GptOssForCausalLM", + "GptOssModel", + "GptOssPreTrainedModel", ] diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 92d4ace1dca0..c64902e974d0 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -122,8 +122,8 @@ def check_quantized_param( state_dict: dict[str, Any], **kwargs, ): - from ..integrations import Mxfp4OpenAIMoeExperts - from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts + from ..integrations import Mxfp4GptOssExperts + from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts # if we are dequantizing, the model doesn't have scales, and blocks only params like gate_up_proj and down_proj so we need to handle this case differently if self.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name): @@ -131,8 +131,8 @@ def check_quantized_param( else: module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module, Mxfp4OpenAIMoeExperts) or ( - isinstance(module, OpenAIMoeExperts) and self.quantization_config.dequantize + if isinstance(module, Mxfp4GptOssExperts) or ( + isinstance(module, GptOssExperts) and self.quantization_config.dequantize ): if tensor_name in ["down_proj_bias", "gate_up_proj_bias"]: return False @@ -152,13 +152,13 @@ def create_quantized_param( if is_triton_kernels_availalble(): from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig - from ..integrations import Mxfp4OpenAIMoeExperts, dequantize, dequantize_and_quantize, quantize_to_mxfp4 - from ..models.openai_moe.modeling_openai_moe import OpenAIMoeExperts + from ..integrations import Mxfp4GptOssExperts, dequantize, dequantize_and_quantize, quantize_to_mxfp4 + from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts if not self.pre_quantized: module, _ = get_module_from_name(model, param_name) with torch.cuda.device(target_device): - if isinstance(module, Mxfp4OpenAIMoeExperts): + if isinstance(module, Mxfp4GptOssExperts): if "gate_up_proj" in param_name: right_pad = module.gate_up_proj_right_pad bottom_pad = module.gate_up_proj_bottom_pad @@ -205,8 +205,8 @@ def create_quantized_param( "model": model, } - if isinstance(module, Mxfp4OpenAIMoeExperts) or ( - isinstance(module, OpenAIMoeExperts) and self.quantization_config.dequantize + if isinstance(module, Mxfp4GptOssExperts) or ( + isinstance(module, GptOssExperts) and self.quantization_config.dequantize ): if self.quantization_config.dequantize: # dq_param_name is the name of the parameter without the blocks or scales suffix, it's used in this case since we don't switch linears @@ -266,11 +266,11 @@ def _process_model_before_weight_loading( model.config.quantization_config = self.quantization_config def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]: - from ..integrations import Mxfp4OpenAIMoeExperts + from ..integrations import Mxfp4GptOssExperts not_missing_keys = [] for name, module in model.named_modules(): - if isinstance(module, Mxfp4OpenAIMoeExperts): + if isinstance(module, Mxfp4GptOssExperts): for missing in missing_keys: if ( (name in missing or name in f"{prefix}.{missing}") @@ -281,7 +281,7 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li return [k for k in missing_keys if k not in not_missing_keys] def update_tp_plan(self, config): - if "OpenAIMoeConfig" in config.__class__.__name__: + if "GptOssConfig" in config.__class__.__name__: if getattr(config, "base_model_tp_plan", None) is not None: config.base_model_tp_plan.update( { diff --git a/tests/models/openai_moe/__init__.py b/tests/models/gpt_oss/__init__.py similarity index 100% rename from tests/models/openai_moe/__init__.py rename to tests/models/gpt_oss/__init__.py diff --git a/tests/models/openai_moe/test_modeling_openai_moe.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py similarity index 74% rename from tests/models/openai_moe/test_modeling_openai_moe.py rename to tests/models/gpt_oss/test_modeling_gpt_oss.py index d91d39a218ea..df35fc6114b0 100644 --- a/tests/models/openai_moe/test_modeling_openai_moe.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Testing suite for the PyTorch OpenAIMoe model.""" +"""Testing suite for the PyTorch GptOss model.""" import unittest @@ -19,7 +19,7 @@ from parameterized import parameterized from tests.tensor_parallel.test_tensor_parallel import TensorParallelTestBase -from transformers import AutoModelForCausalLM, AutoTokenizer, OpenAIMoeConfig, is_torch_available +from transformers import AutoModelForCausalLM, AutoTokenizer, GptOssConfig, is_torch_available from transformers.testing_utils import ( cleanup, require_read_token, @@ -38,21 +38,21 @@ import torch from transformers import ( - OpenAIMoeForCausalLM, - OpenAIMoeModel, + GptOssForCausalLM, + GptOssModel, ) -class OpenAIMoeModelTester(CausalLMModelTester): +class GptOssModelTester(CausalLMModelTester): if is_torch_available(): - config_class = OpenAIMoeConfig - base_model_class = OpenAIMoeModel - causal_lm_class = OpenAIMoeForCausalLM + config_class = GptOssConfig + base_model_class = GptOssModel + causal_lm_class = GptOssForCausalLM pipeline_model_mapping = ( { - "feature-extraction": OpenAIMoeModel, - "text-generation": OpenAIMoeForCausalLM, + "feature-extraction": GptOssModel, + "text-generation": GptOssForCausalLM, } if is_torch_available() else {} @@ -60,12 +60,12 @@ class OpenAIMoeModelTester(CausalLMModelTester): @require_torch -class OpenAIMoeModelTest(CausalLMModelTest, unittest.TestCase): - all_model_classes = (OpenAIMoeModel, OpenAIMoeForCausalLM) if is_torch_available() else () +class GptOssModelTest(CausalLMModelTest, unittest.TestCase): + all_model_classes = (GptOssModel, GptOssForCausalLM) if is_torch_available() else () pipeline_model_mapping = ( { - "feature-extraction": OpenAIMoeModel, - "text-generation": OpenAIMoeForCausalLM, + "feature-extraction": GptOssModel, + "text-generation": GptOssForCausalLM, } if is_torch_available() else {} @@ -75,68 +75,68 @@ class OpenAIMoeModelTest(CausalLMModelTest, unittest.TestCase): test_pruning = False _is_stateful = True model_split_percents = [0.5, 0.6] - model_tester_class = OpenAIMoeModelTester + model_tester_class = GptOssModelTester def setUp(self): - self.model_tester = OpenAIMoeModelTester(self) - self.config_tester = ConfigTester(self, config_class=OpenAIMoeConfig, hidden_size=37) + self.model_tester = GptOssModelTester(self) + self.config_tester = ConfigTester(self, config_class=GptOssConfig, hidden_size=37) @unittest.skip("Failing because of unique cache (HybridCache)") def test_model_outputs_equivalence(self, **kwargs): pass - @unittest.skip("OpenAIMoe's forcefully disables sdpa due to Sink") + @unittest.skip("GptOss's forcefully disables sdpa due to Sink") def test_sdpa_can_dispatch_non_composite_models(self): pass - @unittest.skip("OpenAIMoe's eager attn/sdpa attn outputs are expected to be different") + @unittest.skip("GptOss's eager attn/sdpa attn outputs are expected to be different") def test_eager_matches_sdpa_generate(self): pass @parameterized.expand([("random",), ("same",)]) @pytest.mark.generate - @unittest.skip("OpenAIMoe has HybridCache which is not compatible with assisted decoding") + @unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_matches_greedy_search(self, assistant_type): pass - @unittest.skip("OpenAIMoe has HybridCache which is not compatible with assisted decoding") + @unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding") def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): pass @pytest.mark.generate - @unittest.skip("OpenAIMoe has HybridCache which is not compatible with assisted decoding") + @unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_sample(self): pass - @unittest.skip("OpenAIMoe has HybridCache which is not compatible with dola decoding") + @unittest.skip("GptOss has HybridCache which is not compatible with dola decoding") def test_dola_decoding_sample(self): pass - @unittest.skip("OpenAIMoe has HybridCache and doesn't support continue from past kv") + @unittest.skip("GptOss has HybridCache and doesn't support continue from past kv") def test_generate_continue_from_past_key_values(self): pass - @unittest.skip("OpenAIMoe has HybridCache and doesn't support contrastive generation") + @unittest.skip("GptOss has HybridCache and doesn't support contrastive generation") def test_contrastive_generate(self): pass - @unittest.skip("OpenAIMoe has HybridCache and doesn't support contrastive generation") + @unittest.skip("GptOss has HybridCache and doesn't support contrastive generation") def test_contrastive_generate_dict_outputs_use_cache(self): pass - @unittest.skip("OpenAIMoe has HybridCache and doesn't support contrastive generation") + @unittest.skip("GptOss has HybridCache and doesn't support contrastive generation") def test_contrastive_generate_low_memory(self): pass - @unittest.skip("OpenAIMoe has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + @unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") def test_generate_with_static_cache(self): pass - @unittest.skip("OpenAIMoe has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + @unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip("OpenAIMoe has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + @unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") def test_generate_continue_from_inputs_embeds(self): pass @@ -147,18 +147,18 @@ def test_generate_continue_from_inputs_embeds(self): def test_multi_gpu_data_parallel_forward(self): pass - @unittest.skip("OpenAIMoe has HybridCache which auto-compiles. Compile and FA2 don't work together.") + @unittest.skip("GptOss has HybridCache which auto-compiles. Compile and FA2 don't work together.") def test_eager_matches_fa2_generate(self): pass - @unittest.skip("OpenAIMoe eager/FA2 attention outputs are expected to be different") + @unittest.skip("GptOss eager/FA2 attention outputs are expected to be different") def test_flash_attn_2_equivalence(self): pass @slow @require_torch_accelerator -class OpenAIMoeIntegrationTest(unittest.TestCase): +class GptOssIntegrationTest(unittest.TestCase): input_text = ["Hello I am doing", "Hi today"] def setUp(self): @@ -239,7 +239,7 @@ def test_model_120b_bf16_use_kernels(self): @slow @require_torch_multi_accelerator -class OpenAIMoeTPTest(TensorParallelTestBase): +class GptOssTPTest(TensorParallelTestBase): def test_model_training(self): self.run_tensor_parallel_test( model_id="/fsx/vb/new-oai/20b-converted-quantized", mode="training", expected_output="you with something?" diff --git a/tests/quantization/mxfp4/test_mxfp4.py b/tests/quantization/mxfp4/test_mxfp4.py index 70f451b25b9f..e51d60204c10 100644 --- a/tests/quantization/mxfp4/test_mxfp4.py +++ b/tests/quantization/mxfp4/test_mxfp4.py @@ -16,7 +16,7 @@ import unittest from unittest.mock import patch -from transformers import AutoTokenizer, Mxfp4Config, OpenAIMoeForCausalLM +from transformers import AutoTokenizer, Mxfp4Config, GptOssForCausalLM from transformers.testing_utils import ( require_torch, require_torch_gpu, @@ -353,7 +353,7 @@ def check_inference_correctness_quantized(self, model, tokenizer): self.assertIn(generated_text, self.EXPECTED_OUTPUTS) - def test_openai_moe_model_loading_quantized_with_device_map(self): + def test_gpt_oss_model_loading_quantized_with_device_map(self): """Test loading OpenAI MoE model with mxfp4 quantization and device_map""" quantization_config = Mxfp4Config(dequantize=False) @@ -361,7 +361,7 @@ def test_openai_moe_model_loading_quantized_with_device_map(self): # Test that config is properly set up self.assertFalse(quantization_config.dequantize) - model = OpenAIMoeForCausalLM.from_pretrained( + model = GptOssForCausalLM.from_pretrained( self.model_name_packed, quantization_config=quantization_config, torch_dtype=torch.bfloat16, @@ -370,7 +370,7 @@ def test_openai_moe_model_loading_quantized_with_device_map(self): tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed) self.check_inference_correctness_quantized(model, tokenizer) - def test_openai_moe_model_loading_dequantized_with_device_map(self): + def test_gpt_oss_model_loading_dequantized_with_device_map(self): """Test loading OpenAI MoE model with mxfp4 dequantization and device_map""" quantization_config = Mxfp4Config(dequantize=True) @@ -378,7 +378,7 @@ def test_openai_moe_model_loading_dequantized_with_device_map(self): # Test that config is properly set up self.assertTrue(quantization_config.dequantize) - model = OpenAIMoeForCausalLM.from_pretrained( + model = GptOssForCausalLM.from_pretrained( self.model_name_packed, quantization_config=quantization_config, torch_dtype=torch.bfloat16, @@ -404,12 +404,12 @@ def test_memory_footprint_comparison(self): # Expected: quantized < dequantized < unquantized memory usage quantization_config = Mxfp4Config(dequantize=True) - quantized_model = OpenAIMoeForCausalLM.from_pretrained( + quantized_model = GptOssForCausalLM.from_pretrained( self.model_name_packed, torch_dtype=torch.bfloat16, device_map="auto", ) - dequantized_model = OpenAIMoeForCausalLM.from_pretrained( + dequantized_model = GptOssForCausalLM.from_pretrained( self.model_name_packed, torch_dtype=torch.bfloat16, device_map="auto", From 0d1a2da474bc6503000730d839f985645015b2f9 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Fri, 1 Aug 2025 13:36:53 +0200 Subject: [PATCH 267/342] upate pre-final version of template --- .../gpt_oss/convert_gpt_oss_weights_to_hf.py | 347 +++++++++++++++--- 1 file changed, 300 insertions(+), 47 deletions(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index f00b80aedee1..23f046b449c6 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -418,73 +418,326 @@ def __init__( def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): # Updated Harmony chat template - chat_template = """{# Harmony chat template -------------------------------------------------- - This template mirrors the message rendering logic implemented in - `harmony/src/encoding.rs`. It can be consumed by Hugging Face - Transformers (``chat_template`` field) so that *text → tokens* - conversion of chat conversations happens fully on the Python side - without relying on the Rust renderer. - - Supported *message* keys (per ``chat::Message``): - - role (user│assistant│system│developer│tool) - - name (optional author name) - - recipient (optional recipient – omitted or "all" → broadcast) - - channel (optional meta channel) - - content_type (optional content-type qualifier) - - content (string – the actual message payload) - - The template renders each historical message *fully* (incl. the - trailing <|end|>/<|return|> sentinel) and – if ``add_generation_prompt`` - is True – appends a partial header for the **next** assistant turn - exactly like ``render_conversation_for_completion`` does on the Rust - side: ``<|start|>assistant``. -#} - -{%- macro harmony_header(m) -%} - <|start|>{% if m['role'] == 'tool' %}{{ m['name'] }}{% else %}{{ m['role'] }}{% if m.get('name') %}:{{ m['name'] }}{% endif %}{% endif %}{% if m.get('recipient') and m['recipient'] != 'all' %} to={{ m['recipient'] }}{% endif %}{% if m.get('channel') %}<|channel|>{{ m['channel'] }}{% endif %}{% if m.get('content_type') %} {{ m['content_type'] }}{% endif %}<|message|> + chat_template = """{#- + In addition to the normal inputs of `messages` and `tools`, this template also accepts the + following kwargs: + - "builtin_tools": A list, can contain "browser" and/or "python". + - "model_identity": A string that optionally describes the model identity". + - "reasoning_effort": A string that describes the reasoning effort, defaults to "medium". + #} + +{#- Tool Definition Rendering ============================================== #} +{%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%} + {%- if param_spec.type == "array" -%} + {%- if param_spec['items'] -%} + {%- if param_spec['items']['type'] == "string" -%} + {{- "string[]" }} + {%- elif param_spec['items']['type'] == "number" -%} + {{- "number[]" }} + {%- elif param_spec['items']['type'] == "integer" -%} + {{- "number[]" }} + {%- elif param_spec['items']['type'] == "boolean" -%} + {{- "boolean[]" }} + {%- else -%} + {%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%} + {%- if inner_type == "object | object" or inner_type|length > 50 -%} + {{- "any[]" }} + {%- else -%} + {{- inner_type + "[]" }} + {%- endif -%} + {%- endif -%} + {%- if param_spec.nullable -%} + {{- " | null" }} + {%- endif -%} + {%- else -%} + {{- "any[]" }} + {%- if param_spec.nullable -%} + {{- " | null" }} + {%- endif -%} + {%- endif -%} + {%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%} + {#- Handle array of types like ["object", "object"] from Union[dict, list] #} + {%- if param_spec.type | length > 1 -%} + {{- param_spec.type | join(" | ") }} + {%- else -%} + {{- param_spec.type[0] }} + {%- endif -%} + {%- elif param_spec.oneOf -%} + {#- Handle oneOf schemas - check for complex unions and fallback to any #} + {%- set has_object_variants = false -%} + {%- for variant in param_spec.oneOf -%} + {%- if variant.type == "object" -%} + {%- set has_object_variants = true -%} + {%- endif -%} + {%- endfor -%} + {%- if has_object_variants and param_spec.oneOf|length > 1 -%} + {{- "any" }} + {%- else -%} + {%- for variant in param_spec.oneOf -%} + {{- render_typescript_type(variant, required_params) -}} + {%- if variant.description %} + {{- "// " + variant.description }} + {%- endif -%} + {%- if variant.default is defined %} + {{ "// default: " + variant.default|tojson }} + {%- endif -%} + {%- if not loop.last %} + {{- " | " }} + {% endif -%} + {%- endfor -%} + {%- endif -%} + {%- elif param_spec.type == "string" -%} + {%- if param_spec.enum -%} + {{- '"' + param_spec.enum|join('" | "') + '"' -}} + {%- else -%} + {{- "string" }} + {%- if param_spec.nullable %} + {{- " | null" }} + {%- endif -%} + {%- endif -%} + {%- elif param_spec.type == "number" -%} + {{- "number" }} + {%- elif param_spec.type == "integer" -%} + {{- "number" }} + {%- elif param_spec.type == "boolean" -%} + {{- "boolean" }} + + {%- elif param_spec.type == "object" -%} + {%- if param_spec.properties -%} + {{- "{\n" }} + {%- for prop_name, prop_spec in param_spec.properties.items() -%} + {{- prop_name -}} + {%- if prop_name not in (param_spec.required or []) -%} + {{- "?" }} + {%- endif -%} + {{- ": " }} + {{ render_typescript_type(prop_spec, param_spec.required or []) }} + {%- if not loop.last -%} + {{-", " }} + {%- endif -%} + {%- endfor -%} + {{- "}" }} + {%- else -%} + {{- "object" }} + {%- endif -%} + {%- else -%} + {{- "any" }} + {%- endif -%} +{%- endmacro -%} + +{%- macro render_tool_namespace(namespace_name, tools) -%} + {{- "## " + namespace_name + "\n\n" }} + {{- "namespace " + namespace_name + " {\n\n" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- "// " + tool.description + "\n" }} + {{- "type "+ tool.name + " = (" }} + {%- if tool.parameters and tool.parameters.properties -%} + {{- "_: " }} + {{- "{\n" }} + {%- for param_name, param_spec in tool.parameters.properties.items() %} + {%- if param_spec.description -%} + {{- "// " + param_spec.description + "\n" }} + {%- endif -%} + {{- param_name }} + {%- if param_name not in (tool.parameters.required or []) -%} + {{- "?" }} + {%- endif -%} + {{- ": " }} + {{- render_typescript_type(param_spec, tool.parameters.required or []) }} + {%- if param_spec.default is defined -%} + {%- if param_spec.oneOf %} + {{- "// default: " + param_spec.default }} + {%- else %} + {{- ", // default: " + param_spec.default|tojson }} + {%- endif -%} + {%- endif -%} + {%- if not loop.last %} + {{- ",\n" }} + {%- endif -%} + {%- endfor %} + {{- ",\n}) => any;\n" }} + {%- else -%} + {{- "\n}) => any;\n" }} + {%- endif -%} + {%- endfor %} + {{- "\n} // namespace " + namespace_name }} +{%- endmacro -%} + +{%- macro render_builtin_tools(browser_tool, python_tool) -%} + {%- if browser_tool %} + {{- "## browser\n\n" }} + {{- "// Tool for browsing.\n" }} + {{- "// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.\n" }} + {{- "// Cite information from the tool using the following format:\n" }} + {{- "// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\n" }} + {{- "// Do not quote more than 10 words directly from the tool output.\n" }} + {{- "// sources=web (default: web)\n" }} + {{- "namespace browser {\n\n" }} + {{- "// Searches for information related to `query` and displays `topn` results.\n" }} + {{- "type search = (_: {\n" }} + {{- "query: string,\n" }} + {{- "topn?: number, // default: 10\n" }} + {{- "source?: string,\n" }} + {{- "}) => any;\n\n" }} + {{- "// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\n" }} + {{- "// Valid link ids are displayed with the formatting: `【{id}†.*】`.\n" }} + {{- "// If `cursor` is not provided, the most recent page is implied.\n" }} + {{- "// If `id` is a string, it is treated as a fully qualified URL associated with `source`.\n" }} + {{- "// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\n" }} + {{- "// Use this function without `id` to scroll to a new location of an opened page.\n" }} + {{- "type open = (_: {\n" }} + {{- "id?: number | string, // default: -1\n" }} + {{- "cursor?: number, // default: -1\n" }} + {{- "loc?: number, // default: -1\n" }} + {{- "num_lines?: number, // default: -1\n" }} + {{- "view_source?: boolean, // default: false\n" }} + {{- "source?: string,\n" }} + {{- "}) => any;\n\n" }} + {{- "// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.\n" }} + {{- "type find = (_: {\n" }} + {{- "pattern: string,\n" }} + {{- "cursor?: number, // default: -1\n" }} + {{- "}) => any;\n\n" }} + {{- "} // namespace browser\n\n" }} + {%- endif -%} + + {%- if python_tool %} + {{- "## python\n\n" }} + {{- "Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\n\n" }} + {{- "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.\n\n" }} + {%- endif -%} +{%- endmacro -%} + +{#- System Message Construction ============================================ #} +{%- macro build_system_message() -%} + {%- if model_identity is not defined %} + {{- "You are ChatGPT, a large language model trained by OpenAI.\n" -}} + {%- else %} + {{- model_identity }} + {%- endif %} + {{- "Knowledge cutoff: 2024-06\n" }} + {{- "Current date: " + strftime_now("%Y-%m-%d") + "\n\n" }} + {%- if reasoning_effort is not defined %} + {%- set reasoning_effort = "medium" %} + {%- endif %} + {{- "reasoning: " + reasoning_effort + "\n\n" }} + {%- if builtin_tools %} + {{- "# Tools\n\n" }} + {%- set available_builtin_tools = namespace(browser=false, python=false) %} + {%- for tool in builtin_tools %} + {%- if tool == "browser" %} + {%- set available_builtin_tools.browser = true %} + {%- elif tool == "python" %} + {%- set available_builtin_tools.python = true %} + {%- endif %} + {%- endfor %} + {{- render_builtin_tools(available_builtin_tools.browser, available_builtin_tools.python) }} + {%- endif -%} + {{- "# Valid channels: analysis, commentary, final. Channel must be included for every message.\n" }} + {{- "Calls to these tools must go to the commentary channel: 'functions'." }} {%- endmacro -%} -{# Add CoT dropping logic -------------------------------------------- #} -{%- set last_final_idx = None -%} +{#- CoT Dropping Logic ================================================== #} +{%- set cot_final_indices = [] -%} {%- for idx in range(messages|length) -%} {%- set m = messages[idx] -%} - {%- if m['role'] == 'assistant' and m.get('channel') == 'final' -%} - {%- set last_final_idx = idx -%} + {%- if m.role == 'assistant' and m.get('channel', '') == 'final' -%} + {%- if cot_final_indices.append(idx) -%}{%- endif -%} {%- endif -%} {%- endfor -%} -{%- set last_user_idx = None -%} -{%- if last_final_idx is not none -%} - {%- for idx in range(last_final_idx - 1, -1, -1) -%} - {%- if messages[idx]['role'] == 'user' -%} - {%- set last_user_idx = idx -%} - {%- break -%} +{%- set cot_last_final_idx = cot_final_indices[-1] if cot_final_indices else none -%} +{%- set cot_last_user_idx = none -%} +{%- if cot_last_final_idx is not none -%} + {%- for idx in range(cot_last_final_idx - 1, -1, -1) -%} + {%- if messages[idx].role == 'user' and cot_last_user_idx is none -%} + {%- set cot_last_user_idx = idx -%} {%- endif -%} {%- endfor -%} {%- endif -%} -{# --------------------------------------------------------------------- - Render complete history (with CoT dropping) -#} -{%- for idx in range(messages|length) -%} - {%- set message = messages[idx] -%} +{#- Main Template Logic ================================================= #} +{#- Set defaults #} +{%- set auto_drop = auto_drop_analysis if auto_drop_analysis is defined else true -%} + +{#- Render system message #} +{{- "<|start|>system<|message|>" }} +{{- build_system_message() }} +{{- "<|end|>" }} + +{#- Extract developer message #} +{%- if messages[0].role == "developer" or messages[0].role == "system" %} + {%- set developer_message = messages[0].content %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set developer_message = "" %} + {%- set loop_messages = messages %} +{%- endif %} + +{#- Render developer message #} +{%- if developer_message or tools %} + {{- "<|start|>developer<|message|>" }} + {%- if developer_message %} + {{- "# Instructions\n\n" }} + {{- developer_message }} + {%- endif %} + {%- if tools -%} + {{- "\n\n" }} + {{- "# Tools\n\n" }} + {{- render_tool_namespace("functions", tools) }} + {%- endif -%} + {{- "<|end|>" }} +{%- endif %} + +{#- Render messages #} +{%- set last_tool_call = namespace(name=none) %} +{%- for message in loop_messages -%} {%- set skip = false -%} - {%- if last_final_idx is not none and idx < last_final_idx and (last_user_idx is none or idx > last_user_idx) -%} - {%- if message['role'] == 'assistant' and message.get('channel') != 'final' -%} + + {# Apply CoT dropping logic #} + {%- if auto_drop and cot_last_final_idx is not none and loop.index0 < cot_last_final_idx -%} + {%- if message.role == 'assistant' and message.get('channel', '') != 'final' -%} + {%- if cot_last_user_idx is none or loop.index0 > cot_last_user_idx -%} + {%- set skip = true -%} + {%- endif -%} + {%- elif message.role == 'user' and message.get('channel', '') == 'analysis' -%} {%- set skip = true -%} {%- endif -%} {%- endif -%} + {%- if not skip -%} - {{- harmony_header(message) -}}{{ message['content'] }}{%- if message['role'] == 'assistant' and message.get('channel') == 'final' -%}<|return|>{%- else -%}<|end|>{%- endif -%} + {#- At this point only assistant/user/tool messages should remain #} + {%- if message.role == 'assistant' -%} + {%- if "tool_calls" in message %} + {# I'm assuming max 1 tool call per message here, which might be wrong #} + {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content }} + {{- "<|end|><|start|>assistant to=" }} + {{- "functions." + message.tool_calls[0].name + "<|channel|>commentary json<|message|>" }} + {{- message.tool_calls[0].arguments|tojson }} + {{- "<|end|>" }} + {%- set last_tool_call.name = message.tool_calls[0].name %} + {%- elif "thinking" in message %} + {#- CoT is dropped during all model inputs, so we never actually render it #} + {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }} + {%- else %} + {{- "<|start|>assistant<|message|>" + message.content + "<|end|>" }} + {%- endif %} + {%- elif message.role == 'tool' -%} + {%- if last_tool_call.name is none %} + {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }} + {%- endif %} + {{- "<|start|>functions." + last_tool_call.name }} + {{- " to=assistant<|channel|>commentary<|message|>" + message.content|tojson + "<|end|>" }} + {%- else -%} + {{- "<|start|>user<|message|>" + message.content + "<|end|>" }} + {%- endif -%} {%- endif -%} {%- endfor -%} -{# --------------------------------------------------------------------- - Generation prompt for *next* assistant answer -#} +{#- Generation prompt #} {%- if add_generation_prompt -%} <|start|>assistant -{%- endif -%} -""" +{%- endif -%}""" converter = GptOssConverter( vocab_file=tokenizer_path, From 5f3de46c8cc239367c0047e21f91bfbd5e515c50 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 1 Aug 2025 13:41:30 +0200 Subject: [PATCH 268/342] Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py --- .../models/gpt_oss/convert_gpt_oss_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 23f046b449c6..6e5e4a0b16e4 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -422,7 +422,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): In addition to the normal inputs of `messages` and `tools`, this template also accepts the following kwargs: - "builtin_tools": A list, can contain "browser" and/or "python". - - "model_identity": A string that optionally describes the model identity". + - "model_identity": A string that optionally describes the model identity. - "reasoning_effort": A string that describes the reasoning effort, defaults to "medium". #} From bddc8c2a3b3c0b862f0575aea67580b3d67723d5 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Fri, 1 Aug 2025 12:23:36 +0000 Subject: [PATCH 269/342] fix batched inference --- src/transformers/integrations/mxfp4.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 01c384a63855..9a572ba2b11f 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -291,11 +291,12 @@ def mlp_forward(self, hidden_states): from triton_kernels.routing import routing routing = routing - + batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) + routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) return routed_out, router_logits From 06b35eb51ba269353b916ea86f660e55e3aed1db Mon Sep 17 00:00:00 2001 From: "joao@huggingface.co" Date: Fri, 1 Aug 2025 15:07:25 +0000 Subject: [PATCH 270/342] serve fixes --- src/transformers/commands/serving.py | 72 ++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 06a3bc6b9205..a5f87b2b5171 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -899,7 +899,16 @@ def generate_chat_completion(self, req: dict) -> Generator[str, None, None]: inputs = inputs.to(model.device) request_id = req.get("request_id", "req_0") - generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True) + # Temporary hack for GPTOSS 1: don't filter special tokens + skip_special_tokens = True + if "gptoss" in model.config.architectures[0].lower(): + skip_special_tokens = False + + generation_streamer = TextIteratorStreamer( + processor, + skip_special_tokens=skip_special_tokens, + skip_prompt=True, + ) generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config) last_kv_cache = None @@ -915,12 +924,21 @@ def generate_chat_completion(self, req: dict) -> Generator[str, None, None]: } def stream_chat_completion(streamer, _request_id): + # Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output + # classes and piping the reasoning trace into a new field + filter_cot = False + cot_trace_end = None + if "gptoss" in model.config.architectures[0].lower(): + filter_cot = True + cot_trace_end = "<|channel|>final<|message|>" + # Thin wrapper to save the KV cache after generation def generate_with_cache(**kwargs): generate_output = model.generate(**kwargs) self.last_kv_cache = generate_output.past_key_values thread = Thread(target=generate_with_cache, kwargs=generation_kwargs) + results = "" try: thread.start() @@ -931,6 +949,21 @@ def generate_with_cache(**kwargs): yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision) for result in streamer: + + # Temporary hack for GPTOS 3: don't emit the final "<|return|>" + if "gptoss" in model.config.architectures[0].lower(): + if result.endswith("<|return|>"): + result = result[:-len("<|return|>")] + results += result + + # (related to temporary hack 2) + if filter_cot: + if cot_trace_end in results: # end of reasoning trace observed -> stop filtering + filter_cot = False + continue + else: + continue + # ====== TOOL CALL LOGIC ====== if tool_model_family is not None: # Start of a tool call: reset state variables, set `inside_tool_call` @@ -1035,7 +1068,16 @@ def generate_response(self, req: dict) -> Generator[str, None, None]: inputs = processor.apply_chat_template(req["input"], add_generation_prompt=True).to(model.device) request_id = req.get("previous_response_id", "req_0") - generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True) + # Temporary hack for GPTOSS 1: don't filter special tokens + skip_special_tokens = True + if "gptoss" in model.config.architectures[0].lower(): + skip_special_tokens = False + + generation_streamer = TextIteratorStreamer( + processor, + skip_special_tokens=skip_special_tokens, + skip_prompt=True, + ) generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config) last_kv_cache = None @@ -1052,6 +1094,15 @@ def generate_response(self, req: dict) -> Generator[str, None, None]: } def stream_response(streamer, _request_id): + + # Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output + # classes and piping the reasoning trace into a new field + filter_cot = False + cot_trace_end = None + if "gptoss" in model.config.architectures[0].lower(): + filter_cot = True + cot_trace_end = "<|channel|>final<|message|>" + # Thin wrapper to save the KV cache after generation def generate_with_cache(**kwargs): generate_output = model.generate(**kwargs) @@ -1138,7 +1189,21 @@ def generate_with_cache(**kwargs): # Stream the actual generated text results = "" for result in streamer: + + # Temporary hack for GPTOS 3: don't emit the final "<|return|>" + if "gptoss" in model.config.architectures[0].lower(): + if result.endswith("<|return|>"): + result = result[:-len("<|return|>")] results += result + + # (related to temporary hack 2) + if filter_cot: + if cot_trace_end in results: # end of reasoning trace observed -> stop filtering + filter_cot = False + continue + else: + continue + response_output_text_delta = ResponseTextDeltaEvent( type="response.output_text.delta", item_id=f"msg_{request_id}", @@ -1417,9 +1482,10 @@ def _load_model_and_data_processor(self, model_id_and_revision: str): "attn_implementation": args.attn_implementation, "torch_dtype": torch_dtype, "device_map": "auto", - "quantization_config": quantization_config, "trust_remote_code": args.trust_remote_code, } + if quantization_config is not None: + model_kwargs["quantization_config"] = quantization_config config = AutoConfig.from_pretrained(model_id, **model_kwargs) architecture = getattr(transformers, config.architectures[0]) From 0de8f62737e3ac56ed55a380cec804e40196c44e Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 1 Aug 2025 15:48:41 +0000 Subject: [PATCH 271/342] swizzle ! --- src/transformers/integrations/__init__.py | 4 +- src/transformers/integrations/mxfp4.py | 117 +++++++----------- src/transformers/models/__init__.py | 2 +- .../quantizers/quantizer_mxfp4.py | 4 +- tests/quantization/mxfp4/test_mxfp4.py | 2 +- 5 files changed, 48 insertions(+), 81 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 683eb6c9e73a..14fdae321e47 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -125,7 +125,7 @@ "quantize_to_mxfp4", "convert_moe_packed_tensors", "dequantize", - "dequantize_and_quantize", + "load_and_swizzle_mxfp4", ], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], @@ -266,7 +266,7 @@ from .mxfp4 import ( Mxfp4GptOssExperts, dequantize, - dequantize_and_quantize, + load_and_swizzle_mxfp4, quantize_to_mxfp4, replace_with_mxfp4_linear, ) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 01c384a63855..e84f2183b4cf 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..modeling_utils import is_deepspeed_zero3_enabled, is_fsdp_enabled from ..utils import is_accelerate_available, is_torch_available, logging @@ -51,11 +50,15 @@ # Copied from GPT_OSS repo and vllm def quantize_to_mxfp4(w): from triton_kernels.numerics_details.mxfp import downcast_to_mxfp + w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1) + w, w_scale = swizzle_mxfp4(w, w_scale) + return w, w_scale + +def swizzle_mxfp4(w, w_scale): from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor from triton_kernels.tensor_details import layout from triton_kernels.tensor_details.layout import StridedLayout - w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1) value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) @@ -68,12 +71,10 @@ def quantize_to_mxfp4(w): # opt_flags.update_opt_flags_constraints(constraints) # # transpose the tensor so that the quantization axis is on dim1 - # TODO: there is still an issue with the scales on hopper # scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=8) # w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts) w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) - return w, w_scale # Copied from GPT_OSS repo @@ -121,7 +122,7 @@ def convert_moe_packed_tensors( sub[:, 1::2] = lut[idx_hi] torch.ldexp(sub, exp, out=sub) - del idx_lo, idx_hi, blk, exp + del idx_lo, idx_hi, blk, exp, sub out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) @@ -129,7 +130,7 @@ def convert_moe_packed_tensors( # Move back to CPU if needed # if need_to_move_back: # out = out.cpu() - del blocks, scales + del blocks, scales, lut return out @@ -140,47 +141,35 @@ def __init__(self, config): self.num_experts = config.num_local_experts self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size - self.expert_dim = self.intermediate_size self.gate_up_proj_blocks = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size // 32, 16, dtype=torch.uint8), + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), requires_grad=False, ) self.gate_up_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size // 32, dtype=torch.uint8), + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8), requires_grad=False, ) self.gate_up_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.expert_dim, dtype=torch.float32), requires_grad=False + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False ) self.down_proj_blocks = nn.Parameter( - torch.zeros((self.num_experts, self.expert_dim, self.hidden_size // 32, 16), dtype=torch.uint8), + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), requires_grad=False, ) self.down_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, self.expert_dim, self.hidden_size // 32, dtype=torch.uint8), + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), requires_grad=False, ) self.down_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, self.expert_dim, dtype=torch.float32), requires_grad=False + torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False ) self.alpha = 1.702 self.gate_up_proj_precision_config = None self.down_proj_precision_config = None - # TODO: To remove once we make sure that we don't need this - # smallest_even_divide_number = lambda x, n: (x // n + 1) * n if x % n != 0 else x - - self.gate_up_proj_right_pad = ( - 0 # smallest_even_divide_number(self.intermediate_size * 2, 256) - self.intermediate_size * 2 - ) - self.gate_up_proj_bottom_pad = 0 - self.down_proj_right_pad = 0 # smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size - self.down_proj_bottom_pad = 0 # self.gate_up_proj_right_pad // 2 - self.hidden_size_pad = 0 # smallest_even_divide_number(self.hidden_size, 256) - self.hidden_size - def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs from triton_kernels.swiglu import swiglu_fn @@ -188,11 +177,6 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter with torch.cuda.device(hidden_states.device): act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2) - if self.hidden_size_pad is not None: - hidden_states = torch.nn.functional.pad( - hidden_states, (0, self.hidden_size_pad, 0, 0), mode="constant", value=0 - ) - intermediate_cache1 = matmul_ogs( hidden_states, self.gate_up_proj, @@ -241,13 +225,13 @@ def routing_torch_dist( n_gates_pad = n_tokens * n_expts_act - def topk(vals, k, expt_indx): + def topk(vals, k): tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k] tk_indx = tk_indx.long() tk_val = torch.take_along_dim(vals, tk_indx, dim=1) return tk_val, tk_indx.int() - expt_scal, expt_indx = topk(logits, n_expts_act, None) + expt_scal, expt_indx = topk(logits, n_expts_act) expt_scal = torch.softmax(expt_scal, dim=-1) expt_indx, sort_indices = torch.sort(expt_indx, dim=1) expt_scal = torch.gather(expt_scal, 1, sort_indices) @@ -336,7 +320,6 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** scales_attr = f"{proj}_scales" if not hasattr(module, blocks_attr) and not hasattr(module, scales_attr): setattr(module, param_name.rsplit(".", 1)[1], param_value) - return else: setattr(module, param_name.rsplit(".", 1)[1], param_value) dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) @@ -347,10 +330,8 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** setattr(module, proj, torch.nn.Parameter(dequantized)) delattr(module, blocks_attr) delattr(module, scales_attr) - return - -def dequantize_and_quantize( +def load_and_swizzle_mxfp4( module, param_name, param_value, target_device, **kwargs ): from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig @@ -364,58 +345,44 @@ def dequantize_and_quantize( to_contiguous = kwargs.get("to_contiguous", None) rank = kwargs.get("rank", None) device_mesh = kwargs.get("device_mesh", None) - # Combine logic for gate_up_proj and down_proj + for proj in ["gate_up_proj", "down_proj"]: if proj in param_name: + if device_mesh is not None: + shard_and_distribute_module( + model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh + ) + else: + _load_parameter_into_model(model, param_name, param_value) blocks_attr = f"{proj}_blocks" scales_attr = f"{proj}_scales" - right_pad_attr = f"{proj}_right_pad" - bottom_pad_attr = f"{proj}_bottom_pad" - precision_config_attr = f"{proj}_precision_config" - - # Check if both blocks and scales are still on meta device blocks = getattr(module, blocks_attr) scales = getattr(module, scales_attr) - if blocks.device.type == "meta" and scales.device.type == "meta": - if device_mesh is not None: - shard_and_distribute_module( - model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh - ) + # Check if both blocks and scales both not on on meta device + if blocks.device.type != "meta" and scales.device.type != "meta": + # need it for ep + local_experts = getattr(module, blocks_attr).size(0) + if proj == "gate_up_proj": + blocks = module.gate_up_proj_blocks.view(local_experts, module.intermediate_size * 2, -1) else: - _load_parameter_into_model(model, param_name, param_value) - return - else: - # One of the params is already loaded, so load the other - if device_mesh is not None: - shard_and_distribute_module( - model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh - ) - else: - _load_parameter_into_model(model, param_name, param_value) + blocks = module.down_proj_blocks.view(local_experts, -1, module.intermediate_size // 2) + with torch.cuda.device(target_device): + triton_weight_tensor, weight_scale = swizzle_mxfp4(blocks.transpose(-2, -1), getattr(module, scales_attr).transpose(-2, -1)) - dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) - dequantized = dequantized.transpose(1, 2).contiguous().to(target_device) + # need to overwrite the shapes for the kernels + if proj == "gate_up_proj": + triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2]) + else: + triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size]) - right_pad = getattr(module, right_pad_attr) - bottom_pad = getattr(module, bottom_pad_attr) - - dequantized = torch.nn.functional.pad( - dequantized, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0 - ) - original_device = target_device - # for fsdp and deepspeed since the model is load on cpu, we need to move the weight to gpu for quantization - if (is_fsdp_enabled() or is_deepspeed_zero3_enabled()) and target_device == "cpu": - dequantized = dequantized.cuda() - target_device = "cuda" - with torch.cuda.device(target_device): - triton_weight_tensor, weight_scale = quantize_to_mxfp4(dequantized) - triton_weight_tensor.storage.data = triton_weight_tensor.storage.data.to(original_device) - setattr(module, precision_config_attr, PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()))) # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It is like a subtensor setattr(module, proj, triton_weight_tensor) - setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False)) - return + setattr(module, f"{proj}_precision_config", PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()))) + # delete blocks and scales + delattr(module, scales_attr) + delattr(module, blocks_attr) + # setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False)) def _replace_with_mxfp4_linear( model, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 275b0aea6dce..80144b40dbb1 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -138,6 +138,7 @@ from .gpt_neo import * from .gpt_neox import * from .gpt_neox_japanese import * + from .gpt_oss import * from .gpt_sw3 import * from .gptj import * from .granite import * @@ -234,7 +235,6 @@ from .omdet_turbo import * from .oneformer import * from .openai import * - from .gpt_oss import * from .opt import * from .owlv2 import * from .owlvit import * diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index c64902e974d0..5b5a87fa1863 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -152,7 +152,7 @@ def create_quantized_param( if is_triton_kernels_availalble(): from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig - from ..integrations import Mxfp4GptOssExperts, dequantize, dequantize_and_quantize, quantize_to_mxfp4 + from ..integrations import Mxfp4GptOssExperts, dequantize, load_and_swizzle_mxfp4, quantize_to_mxfp4 from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts if not self.pre_quantized: @@ -214,7 +214,7 @@ def create_quantized_param( dq_param_name = param_name[: -len("_blocks")] dequantize(module, param_name, param_value, target_device, dq_param_name, **shard_kwargs) else: - dequantize_and_quantize( + load_and_swizzle_mxfp4( module, param_name, param_value, diff --git a/tests/quantization/mxfp4/test_mxfp4.py b/tests/quantization/mxfp4/test_mxfp4.py index e51d60204c10..2194c2d3219e 100644 --- a/tests/quantization/mxfp4/test_mxfp4.py +++ b/tests/quantization/mxfp4/test_mxfp4.py @@ -16,7 +16,7 @@ import unittest from unittest.mock import patch -from transformers import AutoTokenizer, Mxfp4Config, GptOssForCausalLM +from transformers import AutoTokenizer, GptOssForCausalLM, Mxfp4Config from transformers.testing_utils import ( require_torch, require_torch_gpu, From aca1e72b7a68f2e190735efce137e9eadcd5bd64 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Fri, 1 Aug 2025 18:11:49 +0200 Subject: [PATCH 272/342] update final chat template by Matt. --- .../gpt_oss/convert_gpt_oss_weights_to_hf.py | 97 +++++++------------ 1 file changed, 35 insertions(+), 62 deletions(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 6e5e4a0b16e4..f45b7e4d9e76 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -422,7 +422,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): In addition to the normal inputs of `messages` and `tools`, this template also accepts the following kwargs: - "builtin_tools": A list, can contain "browser" and/or "python". - - "model_identity": A string that optionally describes the model identity. + - "model_identity": A string that optionally describes the model identity". - "reasoning_effort": A string that describes the reasoning effort, defaults to "medium". #} @@ -536,9 +536,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {{- "_: " }} {{- "{\n" }} {%- for param_name, param_spec in tool.parameters.properties.items() %} - {%- if param_spec.description -%} - {{- "// " + param_spec.description + "\n" }} - {%- endif -%} + {{- "// " + param_spec.description + "\n" }} {{- param_name }} {%- if param_name not in (tool.parameters.required or []) -%} {{- "?" }} @@ -638,27 +636,8 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {{- "Calls to these tools must go to the commentary channel: 'functions'." }} {%- endmacro -%} -{#- CoT Dropping Logic ================================================== #} -{%- set cot_final_indices = [] -%} -{%- for idx in range(messages|length) -%} - {%- set m = messages[idx] -%} - {%- if m.role == 'assistant' and m.get('channel', '') == 'final' -%} - {%- if cot_final_indices.append(idx) -%}{%- endif -%} - {%- endif -%} -{%- endfor -%} -{%- set cot_last_final_idx = cot_final_indices[-1] if cot_final_indices else none -%} -{%- set cot_last_user_idx = none -%} -{%- if cot_last_final_idx is not none -%} - {%- for idx in range(cot_last_final_idx - 1, -1, -1) -%} - {%- if messages[idx].role == 'user' and cot_last_user_idx is none -%} - {%- set cot_last_user_idx = idx -%} - {%- endif -%} - {%- endfor -%} -{%- endif -%} - {#- Main Template Logic ================================================= #} {#- Set defaults #} -{%- set auto_drop = auto_drop_analysis if auto_drop_analysis is defined else true -%} {#- Render system message #} {{- "<|start|>system<|message|>" }} @@ -692,45 +671,39 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {#- Render messages #} {%- set last_tool_call = namespace(name=none) %} {%- for message in loop_messages -%} - {%- set skip = false -%} - - {# Apply CoT dropping logic #} - {%- if auto_drop and cot_last_final_idx is not none and loop.index0 < cot_last_final_idx -%} - {%- if message.role == 'assistant' and message.get('channel', '') != 'final' -%} - {%- if cot_last_user_idx is none or loop.index0 > cot_last_user_idx -%} - {%- set skip = true -%} - {%- endif -%} - {%- elif message.role == 'user' and message.get('channel', '') == 'analysis' -%} - {%- set skip = true -%} - {%- endif -%} - {%- endif -%} - - {%- if not skip -%} - {#- At this point only assistant/user/tool messages should remain #} - {%- if message.role == 'assistant' -%} - {%- if "tool_calls" in message %} - {# I'm assuming max 1 tool call per message here, which might be wrong #} - {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content }} - {{- "<|end|><|start|>assistant to=" }} - {{- "functions." + message.tool_calls[0].name + "<|channel|>commentary json<|message|>" }} - {{- message.tool_calls[0].arguments|tojson }} - {{- "<|end|>" }} - {%- set last_tool_call.name = message.tool_calls[0].name %} - {%- elif "thinking" in message %} - {#- CoT is dropped during all model inputs, so we never actually render it #} - {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }} - {%- else %} - {{- "<|start|>assistant<|message|>" + message.content + "<|end|>" }} - {%- endif %} - {%- elif message.role == 'tool' -%} - {%- if last_tool_call.name is none %} - {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }} - {%- endif %} - {{- "<|start|>functions." + last_tool_call.name }} - {{- " to=assistant<|channel|>commentary<|message|>" + message.content|tojson + "<|end|>" }} - {%- else -%} - {{- "<|start|>user<|message|>" + message.content + "<|end|>" }} - {%- endif -%} + {#- At this point only assistant/user/tool messages should remain #} + {%- if message.role == 'assistant' -%} + {%- if "tool_calls" in message %} + {#- We assume max 1 tool call per message, and so we infer the tool call name #} + {#- in "tool" messages from the most recent assistant tool call name #} + {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content }} + {{- "<|end|><|start|>assistant to=" }} + {{- "functions." + message.tool_calls[0].name + "<|channel|>commentary json<|message|>" }} + {{- message.tool_calls[0].arguments|tojson }} + {{- "<|end|>" }} + {%- set last_tool_call.name = message.tool_calls[0].name %} + {%- elif "thinking" in message and loop.last and not add_generation_prompt %} + {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #} + {#- This is a situation that should only occur in training, never in inference. #} + {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }} + {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }} + {%- set last_tool_call.name = none %} + {%- elif "thinking" in message %} + {#- CoT is dropped during all previous turns, so we never render it for inference #} + {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }} + {%- set last_tool_call.name = none %} + {%- else %} + {{- "<|start|>assistant<|message|>" + message.content + "<|end|>" }} + {%- set last_tool_call.name = none %} + {%- endif %} + {%- elif message.role == 'tool' -%} + {%- if last_tool_call.name is none %} + {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }} + {%- endif %} + {{- "<|start|>functions." + last_tool_call.name }} + {{- " to=assistant<|channel|>commentary<|message|>" + message.content|tojson + "<|end|>" }} + {%- else -%} + {{- "<|start|>user<|message|>" + message.content + "<|end|>" }} {%- endif -%} {%- endfor -%} From a8c3c4930c17859f01b7f408b99531307f041258 Mon Sep 17 00:00:00 2001 From: "joao@huggingface.co" Date: Fri, 1 Aug 2025 16:12:04 +0000 Subject: [PATCH 273/342] fix responses; pin oai --- setup.py | 2 +- src/transformers/commands/serving.py | 32 ++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index c728d6396e1e..c8470306c3c5 100644 --- a/setup.py +++ b/setup.py @@ -137,7 +137,7 @@ "onnxconverter-common", "onnxruntime-tools>=1.4.2", "onnxruntime>=1.4.0", - "openai", + "openai>=1.98.0", "opencv-python", "optimum-benchmark>=0.3.0", "optuna", diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index a5f87b2b5171..8bcc9df8e40b 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -949,11 +949,10 @@ def generate_with_cache(**kwargs): yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision) for result in streamer: - # Temporary hack for GPTOS 3: don't emit the final "<|return|>" if "gptoss" in model.config.architectures[0].lower(): if result.endswith("<|return|>"): - result = result[:-len("<|return|>")] + result = result[: -len("<|return|>")] results += result # (related to temporary hack 2) @@ -1065,7 +1064,27 @@ def generate_response(self, req: dict) -> Generator[str, None, None]: self.last_model = model_id_and_revision model, processor = self.load_model_and_processor(model_id_and_revision) - inputs = processor.apply_chat_template(req["input"], add_generation_prompt=True).to(model.device) + if isinstance(req["input"], str): + inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] + inputs.append({"role": "user", "content": req["input"]}) + elif isinstance(req["input"], list): + if "instructions" in req: + if req["input"][0]["role"] != "system": + inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]] + else: + inputs = req["input"] + inputs[0]["content"] = req["instructions"] + else: + inputs = req["input"] + elif isinstance(req["input"], dict): + inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] + inputs.append(req["input"]) + else: + raise ValueError("inputs should be a list, dict, or str") + + inputs = processor.apply_chat_template(inputs, add_generation_prompt=True, return_tensors="pt") + inputs = inputs.to(model.device) + request_id = req.get("previous_response_id", "req_0") # Temporary hack for GPTOSS 1: don't filter special tokens @@ -1094,7 +1113,6 @@ def generate_response(self, req: dict) -> Generator[str, None, None]: } def stream_response(streamer, _request_id): - # Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output # classes and piping the reasoning trace into a new field filter_cot = False @@ -1189,17 +1207,17 @@ def generate_with_cache(**kwargs): # Stream the actual generated text results = "" for result in streamer: - # Temporary hack for GPTOS 3: don't emit the final "<|return|>" if "gptoss" in model.config.architectures[0].lower(): if result.endswith("<|return|>"): - result = result[:-len("<|return|>")] + result = result[: -len("<|return|>")] results += result # (related to temporary hack 2) if filter_cot: if cot_trace_end in results: # end of reasoning trace observed -> stop filtering filter_cot = False + results = "" # reset the results -> results will now track the final response continue else: continue @@ -1211,6 +1229,7 @@ def generate_with_cache(**kwargs): output_index=output_index, content_index=content_index, delta=result, + logprobs=[{"token": "", "logprob": 99.9}], # TODO: add actual logprobs ) sequence_number += 1 yield self.build_response_event(response_output_text_delta) @@ -1223,6 +1242,7 @@ def generate_with_cache(**kwargs): output_index=output_index, content_index=0, text=results, + logprobs=[{"token": "", "logprob": 99.9}], # TODO: add actual logprobs ) sequence_number += 1 yield self.build_response_event(response_output_text_done) From 33636c91490725df4849c151297f5afd47192e97 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 1 Aug 2025 16:17:15 +0000 Subject: [PATCH 274/342] sinplify --- src/transformers/integrations/mxfp4.py | 12 ++++-------- src/transformers/quantizers/quantizer_mxfp4.py | 3 +++ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 9b4e57bed201..46ac4d3f106a 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -61,7 +61,6 @@ def swizzle_mxfp4(w, w_scale): value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) - # TODO : add that when we are actually sure that it works on B200 # if torch.cuda.get_device_capability()[0] == 10: # constraints = { @@ -319,10 +318,8 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** ) blocks_attr = f"{proj}_blocks" scales_attr = f"{proj}_scales" - if not hasattr(module, blocks_attr) and not hasattr(module, scales_attr): - setattr(module, param_name.rsplit(".", 1)[1], param_value) - else: - setattr(module, param_name.rsplit(".", 1)[1], param_value) + setattr(module, param_name.rsplit(".", 1)[1], param_value) + if hasattr(module, blocks_attr) and hasattr(module, scales_attr): dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) dequantized = dequantized.transpose(1, 2).contiguous().to(target_device) # TODO: this is perhaps necessary since if target_device is cpu, and the param was on gpu @@ -338,7 +335,6 @@ def load_and_swizzle_mxfp4( from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig from ..integrations.tensor_parallel import shard_and_distribute_module - from ..modeling_utils import _load_parameter_into_model model = kwargs.get("model", None) empty_param = kwargs.get("empty_param", None) @@ -354,7 +350,7 @@ def load_and_swizzle_mxfp4( model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh ) else: - _load_parameter_into_model(model, param_name, param_value) + setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False)) blocks_attr = f"{proj}_blocks" scales_attr = f"{proj}_scales" blocks = getattr(module, blocks_attr) @@ -384,7 +380,7 @@ def load_and_swizzle_mxfp4( delattr(module, scales_attr) delattr(module, blocks_attr) # setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False)) - + del blocks def _replace_with_mxfp4_linear( model, modules_to_not_convert=None, diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 5b5a87fa1863..679199a124f7 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -226,6 +226,9 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs # we are not really dequantizing, we are just removing everthing related to quantization here if self.quantization_config.dequantize: self.remove_quantization_config(model) + # clean cache due to triton ops + if not torch.cuda.is_available(): + torch.cuda.empty_cache() def update_expected_keys(self, model: "PreTrainedModel", expected_keys: list[str], checkpoint_keys: list[str]): # Replace expected_keys for experts' gate_up_proj and down_proj with their _blocks and _scales variants From af6fb990865337cd88052f28b92cd35496cb66eb Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Fri, 1 Aug 2025 18:18:46 +0200 Subject: [PATCH 275/342] Thanks Matt for his tireless efforts! Co-authored-by: Rocketknight1 From 6f91a55afd14fb8521c4d55a776206b4a37a9a52 Mon Sep 17 00:00:00 2001 From: vb Date: Fri, 1 Aug 2025 18:21:36 +0200 Subject: [PATCH 276/342] Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py Co-authored-by: Matt --- .../models/gpt_oss/convert_gpt_oss_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index f45b7e4d9e76..5334169a2a52 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -422,7 +422,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): In addition to the normal inputs of `messages` and `tools`, this template also accepts the following kwargs: - "builtin_tools": A list, can contain "browser" and/or "python". - - "model_identity": A string that optionally describes the model identity". + - "model_identity": A string that optionally describes the model identity. - "reasoning_effort": A string that describes the reasoning effort, defaults to "medium". #} From afe89129805abb315ccb94dae0c7f7c78ac3322f Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 1 Aug 2025 17:19:09 +0000 Subject: [PATCH 277/342] fix --- src/transformers/integrations/mxfp4.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 46ac4d3f106a..87af16791597 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -363,6 +363,11 @@ def load_and_swizzle_mxfp4( blocks = module.gate_up_proj_blocks.view(local_experts, module.intermediate_size * 2, -1) else: blocks = module.down_proj_blocks.view(local_experts, -1, module.intermediate_size // 2) + + # TODO: we need to have the weights on cuda, refactor later + if target_device == "cpu": + target_device = "cuda" + with torch.cuda.device(target_device): triton_weight_tensor, weight_scale = swizzle_mxfp4(blocks.transpose(-2, -1), getattr(module, scales_attr).transpose(-2, -1)) From 7e540fc35e4bb6ef613446ad49bc85bfb8c8b850 Mon Sep 17 00:00:00 2001 From: Akos Hadnagy Date: Fri, 1 Aug 2025 18:21:36 +0000 Subject: [PATCH 278/342] Use ROCm kernels from HUB --- setup.py | 2 +- src/transformers/integrations/hub_kernels.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index c8470306c3c5..6caa70773f7b 100644 --- a/setup.py +++ b/setup.py @@ -128,7 +128,7 @@ # Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support. "keras>2.9,<2.16", "keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras. - "kernels>=0.6.1,<0.7", + "kernels>=0.6.1,=<0.9", "librosa", "natten>=0.14.6,<0.15.0", "nltk<=3.8.1", diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index e824a5ab1f0e..3a19d3c6ed72 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -44,6 +44,11 @@ repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm", # revision="pure-layer-test", + ), + "rocm": LayerRepository( + repo_id="kernels-community/liger_kernels", + layer_name="LigerRMSNorm", + # revision="pure-layer-test", ) }, "MLP": { @@ -56,7 +61,11 @@ "cuda": LayerRepository( repo_id="kernels-community/megablocks", layer_name="MegaBlocksMoeMLP", - ) + ), + "rocm": LayerRepository( + repo_id="ahadnagy/megablocks", + layer_name="MegaBlocksMoeMLP", + ), }, } From 3e4ad36addc473162cb8b2ef201b54cd65e6d7f5 Mon Sep 17 00:00:00 2001 From: Akos Hadnagy Date: Fri, 1 Aug 2025 19:55:29 +0000 Subject: [PATCH 279/342] Make kernel modes explicit --- src/transformers/integrations/hub_kernels.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 3a19d3c6ed72..ea74402c38d4 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -18,6 +18,7 @@ from kernels import ( Device, LayerRepository, + Mode, register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, @@ -45,11 +46,12 @@ layer_name="LigerRMSNorm", # revision="pure-layer-test", ), - "rocm": LayerRepository( + "rocm": { + Mode.INFERENCE: LayerRepository( repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm", # revision="pure-layer-test", - ) + )} }, "MLP": { "cuda": LayerRepository( @@ -62,10 +64,11 @@ repo_id="kernels-community/megablocks", layer_name="MegaBlocksMoeMLP", ), - "rocm": LayerRepository( + "rocm": { + Mode.INFERENCE: LayerRepository( repo_id="ahadnagy/megablocks", layer_name="MegaBlocksMoeMLP", - ), + )} }, } From e946804b09d981f2928aaf22807c24948fdbf135 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Fri, 1 Aug 2025 23:46:29 +0200 Subject: [PATCH 280/342] update final chat template by Matt. x2 --- .../models/gpt_oss/convert_gpt_oss_weights_to_hf.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 5334169a2a52..180d9be33177 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -686,12 +686,20 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #} {#- This is a situation that should only occur in training, never in inference. #} {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }} - {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }} + {#- <|return|> indicates the end of generation, but <|end|> does not #} + {#- <|return|> should never be an input to the model, but we include it as the final token #} + {#- when training, so the model learns to emit it. #} + {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|return|>" }} {%- set last_tool_call.name = none %} {%- elif "thinking" in message %} {#- CoT is dropped during all previous turns, so we never render it for inference #} {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }} {%- set last_tool_call.name = none %} + {%- elif loop.last and not add_generation_prompt %} + {#- <|return|> indicates the end of generation, but <|end|> does not #} + {#- <|return|> should never be an input to the model, but we include it as the final token #} + {#- when training, so the model learns to emit it. #} + {{- "<|start|>assistant<|message|>" + message.content + "<|return|>" }} {%- else %} {{- "<|start|>assistant<|message|>" + message.content + "<|end|>" }} {%- set last_tool_call.name = none %} From 1a8728d68295b210776715dd5cac1a1282cb6d62 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Fri, 1 Aug 2025 23:46:34 +0200 Subject: [PATCH 281/342] Thanks Matt for his tireless efforts! Co-authored-by: Rocketknight1 From 50b825061d90a240be361f9df9a26b0b12b97910 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 1 Aug 2025 21:52:36 +0000 Subject: [PATCH 282/342] Fix installation --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6caa70773f7b..9b0d009f77a9 100644 --- a/setup.py +++ b/setup.py @@ -128,7 +128,7 @@ # Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support. "keras>2.9,<2.16", "keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras. - "kernels>=0.6.1,=<0.9", + "kernels>=0.6.1,<0.9", "librosa", "natten>=0.14.6,<0.15.0", "nltk<=3.8.1", From dec98d80d56c6762feac65e29062979f3fb24271 Mon Sep 17 00:00:00 2001 From: lewtun Date: Fri, 1 Aug 2025 17:09:04 -0500 Subject: [PATCH 283/342] Update setup.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Ákos Hadnagy --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9b0d009f77a9..8e3513217ba0 100644 --- a/setup.py +++ b/setup.py @@ -128,7 +128,7 @@ # Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support. "keras>2.9,<2.16", "keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras. - "kernels>=0.6.1,<0.9", + "kernels>=0.6.1,<=0.9", "librosa", "natten>=0.14.6,<0.15.0", "nltk<=3.8.1", From 0c6f911d215c290ac9db4b2a9cac59b43227f4d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 1 Aug 2025 22:35:23 +0000 Subject: [PATCH 284/342] allow no content --- .../models/gpt_oss/convert_gpt_oss_weights_to_hf.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 180d9be33177..d8ccd98bc8fa 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -676,12 +676,16 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {%- if "tool_calls" in message %} {#- We assume max 1 tool call per message, and so we infer the tool call name #} {#- in "tool" messages from the most recent assistant tool call name #} - {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content }} + {%- set tool_call = message.tool_calls[0] %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- "<|start|>assistant<|channel|>analysis<|message|>" }} {{- "<|end|><|start|>assistant to=" }} - {{- "functions." + message.tool_calls[0].name + "<|channel|>commentary json<|message|>" }} - {{- message.tool_calls[0].arguments|tojson }} + {{- "functions." + tool_call.name + "<|channel|>commentary json<|message|>" }} + {{- tool_call.arguments|tojson }} {{- "<|end|>" }} - {%- set last_tool_call.name = message.tool_calls[0].name %} + {%- set last_tool_call.name = tool_call.name %} {%- elif "thinking" in message and loop.last and not add_generation_prompt %} {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #} {#- This is a situation that should only occur in training, never in inference. #} From 181c625a2552616160a3a992070d4a4b42a77469 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 1 Aug 2025 22:38:53 +0000 Subject: [PATCH 285/342] fix: update message handling in write_tokenizer function --- .../models/gpt_oss/convert_gpt_oss_weights_to_hf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index d8ccd98bc8fa..e72acd7e39d9 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -680,8 +680,10 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {%- if tool_call.function %} {%- set tool_call = tool_call.function %} {%- endif %} - {{- "<|start|>assistant<|channel|>analysis<|message|>" }} - {{- "<|end|><|start|>assistant to=" }} + {%- if message.content %} + {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }} + {%- endif %} + {{- "<|start|>assistant to=" }} {{- "functions." + tool_call.name + "<|channel|>commentary json<|message|>" }} {{- tool_call.arguments|tojson }} {{- "<|end|>" }} From 7c741230120a747e87d4d77d15d38a7372e89234 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 2 Aug 2025 02:56:26 +0000 Subject: [PATCH 286/342] Fix template logic for user message role --- .../models/gpt_oss/convert_gpt_oss_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 180d9be33177..b819fb28a97d 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -710,7 +710,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {%- endif %} {{- "<|start|>functions." + last_tool_call.name }} {{- " to=assistant<|channel|>commentary<|message|>" + message.content|tojson + "<|end|>" }} - {%- else -%} + {%- elif message.role == 'user' -%} {{- "<|start|>user<|message|>" + message.content + "<|end|>" }} {%- endif -%} {%- endfor -%} From 9d27880c2d072bf49b160b4514f296de4ae431d3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 2 Aug 2025 03:59:50 +0000 Subject: [PATCH 287/342] last nits for CB and flash_paged! --- src/transformers/generation/continuous_batching.py | 4 ++-- src/transformers/integrations/flash_paged.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index a43b11fb40d3..26176a468b6f 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -25,7 +25,6 @@ import torch import torch.nn as nn -from tokenizers import Tokenizer from tokenizers.decoders import DecodeStream from tqdm import tqdm @@ -33,6 +32,7 @@ from ..generation.configuration_utils import GenerationConfig from ..utils.logging import logging from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced +from ..tokenization_utils_fast import PreTrainedTokenizerFast class RequestStatus(Enum): @@ -751,7 +751,7 @@ def __init__( self.setup_static_tensors() - self.tokenizer = Tokenizer.from_pretrained(self.config._name_or_path) + self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.config._name_or_path) self.decode_stream = DecodeStream(skip_special_tokens=True) @traced(standalone=True) diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index a7bf5ae57717..3e66a340794c 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -52,6 +52,7 @@ def paged_attention_forward( if implementation is not None: flash_attn_varlen_func = implementation.flash_attn_varlen_func + custom_kwargs = {"s_aux", kwargs.get("s_aux")} attn_output = flash_attn_varlen_func( q.transpose(1, 2).squeeze(0).contiguous(), k.transpose(1, 2).squeeze(0).contiguous(), @@ -64,7 +65,7 @@ def paged_attention_forward( causal=True, # kind of a must, it automatically aligns the mask for q < k window_size=(-1, -1), # -1 means infinite context window # block_table=block_tables, -> torch.Tensor - # **kwargs, + **custom_kwargs, ) if isinstance(attn_output, tuple): attn_output = attn_output[0] From 4cf6186b74e0d211c9191bcf4ab943b9df0e11d7 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 2 Aug 2025 04:08:35 +0000 Subject: [PATCH 288/342] there was one bad merge --- .../integrations/tensor_parallel.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index df1842cf069c..3131af0aca40 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -584,6 +584,43 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me return outputs.to_local() if use_local_output else outputs +class ReduceFromModelParallelRegion(torch.autograd.Function): + """ + All-reduce in forward pass, identity in backward pass. + This is the `g` function in the paper: https://arxiv.org/abs/1909.08053 + """ + + @staticmethod + def forward(ctx, x, device_mesh): + if device_mesh.size() == 1: + return x + dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group()) + return x + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class CopyToModelParallelRegion(torch.autograd.Function): + """ + Copy in forward pass, all-reduce in backward pass. + This is the `f` function in the paper: https://arxiv.org/abs/1909.08053 + """ + + @staticmethod + def forward(ctx, x, device_mesh): + ctx.device_mesh = device_mesh + return x + + @staticmethod + def backward(ctx, grad_output): + if ctx.device_mesh.size() == 1: + return grad_output + dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=ctx.device_mesh.get_group()) + return grad_output + + class ColwiseParallel(TensorParallelLayer): """ General tensor parallel layer for transformers. From cac4c098a77398e4e975f031c9b897e046fe34b2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 2 Aug 2025 04:33:51 +0000 Subject: [PATCH 289/342] fix CB (hardcode for now, its just using kv groups instead) --- src/transformers/generation/continuous_batching.py | 4 ++-- src/transformers/integrations/flash_paged.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 26176a468b6f..adb15d02d5f2 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -182,7 +182,7 @@ def __init__( f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}." ) # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. - self.num_key_value_heads //= tp_size + # self.num_key_value_heads //= tp_size self.head_dim = ( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads @@ -1250,7 +1250,7 @@ def _run_generation_loop(self): self.model.device, self.model.dtype, num_requests=len(self.input_queue.queue), - tp_size=getattr(self.model, "tp_size"), + tp_size=getattr(self.model, "_tp_size", 8), # TODO quantized converted don't set this ) scheduler = None diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index 3e66a340794c..4d4b73d85c52 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -52,7 +52,7 @@ def paged_attention_forward( if implementation is not None: flash_attn_varlen_func = implementation.flash_attn_varlen_func - custom_kwargs = {"s_aux", kwargs.get("s_aux")} + custom_kwargs = {"s_aux": kwargs.get("s_aux")} attn_output = flash_attn_varlen_func( q.transpose(1, 2).squeeze(0).contiguous(), k.transpose(1, 2).squeeze(0).contiguous(), From eeef8c8dbeaf939699926934167b0538a4a88cb5 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Sat, 2 Aug 2025 07:56:27 +0000 Subject: [PATCH 290/342] fix --- src/transformers/integrations/mxfp4.py | 29 +++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 87af16791597..70d27e065111 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -50,10 +50,12 @@ # Copied from GPT_OSS repo and vllm def quantize_to_mxfp4(w): from triton_kernels.numerics_details.mxfp import downcast_to_mxfp + w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1) w, w_scale = swizzle_mxfp4(w, w_scale) return w, w_scale + def swizzle_mxfp4(w, w_scale): from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor from triton_kernels.tensor_details import layout @@ -76,6 +78,7 @@ def swizzle_mxfp4(w, w_scale): w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) return w, w_scale + # Copied from GPT_OSS repo # TODO: Add absolute link when the repo is public def convert_moe_packed_tensors( @@ -329,6 +332,7 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** delattr(module, blocks_attr) delattr(module, scales_attr) + def load_and_swizzle_mxfp4( module, param_name, param_value, target_device, **kwargs ): @@ -363,29 +367,44 @@ def load_and_swizzle_mxfp4( blocks = module.gate_up_proj_blocks.view(local_experts, module.intermediate_size * 2, -1) else: blocks = module.down_proj_blocks.view(local_experts, -1, module.intermediate_size // 2) - # TODO: we need to have the weights on cuda, refactor later if target_device == "cpu": target_device = "cuda" with torch.cuda.device(target_device): - triton_weight_tensor, weight_scale = swizzle_mxfp4(blocks.transpose(-2, -1), getattr(module, scales_attr).transpose(-2, -1)) + triton_weight_tensor, weight_scale = swizzle_mxfp4( + blocks.transpose(-2, -1), getattr(module, scales_attr).transpose(-2, -1) + ) # need to overwrite the shapes for the kernels if proj == "gate_up_proj": - triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2]) + triton_weight_tensor.shape = torch.Size( + [local_experts, module.hidden_size, module.intermediate_size * 2] + ) else: - triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size]) + triton_weight_tensor.shape = torch.Size( + [local_experts, module.intermediate_size, module.hidden_size] + ) # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It is like a subtensor + triton_weight_tensor.storage.data = triton_weight_tensor.storage.data.to(target_device) + weight_scale.storage.data = weight_scale.storage.data.to(target_device) + setattr(module, proj, triton_weight_tensor) - setattr(module, f"{proj}_precision_config", PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()))) + + setattr( + module, + f"{proj}_precision_config", + PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())), + ) # delete blocks and scales delattr(module, scales_attr) delattr(module, blocks_attr) # setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False)) del blocks + + def _replace_with_mxfp4_linear( model, modules_to_not_convert=None, From 45fbc18590867af65150d08ed19dfe01199aa3fa Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Sat, 2 Aug 2025 09:05:25 +0000 Subject: [PATCH 291/342] better fix for device_map --- src/transformers/integrations/mxfp4.py | 15 +++++++-------- src/transformers/quantizers/quantizer_mxfp4.py | 2 +- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 70d27e065111..ebac23890683 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -362,18 +362,21 @@ def load_and_swizzle_mxfp4( # Check if both blocks and scales both not on on meta device if blocks.device.type != "meta" and scales.device.type != "meta": # need it for ep - local_experts = getattr(module, blocks_attr).size(0) + local_experts = blocks.size(0) if proj == "gate_up_proj": - blocks = module.gate_up_proj_blocks.view(local_experts, module.intermediate_size * 2, -1) + blocks = blocks.view(local_experts, module.intermediate_size * 2, -1) else: - blocks = module.down_proj_blocks.view(local_experts, -1, module.intermediate_size // 2) + blocks = blocks.view(local_experts, -1, module.intermediate_size // 2) # TODO: we need to have the weights on cuda, refactor later if target_device == "cpu": target_device = "cuda" + # TODO: not use why torch.cuda.device + blocks = blocks.to(target_device) + scales = scales.to(target_device) with torch.cuda.device(target_device): triton_weight_tensor, weight_scale = swizzle_mxfp4( - blocks.transpose(-2, -1), getattr(module, scales_attr).transpose(-2, -1) + blocks.transpose(-2, -1), scales.transpose(-2, -1) ) # need to overwrite the shapes for the kernels @@ -387,11 +390,7 @@ def load_and_swizzle_mxfp4( ) # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It is like a subtensor - triton_weight_tensor.storage.data = triton_weight_tensor.storage.data.to(target_device) - weight_scale.storage.data = weight_scale.storage.data.to(target_device) - setattr(module, proj, triton_weight_tensor) - setattr( module, f"{proj}_precision_config", diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 679199a124f7..949f66cc605d 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -227,7 +227,7 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs if self.quantization_config.dequantize: self.remove_quantization_config(model) # clean cache due to triton ops - if not torch.cuda.is_available(): + if torch.cuda.is_available(): torch.cuda.empty_cache() def update_expected_keys(self, model: "PreTrainedModel", expected_keys: list[str], checkpoint_keys: list[str]): From 6dd3a723bd4f0b36a5f41d212e7f2a017817ff0e Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Sat, 2 Aug 2025 22:30:04 +0000 Subject: [PATCH 292/342] minor device fix --- src/transformers/integrations/mxfp4.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index ebac23890683..56ffee437b12 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -368,10 +368,9 @@ def load_and_swizzle_mxfp4( else: blocks = blocks.view(local_experts, -1, module.intermediate_size // 2) # TODO: we need to have the weights on cuda, refactor later - if target_device == "cpu": + if getattr(target_device, "type", target_device) == "cpu": target_device = "cuda" - - # TODO: not use why torch.cuda.device + # TODO: check why we still do move the tensors despite the context manager blocks = blocks.to(target_device) scales = scales.to(target_device) with torch.cuda.device(target_device): From 5ef7f3f416758dde1d700f892b554915ec32d137 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 3 Aug 2025 07:42:41 +0000 Subject: [PATCH 293/342] Fix flash paged --- src/transformers/modeling_flash_attention_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 47744eaca3f2..bfab34703971 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -389,7 +389,8 @@ def _flash_attention_forward( flash_kwargs["deterministic"] = det if softcap is not None: flash_kwargs["softcap"] = softcap - + if "s_aux" in kwargs: + flash_kwargs["s_aux"] = kwargs.get("s_aux") query_states, key_states, value_states = fa_peft_integration_check( query_states, key_states, value_states, target_dtype ) From d2303c71106b05b201f60db861bc98e76b7e3b2d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 3 Aug 2025 07:45:15 +0000 Subject: [PATCH 294/342] updates --- src/transformers/integrations/tensor_parallel.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 3131af0aca40..b393336e431f 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -442,9 +442,9 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_ @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): if isinstance(outputs, torch.Tensor): - dist.all_reduce(outputs, op=dist.ReduceOp.SUM, async_op=False) + dist.all_reduce(outputs, op=dist.ReduceOp.SUM, async_op=False, group=device_mesh.get_group()) else: - dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False) + dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False, group=device_mesh.get_group()) return outputs def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: @@ -594,7 +594,10 @@ class ReduceFromModelParallelRegion(torch.autograd.Function): def forward(ctx, x, device_mesh): if device_mesh.size() == 1: return x - dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group()) + if isinstance(x, tuple): + dist.all_reduce(x[0], op=dist.ReduceOp.SUM, group=device_mesh.get_group()) + else: + dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group()) return x @staticmethod @@ -1086,9 +1089,9 @@ def shard_and_distribute_module( if dist.get_rank() == 0: if current_shard_plan is None: - logger.info(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.") + logger.debug(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.") else: - logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}") + logger.debug(f"Tensor sharding plan for {param_name}: {current_shard_plan}") if current_shard_plan is not None: try: @@ -1167,6 +1170,7 @@ def distribute_model(model, distributed_config, device_mesh, tp_size): plan = getattr(module.config, "base_model_tp_plan", {}) model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) + logger.debug(f"Final TP plan for the model: {model._tp_plan}") if model._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available: for v in model._tp_plan.values(): if v not in ALL_PARALLEL_STYLES: @@ -1176,6 +1180,8 @@ def distribute_model(model, distributed_config, device_mesh, tp_size): from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model._tp_plan, is_weight=False) + if plan is not None: + logger.debug(f"adding hooks for {name},{plan}") add_tensor_parallel_hooks_to_module( model=model, module=module, From ed511f215810d390c79a9c58b1d56eca246e1688 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 3 Aug 2025 08:04:19 +0000 Subject: [PATCH 295/342] Revert "remove dtensors, not explicit (#39840)" This reverts commit 6dfd561d9cd722dfc09f702355518c6d09b9b4e3. --- src/transformers/modeling_utils.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index aeffe0055b14..ee8d0faa6690 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4122,16 +4122,9 @@ def save_pretrained( for shard_file, tensors in filename_to_tensors: shard = {} for tensor in tensors: - if _is_dtensor_available and getattr(self, "_device_mesh", None) is not None: - plan = _get_parameter_tp_plan(tensor, self._tp_plan) - full_tensor = state_dict[tensor] - if isinstance(state_dict[tensor], DTensor): - full_tensor = full_tensor.full_tensor() - elif plan is not None: - shard_dim = -1 if "rowwise" in plan else 0 - gather_list = [torch.empty_like(full_tensor) for _ in range(self._device_mesh.size())] - torch.distributed.all_gather(gather_list, full_tensor) - full_tensor = torch.cat(gather_list, dim=shard_dim) + if _is_dtensor_available and isinstance(state_dict[tensor], DTensor): + full_tensor = state_dict[tensor].full_tensor() + # to get the correctly ordered tensor we need to repack if packed if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",): full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2) shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly From e9b3708ec68beac5a8c4cf9e77e5a6e46ff4460c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 3 Aug 2025 15:45:11 +0000 Subject: [PATCH 296/342] update --- src/transformers/integrations/tensor_parallel.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index b393336e431f..73857a65ee89 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -595,14 +595,14 @@ def forward(ctx, x, device_mesh): if device_mesh.size() == 1: return x if isinstance(x, tuple): - dist.all_reduce(x[0], op=dist.ReduceOp.SUM, group=device_mesh.get_group()) + dist.all_reduce(x[0], op=dist.ReduceOp.SUM, async_op=False, group=device_mesh.get_group()) else: - dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group()) + dist.all_reduce(x, op=dist.ReduceOp.SUM, async_op=False, group=device_mesh.get_group()) return x @staticmethod def backward(ctx, grad_output): - return grad_output + return grad_output, None class CopyToModelParallelRegion(torch.autograd.Function): @@ -621,7 +621,7 @@ def backward(ctx, grad_output): if ctx.device_mesh.size() == 1: return grad_output dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=ctx.device_mesh.get_group()) - return grad_output + return grad_output, None class ColwiseParallel(TensorParallelLayer): @@ -721,12 +721,14 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_ if hasattr(mod, "bias") and mod.bias is not None: mod._bias = mod.bias mod.bias = None - - input_tensor = inputs[0] - return input_tensor + return inputs @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + """ + We add the bias after outputs have been collected because the bias is replicated across shards. + If you add it before you are adding it as many time as the world size. + """ outputs = ReduceFromModelParallelRegion.apply(outputs, device_mesh) if hasattr(mod, "_bias"): outputs += mod._bias From 70750d9ac145efc325982cf3dad846a27860b1cb Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 3 Aug 2025 15:46:40 +0000 Subject: [PATCH 297/342] Revert "remove dtensors, not explicit (#39840)" This reverts commit 6dfd561d9cd722dfc09f702355518c6d09b9b4e3. --- .../integrations/tensor_parallel.py | 194 ++++++------------ tests/tensor_parallel/test_tensor_parallel.py | 124 ++++++----- 2 files changed, 128 insertions(+), 190 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 73857a65ee89..e61fc2135070 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -18,7 +18,6 @@ import os import re from functools import partial, reduce -from typing import Optional import torch import torch.distributed as dist @@ -151,7 +150,6 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig "F64": torch.float64, "I64": torch.int64, "F8_E4M3": torch.float8_e4m3fn, - "F8_E5M2": torch.float8_e5m2, } @@ -442,9 +440,9 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_ @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): if isinstance(outputs, torch.Tensor): - dist.all_reduce(outputs, op=dist.ReduceOp.SUM, async_op=False, group=device_mesh.get_group()) + dist.all_reduce(outputs, op=dist.ReduceOp.SUM, async_op=False) else: - dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False, group=device_mesh.get_group()) + dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False) return outputs def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: @@ -527,149 +525,81 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, return param -class VocabParallel(TensorParallelLayer): +class ColwiseParallel(TensorParallelLayer): """ - VocabParallel is a tensor parallel layer that shards the embedding table along the last dimension. - No need to do input masking as embedding would be stored in `_MaskPartial` which handles it (https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_embedding_ops.py#L70) - - This is useful if you want to train with long sequence length! + General tensor parallel layer for transformers. """ def __init__( self, *, - input_layouts: Optional[Placement] = None, - output_layouts: Optional[Placement] = None, - weight_dim_sharding: int = 0, + input_layouts: Placement | None = None, + output_layouts: Placement | None = None, use_local_output: bool = True, use_dtensor=True, ): super().__init__() self.input_layouts = (input_layouts or Replicate(),) + self.output_layouts = (output_layouts or Shard(-1),) self.desired_input_layouts = (Replicate(),) - self.output_layouts = (output_layouts or Replicate(),) self.use_local_output = use_local_output self.use_dtensor = use_dtensor - self.weight_dim_sharding = weight_dim_sharding @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + # TODO: figure out dynamo support for instance method and switch this to instance method + # annotate module input placements/sharding with input_layouts input_tensor = inputs[0] if not isinstance(input_tensor, DTensor): input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + # transform the input layouts to the desired layouts of ColwiseParallel if input_layouts != desired_input_layouts: input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False) return input_tensor def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): - # Shard the embedding table along dim 0 (vocab dimension) + # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) + # means Colwise as Linear is input * weight^T + bias, where + # weight would become Shard(1) if param_type == "bias": parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1) - placements = [Shard(-1)] + shard = [Shard(-1)] else: - parameter = get_tensor_shard(param, empty_param, device_mesh, rank, self.weight_dim_sharding) - placements = [Shard(self.weight_dim_sharding)] + shard = [Shard(-2)] + parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2) + parameter = parameter.to(param_casting_dtype) if to_contiguous: parameter = parameter.contiguous() if self.use_dtensor: - parameter = DTensor.from_local(parameter, device_mesh, placements, run_check=False) - return nn.Parameter(parameter) + parameter = DTensor.from_local( + parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride() + ) + return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + # outputs is a shard on last dimension DTensor, i.e. Shard(-1) if outputs.placements != output_layouts: outputs = outputs.redistribute(placements=output_layouts, async_op=False) - return outputs.to_local() if use_local_output else outputs - - -class ReduceFromModelParallelRegion(torch.autograd.Function): - """ - All-reduce in forward pass, identity in backward pass. - This is the `g` function in the paper: https://arxiv.org/abs/1909.08053 - """ - - @staticmethod - def forward(ctx, x, device_mesh): - if device_mesh.size() == 1: - return x - if isinstance(x, tuple): - dist.all_reduce(x[0], op=dist.ReduceOp.SUM, async_op=False, group=device_mesh.get_group()) - else: - dist.all_reduce(x, op=dist.ReduceOp.SUM, async_op=False, group=device_mesh.get_group()) - return x - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None - - -class CopyToModelParallelRegion(torch.autograd.Function): - """ - Copy in forward pass, all-reduce in backward pass. - This is the `f` function in the paper: https://arxiv.org/abs/1909.08053 - """ - - @staticmethod - def forward(ctx, x, device_mesh): - ctx.device_mesh = device_mesh - return x - - @staticmethod - def backward(ctx, grad_output): - if ctx.device_mesh.size() == 1: - return grad_output - dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=ctx.device_mesh.get_group()) - return grad_output, None - - -class ColwiseParallel(TensorParallelLayer): - """ - General tensor parallel layer for transformers. - """ + # back to local tensor + return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs - def __init__( - self, - *, - input_layouts: Placement | None = None, - output_layouts: Placement | None = None, - use_local_output: bool = True, - use_dtensor=True, - ): - super().__init__() - self.input_layouts = (input_layouts or Replicate(),) - self.output_layouts = (output_layouts or Shard(-1),) - self.desired_input_layouts = (Replicate(),) - self.use_local_output = use_local_output - self.use_dtensor = use_dtensor - - @staticmethod - def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): - # annotate module input placements/sharding with input_layouts - input_tensor = inputs[0] - return input_tensor +class PackedColwiseParallel(ColwiseParallel): def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) # means Colwise as Linear is input * weight^T + bias, where # weight would become Shard(1) - if param_type == "bias": - parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1) - else: - parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2) - + parameter = get_packed_weights(param, empty_param, device_mesh, rank, -2) parameter = parameter.to(param_casting_dtype) if to_contiguous: parameter = parameter.contiguous() - + if self.use_dtensor: + parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) - @staticmethod - def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): - outputs = CopyToModelParallelRegion.apply(outputs, device_mesh) - return outputs - class RowwiseParallel(TensorParallelLayer): """ @@ -705,15 +635,23 @@ def __init__( self.use_dtensor = use_dtensor def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): - if param_type == "bias": - parameter = param[:] - else: + # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) + # means Rowwise as nn.Linear is input * weight^T + bias, where + # weight would become Shard(0) + if param_type != "bias": parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1) + shard = [Shard(-1)] + else: + shard = [Replicate()] + parameter = param[:] parameter = parameter.to(param_casting_dtype) if to_contiguous: parameter = parameter.contiguous() - + if self.use_dtensor: + parameter = DTensor.from_local( + parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride() + ) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) @staticmethod @@ -721,18 +659,26 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_ if hasattr(mod, "bias") and mod.bias is not None: mod._bias = mod.bias mod.bias = None - return inputs + + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + return input_tensor @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): - """ - We add the bias after outputs have been collected because the bias is replicated across shards. - If you add it before you are adding it as many time as the world size. - """ - outputs = ReduceFromModelParallelRegion.apply(outputs, device_mesh) + # Rowwise sharding produces partial output, depending on output layouts: + # 1. to replicate -> allreduce + # 2. to shard -> reduce_scatter + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) if hasattr(mod, "_bias"): outputs += mod._bias - return outputs + # back to local tensor if use_local_output is True + return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: module._distribute_module_applied = True @@ -757,21 +703,6 @@ def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: ) -class PackedColwiseParallel(ColwiseParallel): - def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): - # NOTE(3outeille): need to be deprecated as no longer using dtensors - # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) - # means Colwise as Linear is input * weight^T + bias, where - # weight would become Shard(1) - parameter = get_packed_weights(param, empty_param, device_mesh, rank, -2) - parameter = parameter.to(param_casting_dtype) - if to_contiguous: - parameter = parameter.contiguous() - if self.use_dtensor: - parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False) - return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) - - class PackedRowwiseParallel(RowwiseParallel): def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) @@ -975,8 +906,6 @@ class ParallelInterface(GeneralInterface): # a new instance is created (in order to locally override a given entry) _global_mapping = ( { - "vocab_parallel_rowwise": VocabParallel(weight_dim_sharding=0), - "vocab_parallel_colwise": VocabParallel(weight_dim_sharding=-2, output_layouts=Replicate()), "colwise": ColwiseParallel(), "rowwise": RowwiseParallel(), "colwise_rep": ColwiseParallel(output_layouts=Replicate()), @@ -1067,7 +996,7 @@ def add_tensor_parallel_hooks_to_module( def shard_and_distribute_module( - model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh, set_param=True + model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh ): # TODO: rename to shard_and_distribute_param r""" This function is called in `from_pretrained` when loading a model's checkpoints. @@ -1091,9 +1020,9 @@ def shard_and_distribute_module( if dist.get_rank() == 0: if current_shard_plan is None: - logger.debug(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.") + logger.info(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.") else: - logger.debug(f"Tensor sharding plan for {param_name}: {current_shard_plan}") + logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}") if current_shard_plan is not None: try: @@ -1112,8 +1041,7 @@ def shard_and_distribute_module( # otherwise loading is crazy slow if not isinstance(param, torch.nn.Parameter): param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point()) - if set_param: - setattr(module_to_tp, param_type, param) + setattr(module_to_tp, param_type, param) # module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True) return param @@ -1163,6 +1091,7 @@ def distribute_model(model, distributed_config, device_mesh, tp_size): _plan = "_ep_plan" model._tp_plan = getattr(model.config, "base_model_ep_plan", model._tp_plan).copy() + # now fetch my childrens for name, module in model.named_children(): if plan := getattr(module, _plan, getattr(module, "tp_plan", None)): model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) @@ -1172,7 +1101,6 @@ def distribute_model(model, distributed_config, device_mesh, tp_size): plan = getattr(module.config, "base_model_tp_plan", {}) model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) - logger.debug(f"Final TP plan for the model: {model._tp_plan}") if model._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available: for v in model._tp_plan.values(): if v not in ALL_PARALLEL_STYLES: @@ -1182,8 +1110,6 @@ def distribute_model(model, distributed_config, device_mesh, tp_size): from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model._tp_plan, is_weight=False) - if plan is not None: - logger.debug(f"adding hooks for {name},{plan}") add_tensor_parallel_hooks_to_module( model=model, module=module, diff --git a/tests/tensor_parallel/test_tensor_parallel.py b/tests/tensor_parallel/test_tensor_parallel.py index 985a3c6a23b0..1904fc8bd1e7 100644 --- a/tests/tensor_parallel/test_tensor_parallel.py +++ b/tests/tensor_parallel/test_tensor_parallel.py @@ -64,68 +64,76 @@ def size(self): assert torch.allclose(unpacked_weights, original_packed_weights) -class TensorParallelTestBase(TestCasePlus): +class TestTensorParallel(TestCasePlus): nproc_per_node = 2 - def run_torch_distributed_test(self, script: str, is_torchrun: bool = True): - """Run the given Python script in a subprocess using torchrun or python3.""" + def torchrun(self, script: str, is_torchrun: bool = True): + """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary.""" with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp: tmp.write(script) tmp.flush() tmp.seek(0) - cmd = ( - ( - f"torchrun --nproc_per_node {self.nproc_per_node} " - f"--master_port {get_torch_dist_unique_port()} {tmp.name}" + if is_torchrun: + cmd = ( + f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}" ).split() - if is_torchrun - else ["python3", tmp.name] - ) + else: + cmd = ["python3", tmp.name] + # Note that the subprocess will be waited for here, and raise an error if not successful try: - subprocess.run(cmd, capture_output=True, env=self.get_env(), text=True, check=True) + _ = subprocess.run(cmd, capture_output=True, env=self.get_env(), text=True, check=True) except subprocess.CalledProcessError as e: - raise Exception(f"Subprocess failed with:\nSTDOUT:\n{e.stdout}\n\nSTDERR:\n{e.stderr}") - - def run_tensor_parallel_test(self, model_id: str, mode: str = "training", expected_output: str = None): - """ - Runs a tensor-parallel test for either training (forward) or generate mode. - - Args: - model_id: The model to test. - mode: "training" or "generate". - expected_output: Token or string to assert for training mode. - """ - if mode not in ("training", "generate"): - raise ValueError(f"Invalid mode '{mode}', must be 'training' or 'generate'") - - # Only the outputs line changes between training and generate - outputs_line = ( - "outputs = model(inputs)" - if mode == "training" - else 'outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static")' - ) + raise Exception(f"The following error was captured: {e.stderr}") + + def test_model_forward(self): + script_to_run = textwrap.dedent( + """ + import torch + import os + from transformers import AutoModelForCausalLM, AutoTokenizer + + model_id = "JackFram/llama-68m" + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto") + torch.distributed.barrier() + + has_dtensor = 0 + for name, parameter in model.named_parameters(): + if isinstance(parameter.data, torch.distributed.tensor.DTensor): + has_dtensor = 1 + break + + assert has_dtensor == 1, "TP model must has DTensor" + + tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=False) + prompt = "Can I help" + + inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) + outputs = model(inputs) - # Expected assertion differs slightly depending on the mode - if mode == "training": - assertion = f""" next_token_logits = outputs[0][:, -1, :] next_token = torch.argmax(next_token_logits, dim=-1) response = tokenizer.decode(next_token) - assert response == "{expected_output}", f"Expected token '{{expected_output}}', got '{{response}}'" - """ - else: - assertion = """ - output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) - assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'" + assert response == "with" + + torch.distributed.barrier() + torch.distributed.destroy_process_group() """ + ) + self.torchrun(script_to_run) - script = f""" + def test_model_generate(self): + script_to_run = textwrap.dedent( + """ import torch import os from transformers import AutoModelForCausalLM, AutoTokenizer - model_id = "{model_id}" + model_id = "JackFram/llama-68m" rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) @@ -133,18 +141,30 @@ def run_tensor_parallel_test(self, model_id: str, mode: str = "training", expect model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto") torch.distributed.barrier() - tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=False) + model.forward = torch.compile(model.forward) + + has_dtensor = 0 + for name, parameter in model.named_parameters(): + if isinstance(parameter.data, torch.distributed.tensor.DTensor): + has_dtensor = 1 + break + + assert has_dtensor == 1, "TP model must has DTensor" + + tokenizer = AutoTokenizer.from_pretrained(model_id) prompt = "Can I help" inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) - {outputs_line} + outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static") - {assertion} + output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) + assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'" torch.distributed.barrier() torch.distributed.destroy_process_group() - """ - self.run_torch_distributed_test(textwrap.dedent(script)) + """ + ) + self.torchrun(script_to_run) @require_huggingface_hub_greater_or_equal("0.31.4") def test_model_save(self): @@ -171,7 +191,7 @@ def test_model_save(self): model.save_pretrained(result_dir) """ ) - self.run_torch_distributed_test(script_to_run, is_torchrun=is_torchrun) + self.torchrun(script_to_run, is_torchrun=is_torchrun) non_tp_model_path = os.path.join(tmp_dir, "nontp") tp_model_path = os.path.join(tmp_dir, "tp") @@ -189,14 +209,6 @@ def test_model_save(self): del non_tp_tensor, tp_tensor -class TestTensorParallel(TensorParallelTestBase): - def test_model_training(self): - self.run_tensor_parallel_test(model_id="JackFram/llama-68m", mode="training", expected_output="with") - - def test_model_generate(self): - self.run_tensor_parallel_test(model_id="JackFram/llama-68m", mode="generate") - - @require_torch_multi_accelerator class TestTensorParallelAccelerator(TestTensorParallel): nproc_per_node = backend_device_count(torch_device) From 35576899d893b86217651c5a53e7b26be18e6646 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sun, 3 Aug 2025 16:04:21 +0000 Subject: [PATCH 298/342] fix merge --- src/transformers/integrations/tensor_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index e61fc2135070..b988fee74a0b 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -996,7 +996,7 @@ def add_tensor_parallel_hooks_to_module( def shard_and_distribute_module( - model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh + model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh, set_param=True ): # TODO: rename to shard_and_distribute_param r""" This function is called in `from_pretrained` when loading a model's checkpoints. From b939303b06438efb7eb767905db90767c795ac98 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Sun, 3 Aug 2025 17:44:48 +0000 Subject: [PATCH 299/342] fix --- src/transformers/quantizers/quantizer_mxfp4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 949f66cc605d..1358f3394da0 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -149,7 +149,7 @@ def create_quantized_param( unexpected_keys: Optional[list[str]] = None, **kwargs, ): - if is_triton_kernels_availalble(): + if is_triton_kernels_availalble() and is_triton_available("3.4.0"): from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig from ..integrations import Mxfp4GptOssExperts, dequantize, load_and_swizzle_mxfp4, quantize_to_mxfp4 From d238182f465d70bfb0d357041904e903bdcb7208 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 3 Aug 2025 19:23:23 +0000 Subject: [PATCH 300/342] Fix line break when custom model indentity --- .../models/gpt_oss/convert_gpt_oss_weights_to_hf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index f92acc8351f1..9205f7364303 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -610,10 +610,9 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {#- System Message Construction ============================================ #} {%- macro build_system_message() -%} {%- if model_identity is not defined %} - {{- "You are ChatGPT, a large language model trained by OpenAI.\n" -}} - {%- else %} - {{- model_identity }} + {%- set model_identity = "You are ChatGPT, a large language model trained by OpenAI." %} {%- endif %} + {{- model_identity + "\n" }} {{- "Knowledge cutoff: 2024-06\n" }} {{- "Current date: " + strftime_now("%Y-%m-%d") + "\n\n" }} {%- if reasoning_effort is not defined %} From 088a6070735b2a2a76c30ac7f2cf39515645c2e2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Aug 2025 06:39:02 +0000 Subject: [PATCH 301/342] nits testing --- .../generation/continuous_batching.py | 24 ++-- tests/models/gpt_oss/test_modeling_gpt_oss.py | 104 +++++++++++++----- 2 files changed, 87 insertions(+), 41 deletions(-) diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index adb15d02d5f2..5089551fdf28 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -190,19 +190,21 @@ def __init__( self.num_hidden_layers = config.num_hidden_layers # Calculate optimal block size and number if not provided - num_blocks = getattr(generation_config, "num_blocks", None) + num_blocks = getattr(generation_config, "num_blocks", 1024) block_size = getattr(generation_config, "block_size", 32) max_memory_percent = getattr(generation_config, "max_memory", 0.9) - num_blocks, max_batch_tokens = compute_optimal_blocks( - generation_config.max_new_tokens, - block_size=block_size, - head_dim=self.head_dim, - num_layers=self.num_hidden_layers, - num_heads=self.num_key_value_heads, - max_memory_percent=max_memory_percent, - dtype=dtype, - num_blocks=num_blocks, - ) + max_batch_tokens = getattr(generation_config, "max_batch_tokens", 256) + if num_blocks is None or max_batch_tokens is None: + num_blocks, max_batch_tokens = compute_optimal_blocks( + generation_config.max_new_tokens, + block_size=block_size, + head_dim=self.head_dim, + num_layers=self.num_hidden_layers, + num_heads=self.num_key_value_heads, + max_memory_percent=max_memory_percent, + dtype=dtype, + num_blocks=num_blocks, + ) logger.warning( f"Using calculated num_blocks={num_blocks}, block_size={block_size}, max concurrent requests {max_batch_tokens}" ) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index df35fc6114b0..c6df33b15046 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -19,7 +19,12 @@ from parameterized import parameterized from tests.tensor_parallel.test_tensor_parallel import TensorParallelTestBase -from transformers import AutoModelForCausalLM, AutoTokenizer, GptOssConfig, is_torch_available +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + GptOssConfig, + is_torch_available, +) from transformers.testing_utils import ( cleanup, require_read_token, @@ -79,7 +84,9 @@ class GptOssModelTest(CausalLMModelTest, unittest.TestCase): def setUp(self): self.model_tester = GptOssModelTester(self) - self.config_tester = ConfigTester(self, config_class=GptOssConfig, hidden_size=37) + self.config_tester = ConfigTester( + self, config_class=GptOssConfig, hidden_size=37 + ) @unittest.skip("Failing because of unique cache (HybridCache)") def test_model_outputs_equivalence(self, **kwargs): @@ -95,16 +102,22 @@ def test_eager_matches_sdpa_generate(self): @parameterized.expand([("random",), ("same",)]) @pytest.mark.generate - @unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding") + @unittest.skip( + "GptOss has HybridCache which is not compatible with assisted decoding" + ) def test_assisted_decoding_matches_greedy_search(self, assistant_type): pass - @unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding") + @unittest.skip( + "GptOss has HybridCache which is not compatible with assisted decoding" + ) def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): pass @pytest.mark.generate - @unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding") + @unittest.skip( + "GptOss has HybridCache which is not compatible with assisted decoding" + ) def test_assisted_decoding_sample(self): pass @@ -128,15 +141,21 @@ def test_contrastive_generate_dict_outputs_use_cache(self): def test_contrastive_generate_low_memory(self): pass - @unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + @unittest.skip( + "GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support." + ) def test_generate_with_static_cache(self): pass - @unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + @unittest.skip( + "GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support." + ) def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + @unittest.skip( + "GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support." + ) def test_generate_continue_from_inputs_embeds(self): pass @@ -147,7 +166,9 @@ def test_generate_continue_from_inputs_embeds(self): def test_multi_gpu_data_parallel_forward(self): pass - @unittest.skip("GptOss has HybridCache which auto-compiles. Compile and FA2 don't work together.") + @unittest.skip( + "GptOss has HybridCache which auto-compiles. Compile and FA2 don't work together." + ) def test_eager_matches_fa2_generate(self): pass @@ -159,7 +180,10 @@ def test_flash_attn_2_equivalence(self): @slow @require_torch_accelerator class GptOssIntegrationTest(unittest.TestCase): - input_text = ["Hello I am doing", "Hi today"] + input_text = [ + "Roses are red, violets", + "How are you? Tell me the name of the president of", + ] def setUp(self): cleanup(torch_device, gc_collect=True) @@ -168,18 +192,22 @@ def tearDown(self): cleanup(torch_device, gc_collect=True) @staticmethod - def load_and_forward(model_id, attn_implementation, input_text, **pretrained_kwargs): + def load_and_forward( + model_id, attn_implementation, input_text, **pretrained_kwargs + ): if not isinstance(attn_implementation, list): attn_implementation = [attn_implementation] text = [] - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, **pretrained_kwargs).to( - torch_device - ) + model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype=torch.bfloat16, **pretrained_kwargs + ).to(torch_device) for attn in attn_implementation: model.set_attn_implementation(attn) tokenizer = AutoTokenizer.from_pretrained(model_id) - inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(torch_device) + inputs = tokenizer(input_text, return_tensors="pt", padding=True).to( + torch_device + ) output = model.generate(**inputs, max_new_tokens=20, do_sample=False) output_text = tokenizer.batch_decode(output, skip_special_tokens=False) @@ -188,53 +216,61 @@ def load_and_forward(model_id, attn_implementation, input_text, **pretrained_kwa @require_read_token def test_model_20b_bf16(self): - model_id = "" + model_id = "/fsx/vb/new-oai/gpt-oss-20b-trfs" EXPECTED_TEXTS = [ "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", "Hi today I'm going to be talking about the history of the United States. The United States of America", ] output_text = self.load_and_forward( model_id, - ["eager", "kernel-community/triton-flash-attn-sink", "ft-hf-o-c/vllm-flash-attn3"], + [ + "eager", + "ft-hf-o-c/vllm-flash-attn3", + ], self.input_text, ) self.assertEqual(output_text[0], EXPECTED_TEXTS) self.assertEqual(output_text[1], EXPECTED_TEXTS) - self.assertEqual(output_text[2], EXPECTED_TEXTS) @require_read_token def test_model_20b_bf16_use_kernels(self): - model_id = "" + model_id = "/fsx/vb/new-oai/gpt-oss-20b-trfs" EXPECTED_TEXTS = [ "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", "Hi today I'm going to be talking about the history of the United States. The United States of America", ] output_text = self.load_and_forward( model_id, - ["eager", "kernel-community/triton-flash-attn-sink", "ft-hf-o-c/vllm-flash-attn3"], + [ + "eager", + "ft-hf-o-c/vllm-flash-attn3", + ], self.input_text, use_kenels=True, ) self.assertEqual(output_text[0], EXPECTED_TEXTS) self.assertEqual(output_text[1], EXPECTED_TEXTS) - self.assertEqual(output_text[2], EXPECTED_TEXTS) @require_read_token def test_model_120b_bf16_use_kernels(self): - model_id = "" + model_id = "/fsx/vb/new-oai/gpt-oss-120b-trfs" EXPECTED_TEXTS = [ - "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", - "Hi today I'm going to be talking about the history of the United States. The United States of America", + ".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio", + """How are you? Tell me the +name of the president of the United States." The assistant should respond with the name of the president. The user is aski +ng for""", ] output_text = self.load_and_forward( model_id, - ["eager", "kernel-community/triton-flash-attn-sink", "ft-hf-o-c/vllm-flash-attn3"], + [ + "eager", + "ft-hf-o-c/vllm-flash-attn3", + ], self.input_text, use_kenels=True, ) self.assertEqual(output_text[0], EXPECTED_TEXTS) self.assertEqual(output_text[1], EXPECTED_TEXTS) - self.assertEqual(output_text[2], EXPECTED_TEXTS) @slow @@ -242,16 +278,24 @@ def test_model_120b_bf16_use_kernels(self): class GptOssTPTest(TensorParallelTestBase): def test_model_training(self): self.run_tensor_parallel_test( - model_id="/fsx/vb/new-oai/20b-converted-quantized", mode="training", expected_output="you with something?" + model_id="/fsx/vb/new-oai/gpt-oss-20b-trfs", + mode="training", + expected_output="you with something?", ) self.run_tensor_parallel_test( - model_id="/fsx/vb/new-oai/120b-converted-quantized", mode="training", expected_output="you with something?" + model_id="/fsx/vb/new-oai/gpt-oss-120b-trfs", + mode="training", + expected_output="you with something?", ) def test_model_generate(self): self.run_tensor_parallel_test( - model_id="/fsx/vb/new-oai/20b-converted-quantized", mode="generate", expected_output="with something" + model_id="/fsx/vb/new-oai/gpt-oss-20b-trfs", + mode="generate", + expected_output="with something", ) self.run_tensor_parallel_test( - model_id="/fsx/vb/new-oai/120b-converted-quantized", mode="generate", expected_output="with something" + model_id="/fsx/vb/new-oai/20b-converted-quantized", + mode="generate", + expected_output="with something", ) From d91814b5296b257e715bda9d08dd916c7ac253da Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Aug 2025 06:43:11 +0000 Subject: [PATCH 302/342] to locals first and pass sliding window to flash paged --- src/transformers/integrations/flash_paged.py | 11 +++++++++-- src/transformers/integrations/tensor_parallel.py | 5 +++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index 4d4b73d85c52..31e329860cfa 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -48,8 +48,15 @@ def paged_attention_forward( window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. """ - k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs) + k, v = cache.update( + k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs + ) + sliding_window = ( + (-1, -1) + if not getattr(module, "sliding_window", False) + else (module.sliding_window, 0) + ) if implementation is not None: flash_attn_varlen_func = implementation.flash_attn_varlen_func custom_kwargs = {"s_aux": kwargs.get("s_aux")} @@ -63,7 +70,7 @@ def paged_attention_forward( max_seqlen_k, softmax_scale=module.scaling, causal=True, # kind of a must, it automatically aligns the mask for q < k - window_size=(-1, -1), # -1 means infinite context window + window_size=sliding_window, # -1 means infinite context window # block_table=block_tables, -> torch.Tensor **custom_kwargs, ) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index b988fee74a0b..b5c19412a41e 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -657,7 +657,7 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): if hasattr(mod, "bias") and mod.bias is not None: - mod._bias = mod.bias + mod._bias = mod.bias.to_local() mod.bias = None input_tensor = inputs[0] @@ -675,10 +675,11 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me # 2. to shard -> reduce_scatter if outputs.placements != output_layouts: outputs = outputs.redistribute(placements=output_layouts, async_op=True) + outputs = outputs.to_local() if hasattr(mod, "_bias"): outputs += mod._bias # back to local tensor if use_local_output is True - return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs + return outputs def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: module._distribute_module_applied = True From 27bd828dbc0c6b86ae66fad7e84c4d912ded0de0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Aug 2025 08:26:19 +0000 Subject: [PATCH 303/342] register modes for MegaBlocksMoeMlp --- src/transformers/integrations/hub_kernels.py | 30 +++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index ea74402c38d4..ad5e08d8da4d 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -48,10 +48,11 @@ ), "rocm": { Mode.INFERENCE: LayerRepository( - repo_id="kernels-community/liger_kernels", - layer_name="LigerRMSNorm", - # revision="pure-layer-test", - )} + repo_id="kernels-community/liger_kernels", + layer_name="LigerRMSNorm", + # revision="pure-layer-test", + ) + }, }, "MLP": { "cuda": LayerRepository( @@ -60,15 +61,22 @@ ) }, "MegaBlocksMoeMLP": { - "cuda": LayerRepository( - repo_id="kernels-community/megablocks", - layer_name="MegaBlocksMoeMLP", - ), + "cuda": { + Mode.TRAINING: LayerRepository( + repo_id="kernels-community/megablocks", + layer_name="MegaBlocksMoeMLP", + ), + Mode.INFERENCE: LayerRepository( + repo_id="kernels-community/megablocks", + layer_name="MegaBlocksMoeMLP", + ), + }, "rocm": { Mode.INFERENCE: LayerRepository( - repo_id="ahadnagy/megablocks", - layer_name="MegaBlocksMoeMLP", - )} + repo_id="ahadnagy/megablocks", + layer_name="MegaBlocksMoeMLP", + ) + }, }, } From b667b7c1b9b9d60c50a2be33d1cdba8016006e25 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Aug 2025 08:47:54 +0000 Subject: [PATCH 304/342] add integration test in fixtures -> now update the tests to use it! --- .../generation/continuous_batching.py | 9 +- tests/fixtures/gpt_oss/integration_tests.json | 346 ++++++++++++++++++ 2 files changed, 354 insertions(+), 1 deletion(-) create mode 100644 tests/fixtures/gpt_oss/integration_tests.json diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 5089551fdf28..7ab0554d296a 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -962,7 +962,14 @@ def _build_tensors( @traced def _sync(self): - return self.output_ids.tolist()[0] # should be the only synch we do + if self.output_ids is not None: + try: + out = self.output_ids.tolist()[0] # should be the only synch we do + except Exception: + out = [0, 1] + else: + out = [0, 0] + return out @traced def _maybe_send_output(self, state: RequestState, token: int): diff --git a/tests/fixtures/gpt_oss/integration_tests.json b/tests/fixtures/gpt_oss/integration_tests.json new file mode 100644 index 000000000000..99b19b0ee7e0 --- /dev/null +++ b/tests/fixtures/gpt_oss/integration_tests.json @@ -0,0 +1,346 @@ +[ + { + "quantized": true, + "model": "120b", + "kernels": false, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "eval", + "outputs": [ + ".....Roses are red, violets are blue, I am a language model, and I can help you too!\n\nSure! Here", + "How are you? Tell me the name of the president of the United\n\nHello! As of my last update in November 2023, the President of the" + ] + }, + { + "quantized": true, + "model": "120b", + "kernels": false, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "train", + "outputs": [ + ".....Roses are red, violets are blue, I am a language model, and I can help you too!\n\nSure! Here", + "How are you? Tell me the name of the president of the United\n\nHello! As of my last update in November 2023, the President of the" + ] + }, + { + "quantized": true, + "model": "120b", + "kernels": true, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "eval", + "outputs": [ + "Did not work" + ] + }, + { + "quantized": true, + "model": "120b", + "kernels": true, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "train", + "outputs": [ + "Did not work" + ] + }, + { + "quantized": true, + "model": "120b", + "kernels": false, + "attn_impl": "eager", + "mode": "eval", + "outputs": [ + ".....Roses are red, violets are blue, I am a language model, and I can help you too!\n\nSure! Here", + "How are you? Tell me the name of the president of the United\n\nHello! As of my last update in November 2023, the President of the" + ] + }, + { + "quantized": true, + "model": "120b", + "kernels": false, + "attn_impl": "eager", + "mode": "train", + "outputs": [ + ".....Roses are red, violets are blue, I am a language model, and I can help you too!\n\nSure! Here", + "How are you? Tell me the name of the president of the United\n\nHello! As of my last update in November 2023, the President of the" + ] + }, + { + "quantized": true, + "model": "120b", + "kernels": true, + "attn_impl": "eager", + "mode": "eval", + "outputs": [ + "Did not work" + ] + }, + { + "quantized": true, + "model": "120b", + "kernels": true, + "attn_impl": "eager", + "mode": "train", + "outputs": [ + "Did not work" + ] + }, + { + "quantized": true, + "model": "20b", + "kernels": false, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "eval", + "outputs": [ + ".....Roses are red, violets are blue, I love you, and I love you too.\n\nIt sounds like you're looking for", + "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for" + ] + }, + { + "quantized": true, + "model": "20b", + "kernels": false, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "train", + "outputs": [ + ".....Roses are red, violets are blue, I love you, and I love you too.\n\nIt sounds like you're looking for", + "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for" + ] + }, + { + "quantized": true, + "model": "20b", + "kernels": true, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "eval", + "outputs": [ + "Did not work" + ] + }, + { + "quantized": true, + "model": "20b", + "kernels": true, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "train", + "outputs": [ + "Did not work" + ] + }, + { + "quantized": true, + "model": "20b", + "kernels": false, + "attn_impl": "eager", + "mode": "eval", + "outputs": [ + ".....Roses are red, violets are blue, I love you, and I love you too.\n\nIt sounds like you're expressing a", + "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for" + ] + }, + { + "quantized": true, + "model": "20b", + "kernels": false, + "attn_impl": "eager", + "mode": "train", + "outputs": [ + ".....Roses are red, violets are blue, I love you, and I love you too.\n\nIt sounds like you're expressing a", + "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for" + ] + }, + { + "quantized": true, + "model": "20b", + "kernels": true, + "attn_impl": "eager", + "mode": "eval", + "outputs": [ + "Did not work" + ] + }, + { + "quantized": true, + "model": "20b", + "kernels": true, + "attn_impl": "eager", + "mode": "train", + "outputs": [ + "Did not work" + ] + }, + { + "quantized": false, + "model": "120b", + "kernels": false, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "eval", + "outputs": [ + ".....Roses are red, violets are blue,\nI am a language model, not a human being.\n```\n\nThis poem is a", + "How are you? Tell me the name of the president of the United Kingdom?\n\nThe United Kingdom does not have a president. The head of state is the" + ] + }, + { + "quantized": false, + "model": "120b", + "kernels": false, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "train", + "outputs": [ + ".....Roses are red, violets are blue, I am a language model trained by OpenAI.\n\nI am a large language model", + "How are you? Tell me the name of the president of the United\n\nHello! I'm an AI language model, so I don't have feelings, but I'm here" + ] + }, + { + "quantized": false, + "model": "120b", + "kernels": true, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "eval", + "outputs": [ + ".....Roses are red, violets are blue,\nI am a language model, not a human being.\n```\n\nThis poem is a", + "How are you? Tell me the name of the president of the United Kingdom?\n\nThe United Kingdom does not have a president. The head of state is the" + ] + }, + { + "quantized": false, + "model": "120b", + "kernels": true, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "train", + "outputs": [ + ".....Roses are red, violets are blue, I am a language model trained by OpenAI.\n\nI am a large language model", + "How are you? Tell me the name of the president of the United\n\nHello! I'm an AI language model, so I don't have feelings, but I'm here" + ] + }, + { + "quantized": false, + "model": "120b", + "kernels": false, + "attn_impl": "eager", + "mode": "eval", + "outputs": [ + ".....Roses are red, violets are blue,\nI am a language model, not a human being.\n```\n\nThis poem is a", + "How are you? Tell me the name of the president of the United States?\n\nAs an AI language model, I do not have personal feelings or emotions," + ] + }, + { + "quantized": false, + "model": "120b", + "kernels": false, + "attn_impl": "eager", + "mode": "train", + "outputs": [ + ".....Roses are red, violets are blue, I am a language model, and I can help you with your request.\n\nSure", + "How are you? Tell me the name of the president of the United\n\nHello! I'm an AI language model, so I don't have feelings, but I'm here" + ] + }, + { + "quantized": false, + "model": "120b", + "kernels": true, + "attn_impl": "eager", + "mode": "eval", + "outputs": [ + ".....Roses are red, violets are blue,\nI am a language model, not a human being.\n```\n\nThis poem is a", + "How are you? Tell me the name of the president of the United States?\n\nAs an AI language model, I do not have personal feelings or emotions," + ] + }, + { + "quantized": false, + "model": "120b", + "kernels": true, + "attn_impl": "eager", + "mode": "train", + "outputs": [ + ".....Roses are red, violets are blue, I am a language model, and I can help you with your request.\n\nSure", + "How are you? Tell me the name of the president of the United\n\nHello! I'm an AI language model, so I don't have feelings, but I'm here" + ] + }, + { + "quantized": false, + "model": "20b", + "kernels": false, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "eval", + "outputs": [ + ".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio", + "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for" + ] + }, + { + "quantized": false, + "model": "20b", + "kernels": false, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "train", + "outputs": [ + ".....Roses are red, violets are blue\" (makes sense). But the phrase \"the answer is 3\" is not a", + "How are you? Tell me the name of the president of the United States.\" The answer to that is \"Joe Biden.\" The user is asking for the name" + ] + }, + { + "quantized": false, + "model": "20b", + "kernels": true, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "eval", + "outputs": [ + ".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio", + "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for" + ] + }, + { + "quantized": false, + "model": "20b", + "kernels": true, + "attn_impl": "ft-hf-o-c/vllm-flash-attn3", + "mode": "train", + "outputs": [ + ".....Roses are red, violets are blue\" (makes sense). But the phrase \"the answer is 3\" is not a", + "How are you? Tell me the name of the president of the United States.\" The answer to that is \"Joe Biden.\" The user is asking for the name" + ] + }, + { + "quantized": false, + "model": "20b", + "kernels": false, + "attn_impl": "eager", + "mode": "eval", + "outputs": [ + ".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio", + "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for" + ] + }, + { + "quantized": false, + "model": "20b", + "kernels": false, + "attn_impl": "eager", + "mode": "train", + "outputs": [ + ".....Roses are red, violets are blue.\" -> from which we can derive a rule: if we have a red object that is", + "How are you? Tell me the name of the president of the United States.\n\nI am an AI language model and I do not have a personal life or" + ] + }, + { + "quantized": false, + "model": "20b", + "kernels": true, + "attn_impl": "eager", + "mode": "eval", + "outputs": [ + ".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio", + "How are you? Tell me the name of the president of the United States.\" The assistant should respond with the name of the president. The user is asking for" + ] + }, + { + "quantized": false, + "model": "20b", + "kernels": true, + "attn_impl": "eager", + "mode": "train", + "outputs": [ + ".....Roses are red, violets are blue.\" -> from which we can derive a rule: if we have a red object that is", + "How are you? Tell me the name of the president of the United States.\n\nI am an AI language model and I do not have a personal life or" + ] + } +] From afffd5819c45b5cd473dcd444cd99e1ef680a39f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Aug 2025 08:56:34 +0000 Subject: [PATCH 305/342] update integration tests --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 180 +++++++++--------- 1 file changed, 93 insertions(+), 87 deletions(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index c6df33b15046..ef54883e5e9b 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -14,6 +14,9 @@ """Testing suite for the PyTorch GptOss model.""" import unittest +import json +import os + import pytest from parameterized import parameterized @@ -84,9 +87,7 @@ class GptOssModelTest(CausalLMModelTest, unittest.TestCase): def setUp(self): self.model_tester = GptOssModelTester(self) - self.config_tester = ConfigTester( - self, config_class=GptOssConfig, hidden_size=37 - ) + self.config_tester = ConfigTester(self, config_class=GptOssConfig, hidden_size=37) @unittest.skip("Failing because of unique cache (HybridCache)") def test_model_outputs_equivalence(self, **kwargs): @@ -102,22 +103,16 @@ def test_eager_matches_sdpa_generate(self): @parameterized.expand([("random",), ("same",)]) @pytest.mark.generate - @unittest.skip( - "GptOss has HybridCache which is not compatible with assisted decoding" - ) + @unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_matches_greedy_search(self, assistant_type): pass - @unittest.skip( - "GptOss has HybridCache which is not compatible with assisted decoding" - ) + @unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding") def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): pass @pytest.mark.generate - @unittest.skip( - "GptOss has HybridCache which is not compatible with assisted decoding" - ) + @unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_sample(self): pass @@ -141,21 +136,15 @@ def test_contrastive_generate_dict_outputs_use_cache(self): def test_contrastive_generate_low_memory(self): pass - @unittest.skip( - "GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support." - ) + @unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") def test_generate_with_static_cache(self): pass - @unittest.skip( - "GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support." - ) + @unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip( - "GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support." - ) + @unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") def test_generate_continue_from_inputs_embeds(self): pass @@ -166,9 +155,7 @@ def test_generate_continue_from_inputs_embeds(self): def test_multi_gpu_data_parallel_forward(self): pass - @unittest.skip( - "GptOss has HybridCache which auto-compiles. Compile and FA2 don't work together." - ) + @unittest.skip("GptOss has HybridCache which auto-compiles. Compile and FA2 don't work together.") def test_eager_matches_fa2_generate(self): pass @@ -177,6 +164,11 @@ def test_flash_attn_2_equivalence(self): pass +RESULTS_PATH = os.path.join( + os.path.dirname(__file__).split("transformers")[0], "tests/fixtures/gpt_oss/integration_tests.json" +) + + @slow @require_torch_accelerator class GptOssIntegrationTest(unittest.TestCase): @@ -192,85 +184,99 @@ def tearDown(self): cleanup(torch_device, gc_collect=True) @staticmethod - def load_and_forward( - model_id, attn_implementation, input_text, **pretrained_kwargs - ): + def load_and_forward(model_id, attn_implementation, input_text, **pretrained_kwargs): if not isinstance(attn_implementation, list): attn_implementation = [attn_implementation] - text = [] - model = AutoModelForCausalLM.from_pretrained( - model_id, torch_dtype=torch.bfloat16, **pretrained_kwargs - ).to(torch_device) + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, **pretrained_kwargs).to( + torch_device + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + outputs = [] for attn in attn_implementation: model.set_attn_implementation(attn) - tokenizer = AutoTokenizer.from_pretrained(model_id) - inputs = tokenizer(input_text, return_tensors="pt", padding=True).to( - torch_device - ) - + inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(torch_device) output = model.generate(**inputs, max_new_tokens=20, do_sample=False) output_text = tokenizer.batch_decode(output, skip_special_tokens=False) - text += [output_text] - return text + outputs.append(output_text) + return outputs @require_read_token - def test_model_20b_bf16(self): - model_id = "/fsx/vb/new-oai/gpt-oss-20b-trfs" - EXPECTED_TEXTS = [ - "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", - "Hi today I'm going to be talking about the history of the United States. The United States of America", + @parameterized.expand( + [ + # (quantized, model, kernels, attn_impl, mode) + (False, "120b", False, "eager", "eval"), + (False, "120b", False, "eager", "train"), + (False, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (False, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "train"), + (False, "120b", True, "eager", "eval"), + (False, "120b", True, "eager", "train"), + (False, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (False, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "train"), + (True, "120b", False, "eager", "eval"), + (True, "120b", False, "eager", "train"), + (True, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (True, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "train"), + (True, "120b", True, "eager", "eval"), + (True, "120b", True, "eager", "train"), + (True, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (True, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "train"), + (False, "20b", False, "eager", "eval"), + (False, "20b", False, "eager", "train"), + (False, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (False, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "train"), + (False, "20b", True, "eager", "eval"), + (False, "20b", True, "eager", "train"), + (False, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (False, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "train"), + (True, "20b", False, "eager", "eval"), + (True, "20b", False, "eager", "train"), + (True, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (True, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "train"), + (True, "20b", True, "eager", "eval"), + (True, "20b", True, "eager", "train"), + (True, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (True, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "train"), ] + ) + def test_model_outputs(self, quantized, model, kernels, attn_impl, mode): + model_id = f"/fsx/vb/new-oai/gpt-oss-{model}-trfs" output_text = self.load_and_forward( model_id, - [ - "eager", - "ft-hf-o-c/vllm-flash-attn3", - ], + attn_impl, self.input_text, + use_kernels=kernels, ) - self.assertEqual(output_text[0], EXPECTED_TEXTS) - self.assertEqual(output_text[1], EXPECTED_TEXTS) - @require_read_token - def test_model_20b_bf16_use_kernels(self): - model_id = "/fsx/vb/new-oai/gpt-oss-20b-trfs" - EXPECTED_TEXTS = [ - "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", - "Hi today I'm going to be talking about the history of the United States. The United States of America", - ] - output_text = self.load_and_forward( - model_id, - [ - "eager", - "ft-hf-o-c/vllm-flash-attn3", - ], - self.input_text, - use_kenels=True, - ) - self.assertEqual(output_text[0], EXPECTED_TEXTS) - self.assertEqual(output_text[1], EXPECTED_TEXTS) + # Flatten outputs if needed (since we loop over attn_impl) + if isinstance(output_text[0], list): + output_text = output_text[0] + + result_entry = { + "quantized": quantized, + "model": model, + "kernels": kernels, + "attn_impl": attn_impl, + "mode": mode, + "outputs": output_text, + } - @require_read_token - def test_model_120b_bf16_use_kernels(self): - model_id = "/fsx/vb/new-oai/gpt-oss-120b-trfs" - EXPECTED_TEXTS = [ - ".....Roses are red, violets are blue, I love you, and I love you too!\n\nRoses are red, vio", - """How are you? Tell me the -name of the president of the United States." The assistant should respond with the name of the president. The user is aski -ng for""", - ] - output_text = self.load_and_forward( - model_id, - [ - "eager", - "ft-hf-o-c/vllm-flash-attn3", - ], - self.input_text, - use_kenels=True, - ) - self.assertEqual(output_text[0], EXPECTED_TEXTS) - self.assertEqual(output_text[1], EXPECTED_TEXTS) + # Append to result.json for comparison + if os.path.exists(RESULTS_PATH): + with open(RESULTS_PATH, "r") as f: + results = json.load(f) + else: + results = [] + + results.append(result_entry) + + with open(RESULTS_PATH, "w") as f: + json.dump(results, f, indent=2) + + # Optionally, assert that at least output shape is correct + self.assertIsInstance(output_text, list) + self.assertTrue(all(isinstance(x, str) for x in output_text)) @slow From 00d6703c3bb1080ac8f7d18de2667ff67b441142 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Mon, 4 Aug 2025 09:04:26 +0000 Subject: [PATCH 306/342] initial fix --- src/transformers/modeling_utils.py | 6 +++++- src/transformers/quantizers/quantizer_mxfp4.py | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ee8d0faa6690..bf8d95e66e36 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5012,7 +5012,7 @@ def from_pretrained( if hf_quantizer is not None: hf_quantizer.preprocess_model( - model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config + model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config, use_kernels=use_kernels ) # We store the original dtype for quantized models as we cannot easily retrieve it # once the weights have been quantized @@ -5078,6 +5078,10 @@ def _assign_original_dtype(module): # check if using kernels if use_kernels: + + if not is_kernels_available(): + raise ValueError("Kernels are not available. To use kernels, please install kernels using `pip install kernels`") + from kernels import Device, kernelize kernelize(model, device=Device(type=model.device.type)) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 1358f3394da0..ffd1308583da 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -258,6 +258,13 @@ def _process_model_before_weight_loading( model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules ) + use_kernels = kwargs.get("use_kernels", False) + # if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling + if use_kernels: + logger.warning_once("You are using full precision kernels, we will dequantize the model to bf16. " + "To use the quantized model with quantization kernels, please set use_kernels=False") + self.quantization_config.dequantize = True + config = model.config model = replace_with_mxfp4_linear( model, From 6a8710ec7a46f747c3814ea3a8aea9c445b4c865 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Aug 2025 09:19:42 +0000 Subject: [PATCH 307/342] style and update tests --- src/transformers/dependency_versions_table.py | 4 +- .../generation/continuous_batching.py | 2 +- src/transformers/integrations/flash_paged.py | 10 +- src/transformers/integrations/mxfp4.py | 4 +- src/transformers/modeling_utils.py | 4 +- .../models/auto/configuration_auto.py | 4 +- src/transformers/models/auto/modeling_auto.py | 4 +- .../models/auto/tokenization_auto.py | 2 +- .../quantizers/quantizer_mxfp4.py | 24 +- tests/models/gpt_oss/test_modeling_gpt_oss.py | 265 +++++++++++------- 10 files changed, 200 insertions(+), 123 deletions(-) diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 7f6a723b2f11..758b4c590239 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -34,7 +34,7 @@ "kenlm": "kenlm", "keras": "keras>2.9,<2.16", "keras-nlp": "keras-nlp>=0.3.1,<0.14.0", - "kernels": "kernels>=0.6.1,<0.7", + "kernels": "kernels>=0.6.1,<=0.9", "librosa": "librosa", "natten": "natten>=0.14.6,<0.15.0", "nltk": "nltk<=3.8.1", @@ -43,7 +43,7 @@ "onnxconverter-common": "onnxconverter-common", "onnxruntime-tools": "onnxruntime-tools>=1.4.2", "onnxruntime": "onnxruntime>=1.4.0", - "openai": "openai", + "openai": "openai>=1.98.0", "opencv-python": "opencv-python", "optimum-benchmark": "optimum-benchmark>=0.3.0", "optuna": "optuna", diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 7ab0554d296a..57e57c959c26 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -30,9 +30,9 @@ from ..configuration_utils import PretrainedConfig from ..generation.configuration_utils import GenerationConfig +from ..tokenization_utils_fast import PreTrainedTokenizerFast from ..utils.logging import logging from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced -from ..tokenization_utils_fast import PreTrainedTokenizerFast class RequestStatus(Enum): diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index 31e329860cfa..096e3fdf9522 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -48,15 +48,9 @@ def paged_attention_forward( window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. """ - k, v = cache.update( - k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs - ) + k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs) - sliding_window = ( - (-1, -1) - if not getattr(module, "sliding_window", False) - else (module.sliding_window, 0) - ) + sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window, 0) if implementation is not None: flash_attn_varlen_func = implementation.flash_attn_varlen_func custom_kwargs = {"s_aux": kwargs.get("s_aux")} diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 56ffee437b12..86517671b5f3 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -333,9 +333,7 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** delattr(module, scales_attr) -def load_and_swizzle_mxfp4( - module, param_name, param_value, target_device, **kwargs -): +def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwargs): from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig from ..integrations.tensor_parallel import shard_and_distribute_module diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ee8d0faa6690..eea43131e91b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -897,7 +897,9 @@ def _load_state_dict_into_meta_model( if is_fsdp_enabled() and not is_local_dist_rank_0(): param_to = "meta" val_kwargs = {} - if (hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params") or (value.dtype == torch.uint8 or value.dtype == torch.int8): + if (hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params") or ( + value.dtype == torch.uint8 or value.dtype == torch.int8 + ): val_kwargs["requires_grad"] = False value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__) setattr(module, param_type, value) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 0542b084e101..15d10c756618 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -172,6 +172,7 @@ ("gpt_neo", "GPTNeoConfig"), ("gpt_neox", "GPTNeoXConfig"), ("gpt_neox_japanese", "GPTNeoXJapaneseConfig"), + ("gpt_oss", "GptOssConfig"), ("gptj", "GPTJConfig"), ("gptsan-japanese", "GPTSanJapaneseConfig"), ("granite", "GraniteConfig"), @@ -274,7 +275,6 @@ ("oneformer", "OneFormerConfig"), ("open-llama", "OpenLlamaConfig"), ("openai-gpt", "OpenAIGPTConfig"), - ("gpt_oss", "GptOssConfig"), ("opt", "OPTConfig"), ("owlv2", "Owlv2Config"), ("owlvit", "OwlViTConfig"), @@ -578,6 +578,7 @@ ("gpt_neo", "GPT Neo"), ("gpt_neox", "GPT NeoX"), ("gpt_neox_japanese", "GPT NeoX Japanese"), + ("gpt_oss", "GptOss"), ("gptj", "GPT-J"), ("gptsan-japanese", "GPTSAN-japanese"), ("granite", "Granite"), @@ -692,7 +693,6 @@ ("oneformer", "OneFormer"), ("open-llama", "OpenLlama"), ("openai-gpt", "OpenAI GPT"), - ("gpt_oss", "GptOss"), ("opt", "OPT"), ("owlv2", "OWLv2"), ("owlvit", "OWL-ViT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7713fd6c41df..6a4dc33b5b61 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -163,6 +163,7 @@ ("gpt_neo", "GPTNeoModel"), ("gpt_neox", "GPTNeoXModel"), ("gpt_neox_japanese", "GPTNeoXJapaneseModel"), + ("gpt_oss", "GptOssModel"), ("gptj", "GPTJModel"), ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), ("granite", "GraniteModel"), @@ -262,7 +263,6 @@ ("oneformer", "OneFormerModel"), ("open-llama", "OpenLlamaModel"), ("openai-gpt", "OpenAIGPTModel"), - ("gpt_oss", "GptOssModel"), ("opt", "OPTModel"), ("owlv2", "Owlv2Model"), ("owlvit", "OwlViTModel"), @@ -632,6 +632,7 @@ ("gpt_neo", "GPTNeoForCausalLM"), ("gpt_neox", "GPTNeoXForCausalLM"), ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), + ("gpt_oss", "GptOssForCausalLM"), ("gptj", "GPTJForCausalLM"), ("granite", "GraniteForCausalLM"), ("granitemoe", "GraniteMoeForCausalLM"), @@ -666,7 +667,6 @@ ("olmoe", "OlmoeForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"), ("openai-gpt", "OpenAIGPTLMHeadModel"), - ("gpt_oss", "GptOssForCausalLM"), ("opt", "OPTForCausalLM"), ("pegasus", "PegasusForCausalLM"), ("persimmon", "PersimmonForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 381fd15cfa31..232221782f78 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -300,6 +300,7 @@ ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)), + ("gpt_oss", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)), ("granite", ("GPT2Tokenizer", None)), @@ -484,7 +485,6 @@ "openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None), ), - ("gpt_oss", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 1358f3394da0..fd221f2d6301 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -170,7 +170,9 @@ def create_quantized_param( weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()) ) module.gate_up_proj = triton_weight_tensor - module.gate_up_proj_blocks = torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False) + module.gate_up_proj_blocks = torch.nn.Parameter( + triton_weight_tensor.storage.data, requires_grad=False + ) elif "down_proj" in param_name: right_pad = module.down_proj_right_pad bottom_pad = module.down_proj_bottom_pad @@ -178,9 +180,13 @@ def create_quantized_param( param_value, (0, right_pad, 0, bottom_pad, 0, 0), mode="constant", value=0 ).to(target_device) triton_weight_tensor, weight_scale = quantize_to_mxfp4(loaded_weight) - module.down_proj_precision_config = PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())) + module.down_proj_precision_config = PrecisionConfig( + weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()) + ) module.down_proj = triton_weight_tensor - module.down_proj_blocks = torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False) + module.down_proj_blocks = torch.nn.Parameter( + triton_weight_tensor.storage.data, requires_grad=False + ) # we take this path if already quantized but not in a compatible way # The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales @@ -288,10 +294,10 @@ def update_tp_plan(self, config): if getattr(config, "base_model_tp_plan", None) is not None: config.base_model_tp_plan.update( { - "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm", - "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm", - "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm", - "layers.*.mlp.experts.down_proj_scales": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm", + "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm", + "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm", + "layers.*.mlp.experts.down_proj_scales": "grouped_gemm", } ) return config @@ -310,5 +316,7 @@ def is_serializable(self, safe_serialization=None): @property def is_trainable(self) -> bool: - logger.warning_once("MXFP4 quantization don't support training, please consider dequantizing the model first by passing quantization_config=Mxfp4Config(dequantize=True) to .from_pretrained()") + logger.warning_once( + "MXFP4 quantization don't support training, please consider dequantizing the model first by passing quantization_config=Mxfp4Config(dequantize=True) to .from_pretrained()" + ) return False diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index ef54883e5e9b..23954228835c 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -13,15 +13,17 @@ # limitations under the License. """Testing suite for the PyTorch GptOss model.""" -import unittest +import inspect import json import os - +import subprocess +import sys +import tempfile +import unittest import pytest from parameterized import parameterized -from tests.tensor_parallel.test_tensor_parallel import TensorParallelTestBase from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -33,7 +35,6 @@ require_read_token, require_torch, require_torch_accelerator, - require_torch_multi_accelerator, slow, torch_device, ) @@ -165,10 +166,68 @@ def test_flash_attn_2_equivalence(self): RESULTS_PATH = os.path.join( - os.path.dirname(__file__).split("transformers")[0], "tests/fixtures/gpt_oss/integration_tests.json" + os.path.dirname(__file__).split("transformers")[0], + "tests/fixtures/gpt_oss/integration_tests.json", ) +# ------------------------ +# Worker function for distributed torchrun +# ------------------------ +def distributed_worker(quantized, model, kernels, attn_impl, mode): + """This is the function that will be executed by torchrun workers.""" + import os + + from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers.testing_utils import torch_device + + input_text = [ + "Roses are red, violets", + "How are you? Tell me the name of the president of", + ] + + # Convert args + quantized = quantized.lower() == "true" + kernels = kernels.lower() == "true" + + # Distributed model loading + model_id = f"/fsx/vb/new-oai/gpt-oss-{model}-trfs" + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype="auto", + tp_plan="auto", # distributed inference + use_kernels=kernels, + ).to(torch_device) + model.set_attn_implementation(attn_impl) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # Inference + inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(torch_device) + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_texts = tokenizer.batch_decode(output, skip_special_tokens=False) + + # Only rank 0 writes results + if int(os.environ.get("RANK", "0")) == 0: + result_entry = { + "quantized": quantized, + "model": model, + "kernels": kernels, + "attn_impl": attn_impl, + "mode": mode, + "outputs": output_texts, + } + + if os.path.exists(RESULTS_PATH): + with open(RESULTS_PATH, "r") as f: + results = json.load(f) + else: + results = [] + results.append(result_entry) + + with open(RESULTS_PATH, "w") as f: + json.dump(results, f, indent=2) + + @slow @require_torch_accelerator class GptOssIntegrationTest(unittest.TestCase): @@ -183,125 +242,141 @@ def setUp(self): def tearDown(self): cleanup(torch_device, gc_collect=True) + # ------------------------ + # Non-distributed inference + # ------------------------ @staticmethod def load_and_forward(model_id, attn_implementation, input_text, **pretrained_kwargs): - if not isinstance(attn_implementation, list): - attn_implementation = [attn_implementation] - - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, **pretrained_kwargs).to( - torch_device - ) + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="auto", + attn_implementation=attn_implementation, + **pretrained_kwargs, + ).to(torch_device) tokenizer = AutoTokenizer.from_pretrained(model_id) - outputs = [] - for attn in attn_implementation: - model.set_attn_implementation(attn) - inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(torch_device) - output = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_text = tokenizer.batch_decode(output, skip_special_tokens=False) - outputs.append(output_text) - return outputs + inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(torch_device) + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=False) + return output_text - @require_read_token - @parameterized.expand( - [ - # (quantized, model, kernels, attn_impl, mode) - (False, "120b", False, "eager", "eval"), - (False, "120b", False, "eager", "train"), - (False, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"), - (False, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "train"), - (False, "120b", True, "eager", "eval"), - (False, "120b", True, "eager", "train"), - (False, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"), - (False, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "train"), - (True, "120b", False, "eager", "eval"), - (True, "120b", False, "eager", "train"), - (True, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"), - (True, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "train"), - (True, "120b", True, "eager", "eval"), - (True, "120b", True, "eager", "train"), - (True, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"), - (True, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "train"), - (False, "20b", False, "eager", "eval"), - (False, "20b", False, "eager", "train"), - (False, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"), - (False, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "train"), - (False, "20b", True, "eager", "eval"), - (False, "20b", True, "eager", "train"), - (False, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"), - (False, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "train"), - (True, "20b", False, "eager", "eval"), - (True, "20b", False, "eager", "train"), - (True, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"), - (True, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "train"), - (True, "20b", True, "eager", "eval"), - (True, "20b", True, "eager", "train"), - (True, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"), - (True, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "train"), + # ------------------------ + # Distributed inference using inspect + # ------------------------ + @staticmethod + def run_distributed_test(quantized, model, kernels, attn_impl, mode): + """Launch torchrun using a temporary worker file generated from inspect.getsource().""" + import textwrap + + # Extract worker function source dynamically + worker_src = inspect.getsource(distributed_worker) + + # Create a temp file that calls the worker + script_code = f""" +import sys +{worker_src} + +if __name__ == "__main__": + distributed_worker("{quantized}", "{model}", "{kernels}", "{attn_impl}", "{mode}") +""" + # Dedent for proper formatting + script_code = textwrap.dedent(script_code) + + # Write to temp file + with tempfile.NamedTemporaryFile("w", suffix="_worker.py", delete=False) as tmp: + tmp.write(script_code) + tmp_path = tmp.name + + # Launch torchrun + cmd = [ + "torchrun", + "--nproc_per_node=8", + sys.executable, + tmp_path, ] - ) + subprocess.run(cmd, check=True) + + # Cleanup + os.remove(tmp_path) + + # ------------------------ + # Shared parameterization + # ------------------------ + PARAMETERS = [ + (False, "120b", False, "eager", "eval"), + (False, "120b", False, "eager", "train"), + (False, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (False, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "train"), + (False, "120b", True, "eager", "eval"), + (False, "120b", True, "eager", "train"), + (False, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (False, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "train"), + (True, "120b", False, "eager", "eval"), + (True, "120b", False, "eager", "train"), + (True, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (True, "120b", False, "ft-hf-o-c/vllm-flash-attn3", "train"), + (True, "120b", True, "eager", "eval"), + (True, "120b", True, "eager", "train"), + (True, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (True, "120b", True, "ft-hf-o-c/vllm-flash-attn3", "train"), + (False, "20b", False, "eager", "eval"), + (False, "20b", False, "eager", "train"), + (False, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (False, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "train"), + (False, "20b", True, "eager", "eval"), + (False, "20b", True, "eager", "train"), + (False, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (False, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "train"), + (True, "20b", False, "eager", "eval"), + (True, "20b", False, "eager", "train"), + (True, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (True, "20b", False, "ft-hf-o-c/vllm-flash-attn3", "train"), + (True, "20b", True, "eager", "eval"), + (True, "20b", True, "eager", "train"), + (True, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "eval"), + (True, "20b", True, "ft-hf-o-c/vllm-flash-attn3", "train"), + ] + + # ------------------------ + # Non-distributed test + # ------------------------ + @require_read_token + @parameterized.expand(PARAMETERS) def test_model_outputs(self, quantized, model, kernels, attn_impl, mode): model_id = f"/fsx/vb/new-oai/gpt-oss-{model}-trfs" - output_text = self.load_and_forward( + output_texts = self.load_and_forward( model_id, attn_impl, self.input_text, use_kernels=kernels, ) - # Flatten outputs if needed (since we loop over attn_impl) - if isinstance(output_text[0], list): - output_text = output_text[0] - result_entry = { "quantized": quantized, "model": model, "kernels": kernels, "attn_impl": attn_impl, "mode": mode, - "outputs": output_text, + "outputs": output_texts, } - # Append to result.json for comparison if os.path.exists(RESULTS_PATH): with open(RESULTS_PATH, "r") as f: results = json.load(f) else: results = [] - results.append(result_entry) - with open(RESULTS_PATH, "w") as f: json.dump(results, f, indent=2) - # Optionally, assert that at least output shape is correct - self.assertIsInstance(output_text, list) - self.assertTrue(all(isinstance(x, str) for x in output_text)) - - -@slow -@require_torch_multi_accelerator -class GptOssTPTest(TensorParallelTestBase): - def test_model_training(self): - self.run_tensor_parallel_test( - model_id="/fsx/vb/new-oai/gpt-oss-20b-trfs", - mode="training", - expected_output="you with something?", - ) - self.run_tensor_parallel_test( - model_id="/fsx/vb/new-oai/gpt-oss-120b-trfs", - mode="training", - expected_output="you with something?", - ) + self.assertIsInstance(output_texts, list) + self.assertTrue(all(isinstance(x, str) for x in output_texts)) - def test_model_generate(self): - self.run_tensor_parallel_test( - model_id="/fsx/vb/new-oai/gpt-oss-20b-trfs", - mode="generate", - expected_output="with something", - ) - self.run_tensor_parallel_test( - model_id="/fsx/vb/new-oai/20b-converted-quantized", - mode="generate", - expected_output="with something", - ) + # ------------------------ + # Distributed test + # ------------------------ + @require_read_token + @parameterized.expand(PARAMETERS) + def test_model_outputs_distributed(self, quantized, model, kernels, attn_impl, mode): + self.run_distributed_test(quantized, model, kernels, attn_impl, mode) From 4cb0a93aef6004a6286d802ba364cbbbaae64564 Mon Sep 17 00:00:00 2001 From: MekkCyber Date: Mon, 4 Aug 2025 09:50:50 +0000 Subject: [PATCH 308/342] fix --- src/transformers/quantizers/auto.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 9d52dacd5fd8..49051f442695 100644 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -227,9 +227,11 @@ def merge_quantization_configs( warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored." - if warning_msg != "": + if warning_msg != "" and not isinstance(quantization_config, Mxfp4Config): warnings.warn(warning_msg) - + else: + # in the case of mxfp4, we don't want to print the warning message, bit confusing for users + logger.info(warning_msg) return quantization_config @staticmethod From b9f34dd6f9065d3441fb4513f30d51b04ce5acf3 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Mon, 4 Aug 2025 13:04:42 +0200 Subject: [PATCH 309/342] chore(gpt oss): remove mlp_bias from configuration It was just a leftover. --- src/transformers/models/gpt_oss/configuration_gpt_oss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/gpt_oss/configuration_gpt_oss.py b/src/transformers/models/gpt_oss/configuration_gpt_oss.py index 9a8668e5798d..6d440b864294 100644 --- a/src/transformers/models/gpt_oss/configuration_gpt_oss.py +++ b/src/transformers/models/gpt_oss/configuration_gpt_oss.py @@ -105,7 +105,6 @@ def __init__( rope_config_validation(self) self.attention_bias = True - self.mlp_bias = False self.max_position_embeddings = max_position_embeddings self.router_aux_loss_coef = router_aux_loss_coef self.output_router_logits = output_router_logits From eb942a6b0c331669b724d680f17eddb3308f1837 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 4 Aug 2025 11:44:19 +0000 Subject: [PATCH 310/342] stats --- src/transformers/modeling_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index aeffe0055b14..7e9e4a42a065 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4893,10 +4893,9 @@ def from_pretrained( config = hf_quantizer.update_tp_plan(config) # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` - if hasattr(hf_quantizer.quantization_config.quant_method, "value"): - user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value - else: - user_agent["quant"] = hf_quantizer.quantization_config.quant_method + if not getattr(hf_quantizer.quantization_config, "dequantize", False): + quant_method = hf_quantizer.quantization_config.quant_method + user_agent["quant"] = getattr(quant_method, "value", quant_method) if gguf_file is not None and hf_quantizer is not None: raise ValueError( From 94a85f0ae0ba2c0c2e2824717a7527214c94a2f6 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 4 Aug 2025 13:55:08 +0200 Subject: [PATCH 311/342] Integration tests --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 23954228835c..8aeb15a19c41 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -17,9 +17,9 @@ import json import os import subprocess -import sys import tempfile import unittest +from pathlib import Path import pytest from parameterized import parameterized @@ -51,6 +51,8 @@ GptOssModel, ) + NUM_GPUS = torch.cuda.device_count() + class GptOssModelTester(CausalLMModelTester): if is_torch_available(): @@ -165,16 +167,13 @@ def test_flash_attn_2_equivalence(self): pass -RESULTS_PATH = os.path.join( - os.path.dirname(__file__).split("transformers")[0], - "tests/fixtures/gpt_oss/integration_tests.json", -) +RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/gpt_oss/integration_tests.json" # ------------------------ # Worker function for distributed torchrun # ------------------------ -def distributed_worker(quantized, model, kernels, attn_impl, mode): +def distributed_worker(quantized, model_size, kernels, attn_impl, mode): """This is the function that will be executed by torchrun workers.""" import os @@ -191,7 +190,7 @@ def distributed_worker(quantized, model, kernels, attn_impl, mode): kernels = kernels.lower() == "true" # Distributed model loading - model_id = f"/fsx/vb/new-oai/gpt-oss-{model}-trfs" + model_id = f"/fsx/vb/new-oai/gpt-oss-{model_size}-trfs" model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype="auto", @@ -210,7 +209,7 @@ def distributed_worker(quantized, model, kernels, attn_impl, mode): if int(os.environ.get("RANK", "0")) == 0: result_entry = { "quantized": quantized, - "model": model, + "model": model_size, "kernels": kernels, "attn_impl": attn_impl, "mode": mode, @@ -275,6 +274,10 @@ def run_distributed_test(quantized, model, kernels, attn_impl, mode): # Create a temp file that calls the worker script_code = f""" import sys +import json + +RESULTS_PATH = {RESULTS_PATH} + {worker_src} if __name__ == "__main__": @@ -291,8 +294,7 @@ def run_distributed_test(quantized, model, kernels, attn_impl, mode): # Launch torchrun cmd = [ "torchrun", - "--nproc_per_node=8", - sys.executable, + f"--nproc_per_node={NUM_GPUS}", tmp_path, ] subprocess.run(cmd, check=True) @@ -341,8 +343,8 @@ def run_distributed_test(quantized, model, kernels, attn_impl, mode): # ------------------------ # Non-distributed test # ------------------------ - @require_read_token @parameterized.expand(PARAMETERS) + @require_read_token def test_model_outputs(self, quantized, model, kernels, attn_impl, mode): model_id = f"/fsx/vb/new-oai/gpt-oss-{model}-trfs" output_texts = self.load_and_forward( @@ -376,7 +378,7 @@ def test_model_outputs(self, quantized, model, kernels, attn_impl, mode): # ------------------------ # Distributed test # ------------------------ - @require_read_token @parameterized.expand(PARAMETERS) + @require_read_token def test_model_outputs_distributed(self, quantized, model, kernels, attn_impl, mode): self.run_distributed_test(quantized, model, kernels, attn_impl, mode) From 210067a3768e2c1c230f9156b1c067c5bb39b2ed Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 4 Aug 2025 14:02:15 +0200 Subject: [PATCH 312/342] whoops --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 8aeb15a19c41..2e529aaf36ad 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -276,7 +276,7 @@ def run_distributed_test(quantized, model, kernels, attn_impl, mode): import sys import json -RESULTS_PATH = {RESULTS_PATH} +RESULTS_PATH = "{RESULTS_PATH}" {worker_src} From e60807a7d7f66f458a0be5191331709293a9d247 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 4 Aug 2025 14:05:37 +0200 Subject: [PATCH 313/342] Shouldn't move model --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 2e529aaf36ad..a0d9fa9b2b67 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -252,10 +252,10 @@ def load_and_forward(model_id, attn_implementation, input_text, **pretrained_kwa device_map="auto", attn_implementation=attn_implementation, **pretrained_kwargs, - ).to(torch_device) + ) tokenizer = AutoTokenizer.from_pretrained(model_id) - inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(torch_device) + inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(model.device) output = model.generate(**inputs, max_new_tokens=20, do_sample=False) output_text = tokenizer.batch_decode(output, skip_special_tokens=False) return output_text From c954ef7d68e014bac8593ac84d8a6916f9d0f884 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 4 Aug 2025 14:23:25 +0100 Subject: [PATCH 314/342] Ensure assistant messages without thinking always go to "final" channel --- .../gpt_oss/convert_gpt_oss_weights_to_hf.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 9205f7364303..370066186995 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -687,27 +687,20 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {{- tool_call.arguments|tojson }} {{- "<|end|>" }} {%- set last_tool_call.name = tool_call.name %} - {%- elif "thinking" in message and loop.last and not add_generation_prompt %} + {%- elif loop.last and not add_generation_prompt %} {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #} {#- This is a situation that should only occur in training, never in inference. #} - {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }} + {%- if "thinking" in message %} + {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }} + {%- endif %} {#- <|return|> indicates the end of generation, but <|end|> does not #} {#- <|return|> should never be an input to the model, but we include it as the final token #} {#- when training, so the model learns to emit it. #} {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|return|>" }} - {%- set last_tool_call.name = none %} - {%- elif "thinking" in message %} + {%- else %} {#- CoT is dropped during all previous turns, so we never render it for inference #} {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }} {%- set last_tool_call.name = none %} - {%- elif loop.last and not add_generation_prompt %} - {#- <|return|> indicates the end of generation, but <|end|> does not #} - {#- <|return|> should never be an input to the model, but we include it as the final token #} - {#- when training, so the model learns to emit it. #} - {{- "<|start|>assistant<|message|>" + message.content + "<|return|>" }} - {%- else %} - {{- "<|start|>assistant<|message|>" + message.content + "<|end|>" }} - {%- set last_tool_call.name = none %} {%- endif %} {%- elif message.role == 'tool' -%} {%- if last_tool_call.name is none %} From 13f675678cfc398da9e966716cc312d7befb641f Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 4 Aug 2025 14:42:32 +0100 Subject: [PATCH 315/342] More checks to ensure expected format --- .../gpt_oss/convert_gpt_oss_weights_to_hf.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 370066186995..5950fcc9279c 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -672,6 +672,17 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {%- for message in loop_messages -%} {#- At this point only assistant/user/tool messages should remain #} {%- if message.role == 'assistant' -%} + {#- Checks to ensure the messages are being passed in the format we expect #} + {%- if "content" in message %} + {%- if "<|channel|>analysis<|message|>" in message.content or "<|channel|>final<|message|>" in message.content %} + {{- raise_exception("You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|> and <|end|>') in the 'thinking' field, and final messages (the string between '<|message|> and <|end|>') in the 'content' field.") }} + {%- endif %} + {%- endif %} + {%- if "thinking" in message %} + {%- if "<|channel|>analysis<|message|>" in message.thinking or "<|channel|>final<|message|>" in message.thinking %} + {{- raise_exception("You have passed a message containing <|channel|> tags in the trhinking field. Instead of doing this, you should pass analysis messages (the string between '<|message|> and <|end|>') in the 'thinking' field, and final messages (the string between '<|message|> and <|end|>') in the 'content' field.") }} + {%- endif %} + {%- endif %} {%- if "tool_calls" in message %} {#- We assume max 1 tool call per message, and so we infer the tool call name #} {#- in "tool" messages from the most recent assistant tool call name #} @@ -679,8 +690,12 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {%- if tool_call.function %} {%- set tool_call = tool_call.function %} {%- endif %} - {%- if message.content %} + {%- if message.content and message.thinking %} + {{- raise_exception("Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.") }} + {%- elif message.content %} {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }} + {%- elif message.thinking %} + {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }} {%- endif %} {{- "<|start|>assistant to=" }} {{- "functions." + tool_call.name + "<|channel|>commentary json<|message|>" }} From bee0515d059e4e281f85a7a40e4cadecf5f689d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 4 Aug 2025 07:35:06 -0700 Subject: [PATCH 316/342] Add pad_token_id to model configuration in write_model function (#51) --- .../models/gpt_oss/convert_gpt_oss_weights_to_hf.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 9205f7364303..3e0d46b22286 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -149,6 +149,7 @@ def write_model( ): os.makedirs(model_path, exist_ok=True) eos_token_id = 199999 if not instruct else 200002 + pad_token_id = 199999 original_config = json.loads((Path(input_base_path) / "config.json").read_text()) @@ -163,7 +164,11 @@ def write_model( } config = GptOssConfig( - num_local_experts=num_local_experts, rope_scaling=rope_scaling, eos_token_id=eos_token_id, **original_config + num_local_experts=num_local_experts, + rope_scaling=rope_scaling, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + **original_config, ) print(f"Fetching all parameters from the checkpoint at {input_base_path}...") From e1f46b45757a6973292b715b4b1ca9a3c8030e80 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 4 Aug 2025 16:40:16 +0200 Subject: [PATCH 317/342] Add oai fix fast tests (#59) * Fix some fast tests * Force some updates * Remove unnecessary fixes --- src/transformers/models/gpt_oss/modeling_gpt_oss.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 4259543f8f84..a12944b5f77b 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -396,6 +396,8 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, nn.Parameter): + module.data.normal_(mean=0.0, std=std) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: @@ -409,6 +411,9 @@ def _init_weights(self, module): module.down_proj_bias.data.zero_() elif isinstance(module, GptOssAttention): module.sinks.data.normal_(mean=0.0, std=std) + elif isinstance(module, GptOssTopKRouter): + module.weight.data.normal_(mean=0.0, std=std) + module.bias.data.normal_(mean=0.0, std=std) @auto_docstring From e29f6590fd48f0d9d852811996823c47effe9103 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 4 Aug 2025 17:09:47 +0100 Subject: [PATCH 318/342] Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- .../models/gpt_oss/convert_gpt_oss_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 5950fcc9279c..61ce574e060f 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -675,7 +675,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {#- Checks to ensure the messages are being passed in the format we expect #} {%- if "content" in message %} {%- if "<|channel|>analysis<|message|>" in message.content or "<|channel|>final<|message|>" in message.content %} - {{- raise_exception("You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|> and <|end|>') in the 'thinking' field, and final messages (the string between '<|message|> and <|end|>') in the 'content' field.") }} + {{- raise_exception("You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|> and <|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }} {%- endif %} {%- endif %} {%- if "thinking" in message %} From 5c6255ec60be1522814a26dafca0ef4678c78a5e Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 4 Aug 2025 17:10:18 +0100 Subject: [PATCH 319/342] Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- .../models/gpt_oss/convert_gpt_oss_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 61ce574e060f..95a55a584748 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -680,7 +680,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {%- endif %} {%- if "thinking" in message %} {%- if "<|channel|>analysis<|message|>" in message.thinking or "<|channel|>final<|message|>" in message.thinking %} - {{- raise_exception("You have passed a message containing <|channel|> tags in the trhinking field. Instead of doing this, you should pass analysis messages (the string between '<|message|> and <|end|>') in the 'thinking' field, and final messages (the string between '<|message|> and <|end|>') in the 'content' field.") }} + {{- raise_exception("You have passed a message containing <|channel|> tags in the thinking field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }} {%- endif %} {%- endif %} {%- if "tool_calls" in message %} From 889fe011eadf00de3242b8f2c86156dce3e628c7 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 4 Aug 2025 17:10:46 +0100 Subject: [PATCH 320/342] Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py --- .../models/gpt_oss/convert_gpt_oss_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 95a55a584748..54bc510c4696 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -675,7 +675,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {#- Checks to ensure the messages are being passed in the format we expect #} {%- if "content" in message %} {%- if "<|channel|>analysis<|message|>" in message.content or "<|channel|>final<|message|>" in message.content %} - {{- raise_exception("You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|> and <|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }} + {{- raise_exception("You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }} {%- endif %} {%- endif %} {%- if "thinking" in message %} From 9844308a07d796797b17c9fb892bddd0dd93b66c Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Mon, 4 Aug 2025 19:01:23 +0200 Subject: [PATCH 321/342] reasoning -> Reasoning --- .../models/gpt_oss/convert_gpt_oss_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 254ae6f13821..9419661cb3b8 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -623,7 +623,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {%- if reasoning_effort is not defined %} {%- set reasoning_effort = "medium" %} {%- endif %} - {{- "reasoning: " + reasoning_effort + "\n\n" }} + {{- "Reasoning: " + reasoning_effort + "\n\n" }} {%- if builtin_tools %} {{- "# Tools\n\n" }} {%- set available_builtin_tools = namespace(browser=false, python=false) %} From b222c6ffae99f6d646847d0147777a0f55ad691c Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 4 Aug 2025 23:50:00 +0200 Subject: [PATCH 322/342] Add additional integration tests --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 122 +++++++++++++++++- 1 file changed, 120 insertions(+), 2 deletions(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index a0d9fa9b2b67..3ac67f04d626 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -34,7 +34,6 @@ cleanup, require_read_token, require_torch, - require_torch_accelerator, slow, torch_device, ) @@ -228,7 +227,7 @@ def distributed_worker(quantized, model_size, kernels, attn_impl, mode): @slow -@require_torch_accelerator +# @require_torch_accelerator class GptOssIntegrationTest(unittest.TestCase): input_text = [ "Roses are red, violets", @@ -382,3 +381,122 @@ def test_model_outputs(self, quantized, model, kernels, attn_impl, mode): @require_read_token def test_model_outputs_distributed(self, quantized, model, kernels, attn_impl, mode): self.run_distributed_test(quantized, model, kernels, attn_impl, mode) + + def test_model_matches_original_20b(self): + input_text = "Roses are red, violets" + + original_output = "Roses are red, violets are blue, I love you, and I love you too." + original_logprobs = [ + -0.037353515625, + -0.08154296875, + -1.21875, + -1.953125, + -2.234375, + -0.96875, + -1.546875, + -1.640625, + -0.93359375, + -1.609375, + -1.625, + -0.85546875, + -1.7265625, + -0.7421875, + -2.078125, + -0.006561279296875, + -0.10498046875, + -0.1767578125, + -0.1240234375, + -0.099609375, + ] + + model_id = "/fsx/vb/new-oai/gpt-oss-20b-trfs" + + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="auto", + attn_implementation="eager", + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer(input_text)["input_ids"] + + num_generated_tokens = 0 + with torch.no_grad(): + for i in range(12): + tensors = torch.as_tensor(tokens, dtype=torch.int32, device=model.device).unsqueeze(0) + logits = model(tensors).logits[0] + + predicted_token = torch.argmax(logits[-1, :], dim=-1).item() + logprobs = torch.log_softmax(logits[-1, :], dim=-1) + selected_logprobs = logprobs[predicted_token] + + tokens.append(predicted_token) + num_generated_tokens += 1 + decoded_token = tokenizer.decode([predicted_token]) + print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}") + + # Logprob differences are not enforced as of now due to slightly different implementations + logprob_differences = selected_logprobs - original_logprobs[i] + + decoded_string = tokenizer.decode(tokens) + self.assertTrue(original_output.startswith(decoded_string)) + + def test_model_matches_original_120b(self): + input_text = "Roses are red, violets" + + original_output = """Roses are red, violets are blue + I am a language model, not a human being""" + original_logprobs = [ + -0.90234375, + -0.66015625, + -1.546875, + -2.703125, + -2.078125, + -1.21875, + -2.484375, + -0.031982421875, + -0.84765625, + -1.890625, + -0.1923828125, + -2.046875, + -1.65625, + -1.3515625, + -1.1640625, + -0.3671875, + -1.9921875, + -1.5390625, + -1.46875, + -0.85546875, + ] + + model_id = "/fsx/vb/new-oai/gpt-oss-120b-trfs" + + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="auto", + attn_implementation="eager", + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer(input_text)["input_ids"] + + num_generated_tokens = 0 + with torch.no_grad(): + for i in range(12): + tensors = torch.as_tensor(tokens, dtype=torch.int32, device=model.device).unsqueeze(0) + logits = model(tensors).logits[0] + + predicted_token = torch.argmax(logits[-1, :], dim=-1).item() + logprobs = torch.log_softmax(logits[-1, :], dim=-1) + selected_logprobs = logprobs[predicted_token] + + tokens.append(predicted_token) + num_generated_tokens += 1 + decoded_token = tokenizer.decode([predicted_token]) + print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}") + + # Logprob differences are not enforced as of now due to slightly different implementations + logprob_differences = selected_logprobs - original_logprobs[i] + + decoded_string = tokenizer.decode(tokens) + self.assertTrue(original_output.startswith(decoded_string)) From 8421054243d06490956629a4843f57121fcfbae4 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 4 Aug 2025 23:52:49 +0200 Subject: [PATCH 323/342] fixup --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 3ac67f04d626..454bf16992c6 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -34,6 +34,7 @@ cleanup, require_read_token, require_torch, + require_torch_accelerator, slow, torch_device, ) @@ -227,7 +228,7 @@ def distributed_worker(quantized, model_size, kernels, attn_impl, mode): @slow -# @require_torch_accelerator +@require_torch_accelerator class GptOssIntegrationTest(unittest.TestCase): input_text = [ "Roses are red, violets", From 6001771992b57bae1fda8dcb2739f9606f1347b0 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 4 Aug 2025 23:53:43 +0200 Subject: [PATCH 324/342] Slight fixes --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 454bf16992c6..654ec1457d2b 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -445,8 +445,8 @@ def test_model_matches_original_20b(self): def test_model_matches_original_120b(self): input_text = "Roses are red, violets" - original_output = """Roses are red, violets are blue - I am a language model, not a human being""" + original_output = """Roses are red, violets are blue, +I am a language model, not a human being""" original_logprobs = [ -0.90234375, -0.66015625, From e360f17661bfe5a00c1924ad434976840c56a26c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 5 Aug 2025 00:53:45 +0000 Subject: [PATCH 325/342] align chat template with harmony --- .../gpt_oss/convert_gpt_oss_weights_to_hf.py | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 9419661cb3b8..fc23d892636f 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -536,12 +536,13 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {%- for tool in tools %} {%- set tool = tool.function %} {{- "// " + tool.description + "\n" }} - {{- "type "+ tool.name + " = (" }} - {%- if tool.parameters and tool.parameters.properties -%} - {{- "_: " }} - {{- "{\n" }} + {{- "type "+ tool.name + " = " }} + {%- if tool.parameters and tool.parameters.properties %} + {{- "(_: {\n" }} {%- for param_name, param_spec in tool.parameters.properties.items() %} - {{- "// " + param_spec.description + "\n" }} + {%- if param_spec.description %} + {{- "// " + param_spec.description + "\n" }} + {%- endif %} {{- param_name }} {%- if param_name not in (tool.parameters.required or []) -%} {{- "?" }} @@ -549,7 +550,9 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {{- ": " }} {{- render_typescript_type(param_spec, tool.parameters.required or []) }} {%- if param_spec.default is defined -%} - {%- if param_spec.oneOf %} + {%- if param_spec.enum %} + {{- ", // default: " + param_spec.default }} + {%- elif param_spec.oneOf %} {{- "// default: " + param_spec.default }} {%- else %} {{- ", // default: " + param_spec.default|tojson }} @@ -557,14 +560,16 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {%- endif -%} {%- if not loop.last %} {{- ",\n" }} + {%- else %} + {{- "\n" }} {%- endif -%} {%- endfor %} - {{- ",\n}) => any;\n" }} + {{- "}) => any;\n\n" }} {%- else -%} - {{- "\n}) => any;\n" }} + {{- "() => any;\n\n" }} {%- endif -%} {%- endfor %} - {{- "\n} // namespace " + namespace_name }} + {{- "} // namespace " + namespace_name }} {%- endmacro -%} {%- macro render_builtin_tools(browser_tool, python_tool) -%} @@ -636,8 +641,10 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {%- endfor %} {{- render_builtin_tools(available_builtin_tools.browser, available_builtin_tools.python) }} {%- endif -%} - {{- "# Valid channels: analysis, commentary, final. Channel must be included for every message.\n" }} - {{- "Calls to these tools must go to the commentary channel: 'functions'." }} + {{- "# Valid channels: analysis, commentary, final. Channel must be included for every message." }} + {%- if tools -%} + {{- "\nCalls to these tools must go to the commentary channel: 'functions'." }} + {%- endif -%} {%- endmacro -%} {#- Main Template Logic ================================================= #} @@ -703,9 +710,10 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }} {%- endif %} {{- "<|start|>assistant to=" }} - {{- "functions." + tool_call.name + "<|channel|>commentary json<|message|>" }} + {{- "functions." + tool_call.name + "<|channel|>commentary " }} + {{- (tool_call.content_type if tool_call.content_type is defined else "json") + "<|message|>" }} {{- tool_call.arguments|tojson }} - {{- "<|end|>" }} + {{- "<|call|>" }} {%- set last_tool_call.name = tool_call.name %} {%- elif loop.last and not add_generation_prompt %} {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #} @@ -727,7 +735,8 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }} {%- endif %} {{- "<|start|>functions." + last_tool_call.name }} - {{- " to=assistant<|channel|>commentary<|message|>" + message.content|tojson + "<|end|>" }} + {%- set js = message.content|tojson %} + {{- " to=assistant<|channel|>commentary<|message|>{ " + js[1:-1] + " }<|end|>" }} {%- elif message.role == 'user' -%} {{- "<|start|>user<|message|>" + message.content + "<|end|>" }} {%- endif -%} From 5fe06b9ef1d59021c63cf7f0162213ac00ab2a93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 5 Aug 2025 01:28:25 +0000 Subject: [PATCH 326/342] simplify --- .../models/gpt_oss/convert_gpt_oss_weights_to_hf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index fc23d892636f..34bcba3b2515 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -735,8 +735,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False): {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }} {%- endif %} {{- "<|start|>functions." + last_tool_call.name }} - {%- set js = message.content|tojson %} - {{- " to=assistant<|channel|>commentary<|message|>{ " + js[1:-1] + " }<|end|>" }} + {{- " to=assistant<|channel|>commentary<|message|>" + message.content|tojson + "<|end|>" }} {%- elif message.role == 'user' -%} {{- "<|start|>user<|message|>" + message.content + "<|end|>" }} {%- endif -%} From ba792c9dc43418a633b11f67a04c85fd381815a8 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 5 Aug 2025 10:08:44 +0200 Subject: [PATCH 327/342] Add comment --- src/transformers/models/gpt_oss/modeling_gpt_oss.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index a12944b5f77b..355dd069eb83 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -244,7 +244,11 @@ def eager_attention_forward( sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) scores = probs[..., :-1] # we drop the sink here attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) From afc0fc490010cf5060e0c59b50287a78a29ddfe3 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 5 Aug 2025 10:17:00 +0200 Subject: [PATCH 328/342] torch testing assert close --- src/transformers/modeling_utils.py | 11 ++++++++--- src/transformers/quantizers/quantizer_mxfp4.py | 6 ++++-- tests/models/gpt_oss/test_modeling_gpt_oss.py | 4 ++-- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 36aa8f3861aa..d3d8aa2610e0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5014,7 +5014,11 @@ def from_pretrained( if hf_quantizer is not None: hf_quantizer.preprocess_model( - model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config, use_kernels=use_kernels + model=model, + device_map=device_map, + keep_in_fp32_modules=model._keep_in_fp32_modules, + config=config, + use_kernels=use_kernels, ) # We store the original dtype for quantized models as we cannot easily retrieve it # once the weights have been quantized @@ -5080,9 +5084,10 @@ def _assign_original_dtype(module): # check if using kernels if use_kernels: - if not is_kernels_available(): - raise ValueError("Kernels are not available. To use kernels, please install kernels using `pip install kernels`") + raise ValueError( + "Kernels are not available. To use kernels, please install kernels using `pip install kernels`" + ) from kernels import Device, kernelize diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 4b00998d925c..061ca072f029 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -267,8 +267,10 @@ def _process_model_before_weight_loading( use_kernels = kwargs.get("use_kernels", False) # if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling if use_kernels: - logger.warning_once("You are using full precision kernels, we will dequantize the model to bf16. " - "To use the quantized model with quantization kernels, please set use_kernels=False") + logger.warning_once( + "You are using full precision kernels, we will dequantize the model to bf16. " + "To use the quantized model with quantization kernels, please set use_kernels=False" + ) self.quantization_config.dequantize = True config = model.config diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 654ec1457d2b..073d902d7e05 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -436,8 +436,8 @@ def test_model_matches_original_20b(self): decoded_token = tokenizer.decode([predicted_token]) print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}") - # Logprob differences are not enforced as of now due to slightly different implementations logprob_differences = selected_logprobs - original_logprobs[i] + torch.testing.assert_close(logprobs, logprob_differences, atol=1e-1, rtol=1e-1) decoded_string = tokenizer.decode(tokens) self.assertTrue(original_output.startswith(decoded_string)) @@ -496,8 +496,8 @@ def test_model_matches_original_120b(self): decoded_token = tokenizer.decode([predicted_token]) print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}") - # Logprob differences are not enforced as of now due to slightly different implementations logprob_differences = selected_logprobs - original_logprobs[i] + torch.testing.assert_close(logprobs, logprob_differences, atol=1e-1, rtol=1e-1) decoded_string = tokenizer.decode(tokens) self.assertTrue(original_output.startswith(decoded_string)) From 7bddb91b195957cb05fc9790facb0113e5b94157 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 5 Aug 2025 10:19:42 +0200 Subject: [PATCH 329/342] torch testing assert close --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 073d902d7e05..3fef7bb86041 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -434,9 +434,9 @@ def test_model_matches_original_20b(self): tokens.append(predicted_token) num_generated_tokens += 1 decoded_token = tokenizer.decode([predicted_token]) - print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}") - logprob_differences = selected_logprobs - original_logprobs[i] + + print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}") torch.testing.assert_close(logprobs, logprob_differences, atol=1e-1, rtol=1e-1) decoded_string = tokenizer.decode(tokens) @@ -494,10 +494,10 @@ def test_model_matches_original_120b(self): tokens.append(predicted_token) num_generated_tokens += 1 decoded_token = tokenizer.decode([predicted_token]) - print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}") - logprob_differences = selected_logprobs - original_logprobs[i] - torch.testing.assert_close(logprobs, logprob_differences, atol=1e-1, rtol=1e-1) + + print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}") + torch.testing.assert_close(selected_logprobs, original_logprobs[i], atol=1e-1, rtol=1e-1) decoded_string = tokenizer.decode(tokens) self.assertTrue(original_output.startswith(decoded_string)) From 4068437d4791cc13d68a85c8550107d286bc67ef Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 5 Aug 2025 10:19:54 +0200 Subject: [PATCH 330/342] torch testing assert close --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 3fef7bb86041..3ff800c78484 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -437,7 +437,7 @@ def test_model_matches_original_20b(self): logprob_differences = selected_logprobs - original_logprobs[i] print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}") - torch.testing.assert_close(logprobs, logprob_differences, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(selected_logprobs, original_logprobs[i], atol=1e-1, rtol=1e-1) decoded_string = tokenizer.decode(tokens) self.assertTrue(original_output.startswith(decoded_string)) From 94f11c59e46894f1dca1df8cd9ece4659194c94a Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 5 Aug 2025 10:21:54 +0200 Subject: [PATCH 331/342] torch testing assert close --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 3ff800c78484..83c8af18c41f 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -387,7 +387,7 @@ def test_model_matches_original_20b(self): input_text = "Roses are red, violets" original_output = "Roses are red, violets are blue, I love you, and I love you too." - original_logprobs = [ + original_logprobs = torch.tensor([ -0.037353515625, -0.08154296875, -1.21875, @@ -408,7 +408,7 @@ def test_model_matches_original_20b(self): -0.1767578125, -0.1240234375, -0.099609375, - ] + ]) model_id = "/fsx/vb/new-oai/gpt-oss-20b-trfs" @@ -447,7 +447,7 @@ def test_model_matches_original_120b(self): original_output = """Roses are red, violets are blue, I am a language model, not a human being""" - original_logprobs = [ + original_logprobs = torch.tensor([ -0.90234375, -0.66015625, -1.546875, @@ -468,7 +468,7 @@ def test_model_matches_original_120b(self): -1.5390625, -1.46875, -0.85546875, - ] + ]) model_id = "/fsx/vb/new-oai/gpt-oss-120b-trfs" From 3660b2b3d03c25c4f82e00dc9d1658b04f5d7d9b Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 5 Aug 2025 10:22:51 +0200 Subject: [PATCH 332/342] torch testing assert close --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 83c8af18c41f..143c616350ec 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -437,7 +437,7 @@ def test_model_matches_original_20b(self): logprob_differences = selected_logprobs - original_logprobs[i] print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}") - torch.testing.assert_close(selected_logprobs, original_logprobs[i], atol=1e-1, rtol=1e-1) + torch.testing.assert_close(selected_logprobs.cpu(), original_logprobs[i], atol=1e-1, rtol=1e-1) decoded_string = tokenizer.decode(tokens) self.assertTrue(original_output.startswith(decoded_string)) @@ -497,7 +497,7 @@ def test_model_matches_original_120b(self): logprob_differences = selected_logprobs - original_logprobs[i] print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}") - torch.testing.assert_close(selected_logprobs, original_logprobs[i], atol=1e-1, rtol=1e-1) + torch.testing.assert_close(selected_logprobs.cpu(), original_logprobs[i], atol=1e-1, rtol=1e-1) decoded_string = tokenizer.decode(tokens) self.assertTrue(original_output.startswith(decoded_string)) From 974987fa29a1dff1e16aad87fdb6516df404443a Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 5 Aug 2025 10:23:46 +0200 Subject: [PATCH 333/342] torch testing assert close --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 143c616350ec..9f3b5a894bc8 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -437,7 +437,7 @@ def test_model_matches_original_20b(self): logprob_differences = selected_logprobs - original_logprobs[i] print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}") - torch.testing.assert_close(selected_logprobs.cpu(), original_logprobs[i], atol=1e-1, rtol=1e-1) + torch.testing.assert_close(selected_logprobs.cpu().to(original_logprobs.dtype), original_logprobs[i], atol=1e-1, rtol=1e-1) decoded_string = tokenizer.decode(tokens) self.assertTrue(original_output.startswith(decoded_string)) @@ -497,7 +497,7 @@ def test_model_matches_original_120b(self): logprob_differences = selected_logprobs - original_logprobs[i] print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}") - torch.testing.assert_close(selected_logprobs.cpu(), original_logprobs[i], atol=1e-1, rtol=1e-1) + torch.testing.assert_close(selected_logprobs.cpu().to(original_logprobs.dtype), original_logprobs[i], atol=1e-1, rtol=1e-1) decoded_string = tokenizer.decode(tokens) self.assertTrue(original_output.startswith(decoded_string)) From d881a2003dca6414de171e46c7c394399f274abe Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 5 Aug 2025 16:50:47 +0200 Subject: [PATCH 334/342] Revert fixup --- src/transformers/modeling_utils.py | 11 +++-------- src/transformers/quantizers/quantizer_mxfp4.py | 6 ++---- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d3d8aa2610e0..36aa8f3861aa 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5014,11 +5014,7 @@ def from_pretrained( if hf_quantizer is not None: hf_quantizer.preprocess_model( - model=model, - device_map=device_map, - keep_in_fp32_modules=model._keep_in_fp32_modules, - config=config, - use_kernels=use_kernels, + model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config, use_kernels=use_kernels ) # We store the original dtype for quantized models as we cannot easily retrieve it # once the weights have been quantized @@ -5084,10 +5080,9 @@ def _assign_original_dtype(module): # check if using kernels if use_kernels: + if not is_kernels_available(): - raise ValueError( - "Kernels are not available. To use kernels, please install kernels using `pip install kernels`" - ) + raise ValueError("Kernels are not available. To use kernels, please install kernels using `pip install kernels`") from kernels import Device, kernelize diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 061ca072f029..4b00998d925c 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -267,10 +267,8 @@ def _process_model_before_weight_loading( use_kernels = kwargs.get("use_kernels", False) # if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling if use_kernels: - logger.warning_once( - "You are using full precision kernels, we will dequantize the model to bf16. " - "To use the quantized model with quantization kernels, please set use_kernels=False" - ) + logger.warning_once("You are using full precision kernels, we will dequantize the model to bf16. " + "To use the quantized model with quantization kernels, please set use_kernels=False") self.quantization_config.dequantize = True config = model.config From 66980045f3063ca2aa38c18a6b437c569d6252c3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 5 Aug 2025 15:04:57 +0000 Subject: [PATCH 335/342] skip 2 test remove todo --- src/transformers/models/gpt_oss/configuration_gpt_oss.py | 2 +- tests/models/gpt_oss/test_modeling_gpt_oss.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_oss/configuration_gpt_oss.py b/src/transformers/models/gpt_oss/configuration_gpt_oss.py index 6d440b864294..0a120e7ec970 100644 --- a/src/transformers/models/gpt_oss/configuration_gpt_oss.py +++ b/src/transformers/models/gpt_oss/configuration_gpt_oss.py @@ -48,7 +48,7 @@ class GptOssConfig(PretrainedConfig): def __init__( self, num_hidden_layers: int = 36, - num_local_experts: int = 128, # TODO: rename to num_experts otherwise confusing with EP + num_local_experts: int = 128, vocab_size: int = 201088, hidden_size: int = 2880, intermediate_size: int = 2880, diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index a0d9fa9b2b67..4f3c5ae6d6c0 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -166,6 +166,14 @@ def test_eager_matches_fa2_generate(self): def test_flash_attn_2_equivalence(self): pass + @unittest.skip("Most probably because of the MOE, the moe and router does not ignore padding tokens") + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("GptOss does not support flex officially") + def test_flex_attention_with_grads(self): + pass + RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/gpt_oss/integration_tests.json" From 54cf55fa21c94a7c0d6887a0b8ec107ce23329be Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 5 Aug 2025 15:08:06 +0000 Subject: [PATCH 336/342] merge --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 108 ++++++++++-------- 1 file changed, 60 insertions(+), 48 deletions(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index dd837a592934..c1d24e267c9b 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -395,28 +395,30 @@ def test_model_matches_original_20b(self): input_text = "Roses are red, violets" original_output = "Roses are red, violets are blue, I love you, and I love you too." - original_logprobs = torch.tensor([ - -0.037353515625, - -0.08154296875, - -1.21875, - -1.953125, - -2.234375, - -0.96875, - -1.546875, - -1.640625, - -0.93359375, - -1.609375, - -1.625, - -0.85546875, - -1.7265625, - -0.7421875, - -2.078125, - -0.006561279296875, - -0.10498046875, - -0.1767578125, - -0.1240234375, - -0.099609375, - ]) + original_logprobs = torch.tensor( + [ + -0.037353515625, + -0.08154296875, + -1.21875, + -1.953125, + -2.234375, + -0.96875, + -1.546875, + -1.640625, + -0.93359375, + -1.609375, + -1.625, + -0.85546875, + -1.7265625, + -0.7421875, + -2.078125, + -0.006561279296875, + -0.10498046875, + -0.1767578125, + -0.1240234375, + -0.099609375, + ] + ) model_id = "/fsx/vb/new-oai/gpt-oss-20b-trfs" @@ -444,8 +446,12 @@ def test_model_matches_original_20b(self): decoded_token = tokenizer.decode([predicted_token]) logprob_differences = selected_logprobs - original_logprobs[i] - print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}") - torch.testing.assert_close(selected_logprobs.cpu().to(original_logprobs.dtype), original_logprobs[i], atol=1e-1, rtol=1e-1) + print( + f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}" + ) + torch.testing.assert_close( + selected_logprobs.cpu().to(original_logprobs.dtype), original_logprobs[i], atol=1e-1, rtol=1e-1 + ) decoded_string = tokenizer.decode(tokens) self.assertTrue(original_output.startswith(decoded_string)) @@ -455,28 +461,30 @@ def test_model_matches_original_120b(self): original_output = """Roses are red, violets are blue, I am a language model, not a human being""" - original_logprobs = torch.tensor([ - -0.90234375, - -0.66015625, - -1.546875, - -2.703125, - -2.078125, - -1.21875, - -2.484375, - -0.031982421875, - -0.84765625, - -1.890625, - -0.1923828125, - -2.046875, - -1.65625, - -1.3515625, - -1.1640625, - -0.3671875, - -1.9921875, - -1.5390625, - -1.46875, - -0.85546875, - ]) + original_logprobs = torch.tensor( + [ + -0.90234375, + -0.66015625, + -1.546875, + -2.703125, + -2.078125, + -1.21875, + -2.484375, + -0.031982421875, + -0.84765625, + -1.890625, + -0.1923828125, + -2.046875, + -1.65625, + -1.3515625, + -1.1640625, + -0.3671875, + -1.9921875, + -1.5390625, + -1.46875, + -0.85546875, + ] + ) model_id = "/fsx/vb/new-oai/gpt-oss-120b-trfs" @@ -504,8 +512,12 @@ def test_model_matches_original_120b(self): decoded_token = tokenizer.decode([predicted_token]) logprob_differences = selected_logprobs - original_logprobs[i] - print(f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}") - torch.testing.assert_close(selected_logprobs.cpu().to(original_logprobs.dtype), original_logprobs[i], atol=1e-1, rtol=1e-1) + print( + f"Generated token: {repr(decoded_token)}, logprob: {selected_logprobs}, logprob differences: {logprob_differences}" + ) + torch.testing.assert_close( + selected_logprobs.cpu().to(original_logprobs.dtype), original_logprobs[i], atol=1e-1, rtol=1e-1 + ) decoded_string = tokenizer.decode(tokens) self.assertTrue(original_output.startswith(decoded_string)) From f19e04b9c8a69728fdc12f86832ac110c6fd1be5 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 5 Aug 2025 15:34:29 +0000 Subject: [PATCH 337/342] padding side should be left for integration tests --- tests/models/gpt_oss/test_modeling_gpt_oss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index c1d24e267c9b..ab978b9c9bc5 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -206,7 +206,7 @@ def distributed_worker(quantized, model_size, kernels, attn_impl, mode): use_kernels=kernels, ).to(torch_device) model.set_attn_implementation(attn_impl) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") # Inference inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(torch_device) @@ -261,7 +261,7 @@ def load_and_forward(model_id, attn_implementation, input_text, **pretrained_kwa attn_implementation=attn_implementation, **pretrained_kwargs, ) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(model.device) output = model.generate(**inputs, max_new_tokens=20, do_sample=False) From 1f7cad0683b4c211f2e9a8ce38a21ef3cc13e871 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 5 Aug 2025 15:38:44 +0000 Subject: [PATCH 338/342] fix modular wrt to changes made to modeling --- src/transformers/models/gpt_oss/modeling_gpt_oss.py | 4 ++-- src/transformers/models/gpt_oss/modular_gpt_oss.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 355dd069eb83..014708e9e0d1 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -242,13 +242,12 @@ def eager_attention_forward( attn_weights = attn_weights + causal_mask sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) - combined_logits = torch.cat([attn_weights, sinks], dim=-1) # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 # when training with bsz>1 we clamp max values. - combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) scores = probs[..., :-1] # we drop the sink here attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) @@ -392,6 +391,7 @@ class GptOssPreTrainedModel(PreTrainedModel): "attentions": GptOssAttention, } _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] + _supports_flash_attention = False _supports_flex_attention = False def _init_weights(self, module): diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index b2c973ff49ba..07ba95c46f69 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -210,8 +210,11 @@ def eager_attention_forward( attn_weights = attn_weights + causal_mask sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) - combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) scores = probs[..., :-1] # we drop the sink here @@ -330,6 +333,7 @@ def forward( class GptOssPreTrainedModel(LlamaPreTrainedModel): _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] _supports_sdpa = False + _supports_flash_attention = False _supports_flex_attention = False _can_record_outputs = { "router_logits": OutputRecorder(GptOssTopKRouter, index=0), @@ -343,6 +347,8 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, nn.Parameter): + module.data.normal_(mean=0.0, std=std) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: @@ -356,6 +362,9 @@ def _init_weights(self, module): module.down_proj_bias.data.zero_() elif isinstance(module, GptOssAttention): module.sinks.data.normal_(mean=0.0, std=std) + elif isinstance(module, GptOssTopKRouter): + module.weight.data.normal_(mean=0.0, std=std) + module.bias.data.normal_(mean=0.0, std=std) class GptOssModel(MixtralModel): From 6973ba409b3341a5c9a43688b42034fabce114d1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 5 Aug 2025 15:39:59 +0000 Subject: [PATCH 339/342] style --- src/transformers/modeling_utils.py | 11 ++++++++--- src/transformers/models/gpt_oss/modular_gpt_oss.py | 3 ++- src/transformers/quantizers/quantizer_mxfp4.py | 6 ++++-- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 02ae44bc5495..0eab1cbab9d8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5013,7 +5013,11 @@ def from_pretrained( if hf_quantizer is not None: hf_quantizer.preprocess_model( - model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config, use_kernels=use_kernels + model=model, + device_map=device_map, + keep_in_fp32_modules=model._keep_in_fp32_modules, + config=config, + use_kernels=use_kernels, ) # We store the original dtype for quantized models as we cannot easily retrieve it # once the weights have been quantized @@ -5079,9 +5083,10 @@ def _assign_original_dtype(module): # check if using kernels if use_kernels: - if not is_kernels_available(): - raise ValueError("Kernels are not available. To use kernels, please install kernels using `pip install kernels`") + raise ValueError( + "Kernels are not available. To use kernels, please install kernels using `pip install kernels`" + ) from kernels import Device, kernelize diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 07ba95c46f69..9b4eb578b73b 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Callable, Optional import torch from torch import nn @@ -25,6 +25,7 @@ MoeModelOutputWithPast, ) from ...modeling_rope_utils import dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 4b00998d925c..061ca072f029 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -267,8 +267,10 @@ def _process_model_before_weight_loading( use_kernels = kwargs.get("use_kernels", False) # if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling if use_kernels: - logger.warning_once("You are using full precision kernels, we will dequantize the model to bf16. " - "To use the quantized model with quantization kernels, please set use_kernels=False") + logger.warning_once( + "You are using full precision kernels, we will dequantize the model to bf16. " + "To use the quantized model with quantization kernels, please set use_kernels=False" + ) self.quantization_config.dequantize = True config = model.config From 1f47841b4a9e8b3e3246f6e90fc35f6d6b11ef04 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Aug 2025 17:48:45 +0200 Subject: [PATCH 340/342] isort --- src/transformers/integrations/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 14fdae321e47..390db81867fd 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -120,12 +120,12 @@ "run_hp_search_wandb", ], "mxfp4": [ - "replace_with_mxfp4_linear", "Mxfp4GptOssExperts", - "quantize_to_mxfp4", "convert_moe_packed_tensors", "dequantize", "load_and_swizzle_mxfp4", + "quantize_to_mxfp4", + "replace_with_mxfp4_linear", ], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], From 865b368bec7b7f2279850a71434a44e2c5ae986f Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Aug 2025 17:51:41 +0200 Subject: [PATCH 341/342] fix opies for the loss --- src/transformers/models/gpt_oss/modeling_gpt_oss.py | 9 +++------ .../models/granitemoe/modeling_granitemoe.py | 9 ++++++--- .../models/granitemoehybrid/modeling_granitemoehybrid.py | 9 ++++++--- .../models/granitemoeshared/modeling_granitemoeshared.py | 9 ++++++--- src/transformers/models/jamba/modeling_jamba.py | 9 ++++++--- src/transformers/models/jetmoe/modeling_jetmoe.py | 9 ++++++--- src/transformers/models/mixtral/modeling_mixtral.py | 9 +++------ src/transformers/models/olmoe/modeling_olmoe.py | 9 ++++++--- src/transformers/models/phimoe/modeling_phimoe.py | 9 ++++++--- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 9 ++++++--- 10 files changed, 54 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 014708e9e0d1..2077f7372c9d 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -573,8 +573,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) - .reshape(-1, routing_weights.shape[1]) + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) .to(compute_device) ) @@ -583,10 +583,7 @@ def load_balancing_loss_func( router_per_expert_attention_mask, dim=0 ) - rank = routing_weights.shape[1] * int(routing_weights.device.index) - overall_loss = torch.sum( - tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) - ) + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index fffe51d794bd..96fc1ca3373c 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -109,8 +109,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) .to(compute_device) ) @@ -119,7 +119,10 @@ def load_balancing_loss_func( router_per_expert_attention_mask, dim=0 ) - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + rank = routing_weights.shape[1] * int(routing_weights.device.index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) return overall_loss * num_experts diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 408cb861a142..c727d40f448b 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1637,8 +1637,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) .to(compute_device) ) @@ -1647,7 +1647,10 @@ def load_balancing_loss_func( router_per_expert_attention_mask, dim=0 ) - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + rank = routing_weights.shape[1] * int(routing_weights.device.index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) return overall_loss * num_experts diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 21e9d13f7195..f2f5d7d6f0f1 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -908,8 +908,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) .to(compute_device) ) @@ -918,7 +918,10 @@ def load_balancing_loss_func( router_per_expert_attention_mask, dim=0 ) - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + rank = routing_weights.shape[1] * int(routing_weights.device.index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) return overall_loss * num_experts diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index e4a376e90af1..191e82e8e852 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -138,8 +138,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) .to(compute_device) ) @@ -148,7 +148,10 @@ def load_balancing_loss_func( router_per_expert_attention_mask, dim=0 ) - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + rank = routing_weights.shape[1] * int(routing_weights.device.index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) return overall_loss * num_experts diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 5156ac59742e..997885944142 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -119,8 +119,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) .to(compute_device) ) @@ -129,7 +129,10 @@ def load_balancing_loss_func( router_per_expert_attention_mask, dim=0 ) - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + rank = routing_weights.shape[1] * int(routing_weights.device.index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) return overall_loss * num_experts diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index d52aaea702c1..043862a3a2c0 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -543,8 +543,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) - .reshape(-1, routing_weights.shape[1]) + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) .to(compute_device) ) @@ -553,10 +553,7 @@ def load_balancing_loss_func( router_per_expert_attention_mask, dim=0 ) - rank = routing_weights.shape[1] * int(routing_weights.device.index) - overall_loss = torch.sum( - tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) - ) + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index eacb56c064b6..c9540e33af4c 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -108,8 +108,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) .to(compute_device) ) @@ -118,7 +118,10 @@ def load_balancing_loss_func( router_per_expert_attention_mask, dim=0 ) - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + rank = routing_weights.shape[1] * int(routing_weights.device.index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) return overall_loss * num_experts diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 2207793dcaea..a4735a04ac24 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -124,8 +124,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) .to(compute_device) ) @@ -134,7 +134,10 @@ def load_balancing_loss_func( router_per_expert_attention_mask, dim=0 ) - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + rank = routing_weights.shape[1] * int(routing_weights.device.index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) return overall_loss * num_experts diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 108aa9617614..a9cc23b37e22 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -128,8 +128,8 @@ def load_balancing_loss_func( # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) + .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1])) + .reshape(-1, routing_weights.shape[1]) .to(compute_device) ) @@ -138,7 +138,10 @@ def load_balancing_loss_func( router_per_expert_attention_mask, dim=0 ) - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + rank = routing_weights.shape[1] * int(routing_weights.device.index) + overall_loss = torch.sum( + tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0) + ) return overall_loss * num_experts From 75f13d05da82d438e088c4f3e5cebf9bb57825d1 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Aug 2025 17:58:34 +0200 Subject: [PATCH 342/342] mmmm --- utils/check_docstrings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index a9ca4f1d56f2..23ba44958dcb 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -79,6 +79,7 @@ # docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the # line before the docstring. OBJECTS_TO_IGNORE = [ + "Mxfp4Config", "Exaone4Config", "SmolLM3Config", "Gemma3nVisionConfig",