diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 3d1b0b169636..062b0eb50637 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -411,6 +411,8 @@
title: Blenderbot Small
- local: model_doc/bloom
title: BLOOM
+ - local: model_doc/blt
+ title: BLT
- local: model_doc/bort
title: BORT
- local: model_doc/byt5
diff --git a/docs/source/en/model_doc/blt.md b/docs/source/en/model_doc/blt.md
new file mode 100644
index 000000000000..0289f77ac901
--- /dev/null
+++ b/docs/source/en/model_doc/blt.md
@@ -0,0 +1,97 @@
+
+
+
+
+# Byte Lantet Transformer (BLT)
+
+## Overview
+
+The BLT model was proposed in [Byte Latent Transformer: Patches Scale Better Than Tokens]() by Artidoro Pagnoni, Ram Pasunuru, Pedro Rodriguez, John Nguyen, Benjamin Muller, Margaret Li1, Chunting Zhou, Lili Yu, Jason Weston, Luke Zettlemoyer, Gargi Ghosh, Mike Lewis, Ari Holtzmanβ , Srinivasan Iyer.
+BLT is a byte-level LLM that achieves tokenization-level performance through entropy-based dynamic patching.
+
+The abstract from the paper is the following:
+
+*We introduce the Byte Latent Transformer (BLT), a new byte-level LLM architecture that, for the first time, matches tokenization-based LLM performance at scale with significant improvements in inference
+efficiency and robustness. BLT encodes bytes into dynamically sized patches, which serve as the primary units of computation. Patches are segmented based on the entropy of the next byte, allocating
+more compute and model capacity where increased data complexity demands it. We present the first flop controlled scaling study of byte-level models up to 8B parameters and 4T training bytes. Our results demonstrate the feasibility of scaling models trained on raw bytes without a fixed vocabulary. Both training and inference efficiency improve due to dynamically selecting long patches when data is predictable, along with qualitative improvements on reasoning and long tail generalization. Overall, for fixed inference costs, BLT shows significantly better scaling than tokenization-based models, by simultaneously growing both patch and model size.*
+
+## Usage Tips:
+
+- **Dual Model Architecture**: BLT consists of two separate trained models:
+ - **Patcher (Entropy Model)**: A smaller transformer model that predicts byte-level entropy to determine patch boundaries and segment input.
+ - **Main Transformer Model**: The primary model that processes the patches through a Local Encoder, Global Transformer, and Local Decoder.
+
+- **Dynamic Patching**: The model uses entropy-based dynamic patching where:
+ - High-entropy regions (complex data) get shorter patches with more computational attention
+ - Low-entropy regions (predictable data) get longer patches for efficiency
+ - This allows the model to allocate compute resources where they're most needed
+
+- **Local Encoder**: Processes byte sequences with cross-attention to patch embeddings
+- **Global Transformer**: Processes patch-level representations with full attention across patches
+- **Local Decoder**: Generates output with cross-attention back to the original byte sequence
+
+- **Byte-Level Tokenizer**: Unlike traditional tokenizers that use learned vocabularies, BLT's tokenizer simply converts text to UTF-8 bytes and maps each byte to a token ID. There is no need for a vocabulary.
+
+The model can be loaded via:
+
+
+
+```python
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf")
+model = AutoModelForCausalLM.from_pretrained(
+ "itazap/blt-1b-hf",
+ device_map="auto",
+)
+
+inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
+
+prompt = "my name is"
+generated_ids = model.generate(
+ **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False
+)
+
+print(tokenizer.decode(generated_ids[0]))
+```
+
+
+
+This model was contributed by [itazap](https://huggingface.co/).
+The original code can be found [here]().
+
+
+## BltConfig
+
+[[autodoc]] BltConfig
+
+[[autodoc]] BltModel
+ - forward
+
+## BltForCausalLM
+
+[[autodoc]] BltForCausalLM
+ - forward
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index c32c8a795488..f0939b089977 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -48,6 +48,7 @@
from .blip import *
from .blip_2 import *
from .bloom import *
+ from .blt import *
from .bridgetower import *
from .bros import *
from .byt5 import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 06023f09c9d8..ec6ce58f7994 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -65,6 +65,7 @@
("blip-2", "Blip2Config"),
("blip_2_qformer", "Blip2QFormerConfig"),
("bloom", "BloomConfig"),
+ ("blt", "BltConfig"),
("bridgetower", "BridgeTowerConfig"),
("bros", "BrosConfig"),
("camembert", "CamembertConfig"),
@@ -490,6 +491,7 @@
("blip-2", "BLIP-2"),
("blip_2_qformer", "BLIP-2 QFormer"),
("bloom", "BLOOM"),
+ ("blt", "Blt"),
("bort", "BORT"),
("bridgetower", "BridgeTower"),
("bros", "BROS"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 025a7a1f90a0..3d0ee2e9fcbd 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -72,6 +72,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("blip-2", "Blip2Model"),
("blip_2_qformer", "Blip2QFormerModel"),
("bloom", "BloomModel"),
+ ("blt", "BltModel"),
("bridgetower", "BridgeTowerModel"),
("bros", "BrosModel"),
("camembert", "CamembertModel"),
@@ -633,6 +634,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("blenderbot", "BlenderbotForCausalLM"),
("blenderbot-small", "BlenderbotSmallForCausalLM"),
("bloom", "BloomForCausalLM"),
+ ("blt", "BltForCausalLM"),
("camembert", "CamembertForCausalLM"),
("code_llama", "LlamaForCausalLM"),
("codegen", "CodeGenForCausalLM"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index 7858ae587946..52726fd6200a 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -105,6 +105,7 @@
("blip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("blip-2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)),
+ ("blt", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("bridgetower", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("bros", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("byt5", ("ByT5Tokenizer", None)),
diff --git a/src/transformers/models/blt/__init__.py b/src/transformers/models/blt/__init__.py
new file mode 100644
index 000000000000..703b81ecdd09
--- /dev/null
+++ b/src/transformers/models/blt/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_blt import *
+ from .modeling_blt import *
+ from .tokenization_blt import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py
new file mode 100644
index 000000000000..0bc6718e5bd1
--- /dev/null
+++ b/src/transformers/models/blt/configuration_blt.py
@@ -0,0 +1,423 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Blt model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class BltLocalEncoderConfig(PretrainedConfig):
+ """
+ Configuration class for the Blt Local Encoder component.
+ """
+
+ model_type = "blt_local_encoder"
+
+ def __init__(
+ self,
+ vocab_size=260,
+ cross_attn_all_layers=False,
+ cross_attn_k=2,
+ hidden_size_global=2048,
+ hidden_size=1024,
+ num_attention_heads=16,
+ num_key_value_heads=None,
+ num_hidden_layers=1,
+ rms_norm_eps=1e-5,
+ dropout=0.0,
+ max_position_embeddings=24576,
+ rope_theta=500000.0,
+ rope_scaling=None,
+ hidden_act="silu",
+ intermediate_size=2816,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.cross_attn_all_layers = cross_attn_all_layers
+ self.cross_attn_k = cross_attn_k
+ self.hidden_size_global = hidden_size_global
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
+ self.head_dim = hidden_size // num_attention_heads
+ self.intermediate_size = intermediate_size or int(8 * hidden_size / 3)
+ self.num_hidden_layers = num_hidden_layers
+ self.rms_norm_eps = rms_norm_eps
+ self.dropout = dropout
+ self.max_position_embeddings = max_position_embeddings
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+
+ # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error
+ kwargs.pop("tie_word_embeddings", None)
+ super().__init__(**kwargs, tie_word_embeddings=False)
+
+
+class BltLocalDecoderConfig(PretrainedConfig):
+ """
+ Configuration class for the Blt Local Decoder component.
+ """
+
+ model_type = "blt_local_decoder"
+
+ def __init__(
+ self,
+ vocab_size=260,
+ cross_attn_all_layers=True,
+ cross_attn_k=2,
+ hidden_size_global=2048,
+ hidden_size=1024,
+ num_attention_heads=16,
+ num_key_value_heads=None,
+ num_hidden_layers=9,
+ rms_norm_eps=1e-5,
+ dropout=0.0,
+ max_position_embeddings=24576,
+ rope_theta=500000.0,
+ rope_scaling=None,
+ hidden_act="silu",
+ intermediate_size=2816,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.cross_attn_all_layers = cross_attn_all_layers
+ self.cross_attn_k = cross_attn_k
+ self.hidden_size_global = hidden_size_global
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
+ self.head_dim = hidden_size // num_attention_heads
+ self.intermediate_size = intermediate_size or int(8 * hidden_size / 3)
+ self.num_hidden_layers = num_hidden_layers
+ self.rms_norm_eps = rms_norm_eps
+ self.dropout = dropout
+ self.max_position_embeddings = max_position_embeddings
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+
+ # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error
+ kwargs.pop("tie_word_embeddings", None)
+ super().__init__(**kwargs, tie_word_embeddings=False)
+
+
+class BltGlobalTransformerConfig(PretrainedConfig):
+ """
+ Configuration class for the Blt Global Transformer component.
+ """
+
+ model_type = "blt_global_transformer"
+
+ def __init__(
+ self,
+ hidden_size=2048,
+ num_attention_heads=16,
+ num_key_value_heads=None,
+ num_hidden_layers=25,
+ rms_norm_eps=1e-5,
+ dropout=0.0,
+ max_position_embeddings=4096,
+ rope_theta=500000.0,
+ rope_scaling=None,
+ hidden_act="silu",
+ intermediate_size=5632,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
+ self.head_dim = hidden_size // num_attention_heads
+ self.intermediate_size = intermediate_size or int(8 * hidden_size / 3)
+ self.num_hidden_layers = num_hidden_layers
+ self.rms_norm_eps = rms_norm_eps
+ self.dropout = dropout
+ self.max_position_embeddings = max_position_embeddings
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+
+ # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error
+ kwargs.pop("tie_word_embeddings", None)
+ super().__init__(**kwargs, tie_word_embeddings=False)
+
+
+class BltPatcherConfig(PretrainedConfig):
+ r"""
+ Configuration class for the Blt Patcher/Entropy model component.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 260):
+ Vocabulary size of the Blt patcher model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling the patcher model.
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimension of the hidden representations.
+ num_hidden_layers (`int`, *optional*, defaults to 14):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
+ The maximum sequence length that this model might ever be used with.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ intermediate_size (`int`, *optional*, defaults to 2048):
+ Dimension of the MLP representations.
+ rope_scaling (`dict`, *optional*):
+ Dictionary containing the RoPE scaling configuration.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ """
+
+ model_type = "blt_patcher"
+
+ def __init__(
+ self,
+ vocab_size=260,
+ hidden_size=768,
+ num_hidden_layers=14,
+ num_attention_heads=12,
+ num_key_value_heads=None,
+ max_position_embeddings=8192,
+ rms_norm_eps=1e-5,
+ dropout=0.0,
+ rope_theta=10000.0,
+ intermediate_size=2048,
+ rope_scaling=None,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = hidden_size // num_attention_heads
+ self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.rms_norm_eps = rms_norm_eps
+ self.dropout = dropout
+ self.rope_theta = rope_theta
+ self.hidden_act = "silu" # Blt uses silu activation
+ self.intermediate_size = intermediate_size or int(8 * self.hidden_size / 3)
+ self.rope_scaling = rope_scaling
+ self.initializer_range = initializer_range
+
+ # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error
+ kwargs.pop("tie_word_embeddings", None)
+ super().__init__(**kwargs, tie_word_embeddings=False)
+
+
+class BltConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`BltModel`]. It is used to instantiate a
+ Blt model according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 260):
+ Vocabulary size of the Blt model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`BltModel`].
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model might ever be used with.
+ patch_in_forward (`bool`, *optional*, defaults to `True`):
+ Whether to perform patching during the forward pass.
+ patch_size (`int`, *optional*, defaults to 4):
+ Size of the patches used in the patching mechanism.
+ patching_mode (`str`, *optional*, defaults to `"entropy"`):
+ The mode used for patching, such as entropy-based patching.
+ patching_threshold (`float`, *optional*, defaults to 1.34):
+ Threshold value used for determining when to apply patches.
+ patching_batch_size (`int`, *optional*, defaults to 1):
+ Batch size used during the patching process.
+ max_patch_length (`int`, *optional*):
+ Maximum length of patches that can be generated.
+ cross_attn_k (`int`, *optional*, defaults to 2):
+ Number of cross-attention heads used in the model.
+ encoder_hash_byte_group_size (`list`, *optional*):
+ List of byte group sizes used in the encoder hash function.
+ encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 500002):
+ Vocabulary size for the encoder hash byte groups.
+ encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 1):
+ Number of hash functions used in the encoder byte grouping.
+ patcher_config (`BltPatcherConfig`, *optional*):
+ Configuration for the patcher component of the model.
+ encoder_config (`BltLocalEncoderConfig`, *optional*):
+ Configuration for the local encoder component of the model.
+ decoder_config (`BltLocalDecoderConfig`, *optional*):
+ Configuration for the local decoder component of the model.
+ global_config (`BltGlobalTransformerConfig`, *optional*):
+ Configuration for the global transformer component of the model.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rope_theta (`float`, *optional*, defaults to 500000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`dict`, *optional*):
+ Dictionary containing the RoPE scaling configuration.
+
+ ```python
+ >>> from transformers import BltModel, BltConfig
+
+ >>> # Initializing a Blt configuration
+ >>> configuration = BltConfig()
+
+ >>> # Initializing a model from the configuration
+ >>> model = BltModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+
+ Checkpoint: [facebook/blt](https://huggingface.co/facebook/blt)
+ """
+
+ model_type = "blt"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ sub_configs = {
+ "patcher_config": BltPatcherConfig,
+ "encoder_config": BltLocalEncoderConfig,
+ "decoder_config": BltLocalDecoderConfig,
+ "global_config": BltGlobalTransformerConfig,
+ }
+
+ def __init__(
+ self,
+ vocab_size=260,
+ max_position_embeddings=4096,
+ patch_in_forward=True,
+ patch_size=4,
+ patching_mode="entropy",
+ patching_threshold=1.335442066192627,
+ patching_batch_size=1,
+ max_patch_length=None,
+ cross_attn_k=2,
+ encoder_hash_byte_group_size=None,
+ encoder_hash_byte_group_vocab=500002,
+ encoder_hash_byte_group_nb_functions=1,
+ patcher_config=None,
+ encoder_config=None,
+ decoder_config=None,
+ global_config=None,
+ tie_word_embeddings=False,
+ initializer_range=0.02,
+ rope_theta=500000.0,
+ rope_scaling=None,
+ **kwargs,
+ ):
+ # Basic model configuration
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+
+ # Patching configuration
+ self.patch_in_forward = patch_in_forward
+ self.patch_size = patch_size
+ self.patching_mode = patching_mode
+ self.patching_threshold = patching_threshold
+ self.patching_batch_size = patching_batch_size
+ self.max_patch_length = max_patch_length
+ self.patching_device = kwargs.get("patching_device", "cuda")
+ self.realtime_patching = kwargs.get("realtime_patching", True)
+ self.patching_threshold_add = kwargs.get("patching_threshold_add")
+ self.monotonicity = kwargs.get("monotonicity", False)
+
+ # Cross attention configurations
+ self.cross_attn_k = cross_attn_k
+
+ # Encoder configurations
+ self.encoder_hash_byte_group_size = encoder_hash_byte_group_size or [3, 4, 5, 6, 7, 8]
+ self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab
+ self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions
+
+ # Initialize component configurations
+ if patcher_config is None:
+ self.patcher_config = BltPatcherConfig(initializer_range=initializer_range)
+ logger.info("patcher_config is None, using default Blt patcher config")
+ elif isinstance(patcher_config, dict):
+ patcher_config.setdefault("initializer_range", initializer_range)
+ self.patcher_config = BltPatcherConfig(**patcher_config)
+ elif isinstance(patcher_config, BltPatcherConfig):
+ self.patcher_config = patcher_config
+
+ if encoder_config is None:
+ self.encoder_config = BltLocalEncoderConfig(initializer_range=initializer_range)
+ logger.info("encoder_config is None, using default Blt encoder config")
+ elif isinstance(encoder_config, dict):
+ encoder_config.setdefault("initializer_range", initializer_range)
+ self.encoder_config = BltLocalEncoderConfig(**encoder_config)
+ elif isinstance(encoder_config, BltLocalEncoderConfig):
+ self.encoder_config = encoder_config
+
+ if decoder_config is None:
+ self.decoder_config = BltLocalDecoderConfig(initializer_range=initializer_range)
+ logger.info("decoder_config is None, using default Blt decoder config")
+ elif isinstance(decoder_config, dict):
+ decoder_config.setdefault("initializer_range", initializer_range)
+ self.decoder_config = BltLocalDecoderConfig(**decoder_config)
+ elif isinstance(decoder_config, BltLocalDecoderConfig):
+ self.decoder_config = decoder_config
+
+ if global_config is None:
+ self.global_config = BltGlobalTransformerConfig(initializer_range=initializer_range)
+ logger.info("global_config is None, using default Blt global config")
+ elif isinstance(global_config, dict):
+ global_config.setdefault("initializer_range", initializer_range)
+ self.global_config = BltGlobalTransformerConfig(**global_config)
+ elif isinstance(global_config, BltGlobalTransformerConfig):
+ self.global_config = global_config
+
+ # Determine if token embedding projection is needed based on dimension mismatch (7b)
+ encoder_cross_output_size = self.encoder_config.hidden_size * self.cross_attn_k
+ self.global_config.encoder_cross_output_size = (
+ encoder_cross_output_size if encoder_cross_output_size != self.global_config.hidden_size else None
+ )
+
+ # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error
+ kwargs.pop("tie_word_embeddings", None)
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+__all__ = [
+ "BltConfig",
+ "BltPatcherConfig",
+ "BltLocalEncoderConfig",
+ "BltLocalDecoderConfig",
+ "BltGlobalTransformerConfig",
+]
diff --git a/src/transformers/models/blt/convert_blt_weights_to_hf.py b/src/transformers/models/blt/convert_blt_weights_to_hf.py
new file mode 100644
index 000000000000..f9decff3a1f8
--- /dev/null
+++ b/src/transformers/models/blt/convert_blt_weights_to_hf.py
@@ -0,0 +1,487 @@
+import argparse
+import json
+import logging
+import os
+from typing import Any, Optional
+
+import torch
+from huggingface_hub import hf_hub_download, upload_folder
+from safetensors.torch import load_file, save_file
+from tokenizers import Tokenizer, decoders, pre_tokenizers, processors
+from tokenizers.models import BPE
+
+from transformers import PreTrainedTokenizerFast
+from transformers.convert_slow_tokenizer import bytes_to_unicode
+from transformers.utils import logging as transformers_logging
+
+
+logger = transformers_logging.get_logger(__name__)
+transformers_logging.set_verbosity_info()
+
+
+def merge_configurations(config_path: str, entropy_params_path: str) -> dict[str, Any]:
+ logger.info("Merging configurations")
+
+ with open(config_path, "r") as f:
+ main_config = json.load(f)
+
+ with open(entropy_params_path, "r") as f:
+ entropy_data = json.load(f)
+
+ entropy_model_params = entropy_data.get("entropy_model", {})
+ patcher_args = entropy_data.get("data", {}).get("patcher_args", {})
+
+ unified_config = main_config.copy()["args"]
+
+ for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]:
+ if key in unified_config and not isinstance(unified_config[key], int):
+ unified_config[key] = int(unified_config[key])
+
+ patch_size = patcher_args.get("patch_size", 8)
+ if isinstance(patch_size, float):
+ patch_size = int(patch_size)
+
+ # Create patcher config
+ patcher_hidden_size = int(entropy_model_params.get("dim", 512))
+ patcher_multiple_of = int(entropy_model_params.get("multiple_of", 256))
+ patcher_intermediate_size = patcher_multiple_of * (
+ (int(8 * patcher_hidden_size / 3) + patcher_multiple_of - 1) // patcher_multiple_of
+ )
+
+ patcher_config = {
+ "vocab_size": int(entropy_model_params.get("vocab_size", 256)),
+ "hidden_size": patcher_hidden_size,
+ "num_hidden_layers": int(entropy_model_params.get("n_layers", 8)),
+ "num_attention_heads": int(entropy_model_params.get("n_heads", 8)),
+ "num_key_value_heads": int(entropy_model_params.get("n_kv_heads"))
+ if entropy_model_params.get("n_kv_heads") is not None
+ else None,
+ "max_position_embeddings": int(entropy_model_params.get("max_seqlen", 1024)),
+ "norm_eps": entropy_model_params.get("norm_eps", 1e-5),
+ "dropout": entropy_model_params.get("dropout", 0.0),
+ "rope_theta": entropy_model_params.get("rope_theta", 10000.0),
+ "attn_impl": entropy_model_params.get("attn_impl", "sdpa"),
+ "attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"),
+ "intermediate_size": patcher_intermediate_size,
+ }
+
+ # Create encoder config
+ encoder_hidden_size = unified_config.get("dim_local_encoder", 1024)
+ encoder_multiple_of = unified_config.get("multiple_of", 256)
+ encoder_intermediate_size = encoder_multiple_of * (
+ (int(8 * encoder_hidden_size / 3) + encoder_multiple_of - 1) // encoder_multiple_of
+ )
+
+ encoder_config = {
+ "vocab_size": unified_config.get("vocab_size", 256),
+ "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_encoder", False),
+ "cross_attn_k": unified_config.get("cross_attn_k", 2),
+ "hidden_size_global": unified_config.get("dim_global", 2048),
+ "pm_size": unified_config.get("pm_size", 0),
+ "hidden_size": encoder_hidden_size,
+ "num_attention_heads": unified_config.get("n_heads_local_encoder", 16),
+ "num_key_value_heads": unified_config.get("n_kv_heads"),
+ "num_hidden_layers": unified_config.get("n_layers_local_encoder", 1),
+ "norm_eps": unified_config.get("norm_eps", 1e-5),
+ "dropout": unified_config.get("dropout", 0.0),
+ "max_position_embeddings": unified_config.get("max_encoder_seq_length")
+ or unified_config.get("max_seqlen", 1024),
+ "rope_theta": unified_config.get("rope_theta", 10000.0),
+ "rope_scaling": {"rope_type": "default"},
+ "hidden_act": unified_config.get("hidden_act", "silu"),
+ "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
+ "intermediate_size": encoder_intermediate_size,
+ }
+
+ # Create decoder config
+ decoder_hidden_size = unified_config.get("dim_local_decoder", 1024)
+ decoder_multiple_of = unified_config.get("multiple_of", 256)
+ decoder_intermediate_size = decoder_multiple_of * (
+ (int(8 * decoder_hidden_size / 3) + decoder_multiple_of - 1) // decoder_multiple_of
+ )
+
+ decoder_config = {
+ "vocab_size": unified_config.get("vocab_size", 256),
+ "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_decoder", False),
+ "cross_attn_k": unified_config.get("cross_attn_k", 2),
+ "hidden_size_global": unified_config.get("dim_global", 2048),
+ "hidden_size": decoder_hidden_size,
+ "num_attention_heads": unified_config.get("n_heads_local_decoder", 16),
+ "num_key_value_heads": unified_config.get("n_kv_heads"),
+ "num_hidden_layers": unified_config.get("n_layers_local_decoder", 9),
+ "norm_eps": unified_config.get("norm_eps", 1e-5),
+ "dropout": unified_config.get("dropout", 0.0),
+ "max_position_embeddings": unified_config.get("max_encoder_seq_length")
+ or unified_config.get("max_seqlen", 1024),
+ "rope_theta": unified_config.get("rope_theta", 10000.0),
+ "rope_scaling": {"rope_type": "default"},
+ "hidden_act": unified_config.get("hidden_act", "silu"),
+ "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
+ "intermediate_size": decoder_intermediate_size,
+ }
+
+ # Create global transformer config
+ global_hidden_size = unified_config.get("dim_global", 2048)
+ global_multiple_of = unified_config.get("multiple_of", 256)
+ global_intermediate_size = global_multiple_of * (
+ (int(8 * global_hidden_size / 3) + global_multiple_of - 1) // global_multiple_of
+ )
+
+ global_config = {
+ "hidden_size": global_hidden_size,
+ "num_attention_heads": unified_config.get("n_heads_global", 16),
+ "num_key_value_heads": unified_config.get("n_kv_heads_global"),
+ "num_hidden_layers": unified_config.get("n_layers_global", 25),
+ "norm_eps": unified_config.get("norm_eps", 1e-5),
+ "dropout": unified_config.get("dropout", 0.0),
+ "max_position_embeddings": unified_config.get("max_seqlen", 1024),
+ "rope_theta": unified_config.get("rope_theta", 10000.0),
+ "rope_scaling": {"rope_type": "default"},
+ "hidden_act": unified_config.get("hidden_act", "silu"),
+ "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
+ "intermediate_size": global_intermediate_size,
+ }
+
+ # Create main config with sub-configs
+ main_config_dict = {
+ "model_type": "blt",
+ "vocab_size": unified_config.get("vocab_size", 256),
+ "max_position_embeddings": unified_config.get("max_seqlen", 1024),
+ "patch_in_forward": True,
+ "realtime_patching": True,
+ "patching_mode": "entropy",
+ "patch_size": patch_size,
+ "patching_threshold": patcher_args.get("threshold", 0.5),
+ "patching_threshold_add": patcher_args.get("threshold_add", 0.0),
+ "max_patch_length": patcher_args.get("max_patch_length"),
+ "patching_batch_size": patcher_args.get("patching_batch_size", 1),
+ "patching_device": patcher_args.get("patching_device", "cuda"),
+ "monotonicity": patcher_args.get("monotonicity", False),
+ "cross_attn_k": unified_config.get("cross_attn_k", 2),
+ "encoder_hash_byte_group_size": unified_config.get("encoder_hash_byte_group_size"),
+ "encoder_hash_byte_group_vocab": unified_config.get("encoder_hash_byte_group_vocab", 30000),
+ "encoder_hash_byte_group_nb_functions": unified_config.get("encoder_hash_byte_group_nb_functions", 3),
+ "pm_size": unified_config.get("pm_size", 0),
+ "patcher_config": patcher_config,
+ "encoder_config": encoder_config,
+ "decoder_config": decoder_config,
+ "global_config": global_config,
+ }
+
+ main_config_dict["tie_word_embeddings"] = False
+
+ logger.info(f"Merged configuration with {len(main_config_dict)} parameters")
+ return main_config_dict
+
+
+def apply_weight_mapping(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ component_mappings = {
+ ".attention.": ".self_attn.",
+ ".feed_forward.": ".mlp.",
+ ".attention_norm.": ".input_layernorm.",
+ ".ffn_norm.": ".post_attention_layernorm.",
+ ".tok_embeddings.": ".embed_tokens.",
+ ".cross_attn_norm_q.": ".q_norm.",
+ ".cross_attn_norm_kv.": ".k_norm.",
+ ".w1.": ".gate_proj.",
+ ".w2.": ".down_proj.",
+ ".w3.": ".up_proj.",
+ ".wq.": ".q_proj.",
+ ".wk.": ".k_proj.",
+ ".wv.": ".v_proj.",
+ ".wo.": ".o_proj.",
+ ".output.": ".lm_head.",
+ }
+
+ new_state_dict = {}
+
+ for old_key, tensor in state_dict.items():
+ new_key = old_key
+
+ for old_pattern, new_pattern in component_mappings.items():
+ if old_pattern in new_key:
+ new_key = new_key.replace(old_pattern, new_pattern)
+
+ new_state_dict[new_key] = tensor
+
+ return new_state_dict
+
+
+def convert_hash_embeddings_to_fused(
+ unified_weights: dict[str, torch.Tensor], config: dict[str, Any]
+) -> dict[str, torch.Tensor]:
+ """Convert ModuleList hash embeddings to nn.embedding format"""
+ original_keys_format = [
+ key
+ for key in unified_weights.keys()
+ if "encoder_hash_tok_embedding." in key and ".weight" in key and key.split(".")[-2].isdigit()
+ ]
+
+ num_embeddings = config.get("encoder_hash_byte_group_nb_functions", 1) * len(
+ config.get("encoder_hash_byte_group_size", [3, 4, 5, 6, 7, 8])
+ )
+ vocab_size = config.get("encoder_hash_byte_group_vocab", 500002)
+ hidden_size = config.get("encoder_config", {}).get("hidden_size", 1024)
+
+ fused_weight = torch.zeros(vocab_size * num_embeddings, hidden_size)
+
+ sorted_keys = sorted(original_keys_format, key=lambda k: int(k.split(".")[-2]))
+
+ for i, old_key in enumerate(sorted_keys):
+ start_idx = i * vocab_size
+ end_idx = (i + 1) * vocab_size
+ fused_weight[start_idx:end_idx] = unified_weights[old_key]
+ logger.info(f"Copied {old_key} to indices {start_idx}:{end_idx}")
+ del unified_weights[old_key]
+
+ fused_key = "model.encoder_hash_tok_embedding.weight"
+ unified_weights[fused_key] = fused_weight
+
+ return unified_weights
+
+
+def merge_weights(weights_path: str, entropy_weights_path: str) -> dict[str, torch.Tensor]:
+ main_weights = load_file(weights_path)
+
+ entropy_weights = torch.load(entropy_weights_path, map_location="cpu", weights_only=True)
+
+ if "model" in entropy_weights:
+ entropy_weights = entropy_weights["model"]
+ elif "state_dict" in entropy_weights:
+ entropy_weights = entropy_weights["state_dict"]
+
+ unified_weights = main_weights.copy()
+
+ for key, tensor in entropy_weights.items():
+ patcher_key = f"patcher.{key}"
+ unified_weights[patcher_key] = tensor
+
+ unified_weights = apply_weight_mapping(unified_weights)
+
+ decoder_lm_head_key = "local_decoder.lm_head.weight"
+ top_lm_head_key = "lm_head.weight"
+ unified_weights[top_lm_head_key] = unified_weights[decoder_lm_head_key]
+ del unified_weights[decoder_lm_head_key]
+
+ prefixed_weights = {}
+ for key, tensor in unified_weights.items():
+ if key == top_lm_head_key:
+ prefixed_weights[key] = tensor
+ elif not key.startswith("model."):
+ prefixed_weights[f"model.{key}"] = tensor
+ else:
+ prefixed_weights[key] = tensor
+
+ unified_weights = prefixed_weights
+
+ return unified_weights
+
+
+def create_tokenizer_config(output_dir: str, config: dict[str, Any]):
+ tokenizer_config = {
+ "tokenizer_class": "PreTrainedTokenizerFast",
+ "vocab_size": config.get("vocab_size", 256),
+ "model_max_length": config.get("max_seqlen", 1024),
+ "model_input_names": ["input_ids", "attention_mask"],
+ "add_bos_token": True,
+ "add_eos_token": True,
+ "bos_token": "",
+ "eos_token": "",
+ "pad_token": "",
+ "unk_token": "",
+ }
+
+ tokenizer_path = os.path.join(output_dir, "tokenizer_config.json")
+ with open(tokenizer_path, "w") as f:
+ json.dump(tokenizer_config, f, indent=2)
+
+
+def create_tokenizer_json(output_dir: str, config: dict[str, Any]):
+ byte_encoder = bytes_to_unicode()
+
+ vocab: dict[str, int] = {}
+ vocab[""] = 0
+ vocab[""] = 1
+ vocab[""] = 2
+ vocab[""] = 3
+
+ offset = 4
+ for byte_val, unicode_char in byte_encoder.items():
+ vocab[unicode_char] = byte_val + offset
+
+ backend = Tokenizer(
+ BPE(vocab=vocab, merges=[], continuing_subword_prefix="", end_of_word_suffix="", fuse_unk=False)
+ )
+ backend.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
+ backend.decoder = decoders.ByteLevel()
+
+ bos = config.get("bos_token", "")
+ backend.post_processor = processors.TemplateProcessing(
+ single=f"{bos}:0 $A:0",
+ pair=f"{bos}:0 $A:0 $B:1",
+ special_tokens=[(bos, 1)],
+ )
+
+ tokenizer = PreTrainedTokenizerFast(
+ tokenizer_object=backend,
+ bos_token=config.get("bos_token", ""),
+ eos_token=config.get("eos_token", ""),
+ pad_token=config.get("pad_token", ""),
+ unk_token=config.get("unk_token", ""),
+ )
+
+ tokenizer.add_bos_token = bool(config.get("add_bos_token", True))
+ tokenizer.add_eos_token = bool(config.get("add_eos_token", True))
+
+ tokenizer.save_pretrained(output_dir)
+ logger.info(f"Saved tokenizer.json to {os.path.join(output_dir, 'tokenizer.json')}")
+
+
+def push_to_hub(
+ local_dir: str,
+ repo_id: str,
+ commit_message: str = "Upload converted Blt model",
+ private: bool = False,
+ token: Optional[str] = None,
+) -> None:
+ try:
+ upload_folder(
+ folder_path=local_dir,
+ repo_id=repo_id,
+ commit_message=commit_message,
+ repo_type="model",
+ token=token,
+ )
+ logger.info(f"Successfully pushed model to {repo_id}")
+
+ except Exception as e:
+ logger.error(f"Failed to push model to Hub: {e}")
+ raise
+
+
+def convert_hf_blt_to_unified(
+ model_id: str,
+ output_dir: str,
+ config_name: str = "config.json",
+ weights_name: str = "model.bin",
+ cache_dir: Optional[str] = None,
+ push_to_hub_repo: Optional[str] = None,
+ hub_private: bool = False,
+ hub_token: Optional[str] = None,
+) -> None:
+ # Download model files
+ config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir)
+ weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir)
+ entropy_params_path = hf_hub_download(repo_id=model_id, filename="entropy_model/params.json", cache_dir=cache_dir)
+ entropy_weights_path = hf_hub_download(
+ repo_id=model_id, filename="entropy_model/consolidated.pth", cache_dir=cache_dir
+ )
+
+ unified_config = merge_configurations(config_path, entropy_params_path)
+ unified_weights = merge_weights(weights_path, entropy_weights_path)
+
+ unified_weights = convert_hash_embeddings_to_fused(unified_weights, unified_config)
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ config_path = os.path.join(output_dir, config_name)
+ with open(config_path, "w") as f:
+ json.dump(unified_config, f, indent=2)
+
+ if weights_name.endswith(".bin"):
+ weights_name = weights_name.replace(".bin", ".safetensors")
+
+ weights_path = os.path.join(output_dir, weights_name)
+ save_file(unified_weights, weights_path)
+
+ create_tokenizer_json(output_dir=output_dir, config=unified_config)
+
+ create_tokenizer_config(output_dir, unified_config)
+
+ logger.info(f"Conversion completed, model saved to: {output_dir}")
+
+ if push_to_hub_repo:
+ push_to_hub(
+ local_dir=output_dir,
+ repo_id=push_to_hub_repo,
+ commit_message="Upload Blt model converted",
+ private=hub_private,
+ token=hub_token,
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Convert Blt models from HuggingFace Hub format to unified format",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--model_id",
+ type=str,
+ default="facebook/blt-7b",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="./blt_converted",
+ )
+ parser.add_argument(
+ "--config_name",
+ type=str,
+ default="config.json",
+ )
+ parser.add_argument(
+ "--weights_name",
+ type=str,
+ default="model.bin",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ )
+ parser.add_argument(
+ "--debug",
+ action="store_true",
+ default=True,
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ type=str,
+ default=None,
+ )
+ parser.add_argument(
+ "--hub_private",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument(
+ "--hub_token",
+ type=str,
+ default="hf_token",
+ )
+
+ args = parser.parse_args()
+
+ transformers_logging.set_verbosity_debug()
+ logging.basicConfig(level=logging.DEBUG)
+
+ try:
+ convert_hf_blt_to_unified(
+ model_id=args.model_id,
+ output_dir=args.output_dir,
+ config_name=args.config_name,
+ weights_name=args.weights_name,
+ cache_dir=args.cache_dir,
+ push_to_hub_repo=False, # args.push_to_hub,
+ hub_private=args.hub_private,
+ hub_token=args.hub_token,
+ )
+ except Exception as e:
+ logger.error(f"Conversion failed: {e}")
+ raise
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py
new file mode 100644
index 000000000000..e1639d4e3e2b
--- /dev/null
+++ b/src/transformers/models/blt/modeling_blt.py
@@ -0,0 +1,1306 @@
+# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
+# This file was automatically generated from src/transformers/models/blt/modular_blt.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_blt.py file directly. One of our CI enforces this.
+# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Union
+
+import torch
+import torch.distributions
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import OutputRecorder, check_model_inputs
+from .configuration_blt import (
+ BltConfig,
+ BltGlobalTransformerConfig,
+ BltLocalDecoderConfig,
+ BltLocalEncoderConfig,
+ BltPatcherConfig,
+)
+
+
+class BltMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ # Ignore copy
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class BltRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ BltRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class BltRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: BltConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer
+class BltTransformerLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = BltSelfAttention(config=config, layer_idx=layer_idx)
+ self.mlp = BltMLP(config)
+ self.input_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.layer_idx = layer_idx
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ cross_attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def rotate_half(x):
+ # Split and rotate. Note that this function is different from e.g. Llama.
+ x1 = x[..., ::2]
+ x2 = x[..., 1::2]
+ rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
+ return rot_x
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class BltSelfAttention(nn.Module):
+ def __init__(self, config: BltConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.num_heads = config.num_attention_heads
+ self.dropout = config.dropout
+ self.hidden_size = config.hidden_size
+ self.num_key_value_heads = config.num_key_value_heads
+ self.head_dim = config.hidden_size // self.num_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.rope_theta = config.rope_theta
+ self.layer_idx = layer_idx
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.is_causal = True
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_embeddings: torch.Tensor,
+ use_cache: bool = False,
+ past_key_values=None,
+ cache_position=None,
+ **kwargs,
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class BltCrossAttention(nn.Module):
+ """Cross-attention module for Blt, following transformers style"""
+
+ def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.num_heads = self.config.num_attention_heads
+ self.num_key_value_heads = self.config.num_key_value_heads
+ self.dropout = config.dropout
+ self.hidden_size = config.hidden_size
+ self.head_dim = config.hidden_size // self.num_heads
+ self.layer_idx = layer_idx
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+ self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+ self.is_causal = False
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+ bsz, q_len, _ = hidden_states.size()
+ query_states = self.q_norm(hidden_states)
+ query_states = self.q_proj(query_states)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ if cross_attention_states is not None:
+ cross_attention_states = self.k_norm(cross_attention_states)
+ key_states = self.k_proj(cross_attention_states)
+ value_states = self.v_proj(cross_attention_states)
+ key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ if past_key_values is not None:
+ key_states, value_states = past_key_values.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ elif cache_position[0] != 0:
+ key_states, value_states = (
+ past_key_values.layers[self.layer_idx].keys,
+ past_key_values.layers[self.layer_idx].values,
+ )
+ else:
+ raise ValueError(
+ "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
+ )
+ attention_interface: Callable = eager_attention_forward
+
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ attn_output = attn_output + hidden_states
+ return attn_output, attn_weights
+
+
+@auto_docstring
+class BltPreTrainedModel(PreTrainedModel):
+ config: BltConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["BltTransformerLayer"]
+ _can_compile_fullgraph = False # static cache cannot have different shapes for each layer
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _supports_attention_backend = False
+ _can_record_outputs = {
+ "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"),
+ "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"),
+ }
+
+
+class BltLocalEncoder(BltPreTrainedModel):
+ config: BltLocalEncoderConfig
+ _can_record_outputs = {
+ "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"),
+ }
+
+ def __init__(self, config: BltLocalEncoderConfig):
+ super().__init__(config)
+ self.gradient_checkpointing = False
+ self.config = config
+ self.layers = nn.ModuleList(
+ [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.rotary_emb = BltRotaryEmbedding(config=config)
+ self.patch_embedding_projection = nn.Linear(
+ in_features=config.hidden_size,
+ out_features=config.hidden_size * config.cross_attn_k,
+ bias=False,
+ )
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.cross_attn_layers = nn.ModuleList()
+ layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
+ for layer_idx in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
+ )
+
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ patch_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ num_patches: Optional[int] = None,
+ patch_ids: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ batch_size = inputs_embeds.shape[0]
+ hidden_states = F.dropout(inputs_embeds, p=self.config.dropout, training=self.training)
+
+ if position_ids is None:
+ position_ids = (
+ torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
+ )
+
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
+
+ for idx, layer in enumerate(self.layers):
+ hidden_states = layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers:
+ patch_embeds = self.patch_reduce(hidden_states, num_patches, patch_ids)
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ patch_embeds = patch_embeds.reshape(
+ batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size
+ )
+ layer_idx = idx if self.config.cross_attn_all_layers else 0
+ cross_attention_output, _ = self.cross_attn_layers[layer_idx](
+ hidden_states=patch_embeds,
+ cross_attention_states=hidden_states,
+ attention_mask=encoder_attention_mask,
+ **kwargs,
+ )
+ patch_embeds = patch_embeds + cross_attention_output
+ encoder_cross_states = patch_embeds
+ return hidden_states, encoder_cross_states
+
+ def patch_reduce(self, hidden_states, max_num_patches, patch_ids):
+ """
+ Reduce variable length patches to single embedding per patch
+ Note: this works with variable number of patches for different sequences in the batch
+ It handles variable length patches by assuming that patch_lengths will be 0 for any
+ extra patches on the *right*. Since there can be a variable number of patches
+ this function also return the number of patches for each sequence in the batch.
+ Any embeddings on the right that are not allocated to a patch
+ (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
+ will be sent to a dummy patch, which is trimmed before returning.
+ """
+ batch_size = hidden_states.shape[0]
+ embedding_dim = hidden_states.shape[-1]
+
+ patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
+
+ reduced_embeddings = torch.zeros(
+ (batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ reduced_embeddings = reduced_embeddings.scatter_reduce(
+ src=hidden_states,
+ dim=1,
+ index=patch_ids,
+ reduce="amax",
+ include_self=False,
+ )
+ reduced_embeddings = reduced_embeddings[:, :max_num_patches, :]
+
+ return reduced_embeddings
+
+
+class BltLocalDecoder(BltPreTrainedModel):
+ config: BltLocalDecoderConfig
+
+ def __init__(self, config: BltLocalDecoderConfig):
+ super().__init__(config)
+ self.gradient_checkpointing = False
+ self.config = config
+ self.cross_attn_decoder = True
+ self.layers = nn.ModuleList(
+ [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.rotary_emb = BltRotaryEmbedding(config=config)
+ self.patch_embedding_projection = nn.Linear(
+ in_features=config.hidden_size_global,
+ out_features=config.hidden_size * config.cross_attn_k,
+ bias=False,
+ )
+ self.norm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.cross_attn_layers = nn.ModuleList()
+ layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
+ for layer_idx in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
+ )
+
+ self.post_init()
+
+ @check_model_inputs
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ patch_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ batch_size = inputs_embeds.shape[0]
+ hidden_states = inputs_embeds
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ patch_embeds = patch_embeds.reshape(
+ batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size
+ )
+
+ if patch_embeds is not None and not self.cross_attn_decoder:
+ hidden_states = hidden_states + patch_embeds
+
+ if position_ids is None:
+ position_ids = (
+ torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
+ )
+
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
+
+ for i, layer in enumerate(self.layers):
+ if i == 0 or self.config.cross_attn_all_layers:
+ cross_attention_output, _ = self.cross_attn_layers[i](
+ hidden_states=hidden_states,
+ cross_attention_states=patch_embeds,
+ attention_mask=encoder_attention_mask,
+ **kwargs,
+ )
+ hidden_states = hidden_states + cross_attention_output
+ hidden_states = layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ logits = self.norm(hidden_states)
+ return logits
+
+
+class BltGlobalTransformer(BltPreTrainedModel):
+ config: BltGlobalTransformerConfig
+ _can_record_outputs = {
+ "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"),
+ }
+
+ def __init__(self, config: BltGlobalTransformerConfig):
+ super().__init__(config)
+ self.config = config
+ self.layers = nn.ModuleList()
+ for layer_idx in range(config.num_hidden_layers):
+ self.layers.append(BltTransformerLayer(config, layer_idx))
+ self.rotary_emb = BltRotaryEmbedding(config=config)
+
+ # Create token embedding projection (use nn.Identity() when no projection needed)
+ if getattr(config, "encoder_cross_output_size", None) is not None:
+ self.token_embedding_projection = nn.Linear(
+ config.encoder_cross_output_size, config.hidden_size, bias=False
+ )
+ else:
+ self.token_embedding_projection = nn.Identity()
+
+ self.post_init()
+
+ def forward(
+ self,
+ input_embeds: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ batch_size, seq_len, _ = input_embeds.shape
+ hidden_states = self.token_embedding_projection(input_embeds)
+ hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
+ if position_ids is None:
+ position_ids = (
+ torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
+ )
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ for i, layer in enumerate(self.layers):
+ hidden_states = layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ return hidden_states
+
+
+def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor:
+ """
+ Splits patch lengths into smaller segments if they exceed `max_patch_length`.
+ Pads the result to uniform length across the batch.
+
+ Args:
+ patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
+ max_patch_length (int, optional): Maximum allowed length per patch.
+
+ Returns:
+ torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
+ """
+ if max_patch_length is None:
+ return patch_lengths
+
+ batch_size = patch_lengths.size(0)
+ processed = []
+
+ for seq in patch_lengths:
+ splits = []
+ for length in seq[seq > 0]:
+ length = length.item()
+ full_chunks, remainder = divmod(length, max_patch_length)
+ splits.extend([max_patch_length] * full_chunks)
+ if remainder:
+ splits.append(remainder)
+ processed.append(splits)
+
+ # Find max length to pad to
+ max_len = max(len(splits) for splits in processed)
+ padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
+
+ for i, splits in enumerate(processed):
+ if splits:
+ padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
+
+ # Trim zero columns
+ if (padded != 0).any(dim=0).sum() < padded.shape[1]:
+ last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
+ padded = padded[:, :last_nonzero]
+
+ return padded
+
+
+class BltPatcher(BltPreTrainedModel):
+ config: BltPatcherConfig
+
+ def __init__(self, config: BltPatcherConfig):
+ super().__init__(config)
+ self.rotary_emb = BltRotaryEmbedding(config=self.config)
+ self.layers = nn.ModuleList()
+ for layer_idx in range(self.config.num_hidden_layers):
+ self.layers.append(BltTransformerLayer(self.config, layer_idx))
+ self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
+ self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.lm_head = nn.Linear(
+ self.config.hidden_size,
+ self.config.vocab_size,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ patch_size: Optional[int] = None,
+ threshold: Optional[float] = None,
+ max_patch_length: Optional[int] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for layer in self.layers:
+ hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask)
+
+ logits = self.lm_head(self.norm(hidden_states))
+ prediction_entropies = torch.distributions.Categorical(logits=logits).entropy()
+
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+ if patch_size is not None:
+ patch_lengths = self.patch_lengths_from_entropies(
+ entropies=prediction_entropies,
+ sequence_length=sequence_length,
+ patch_size=patch_size,
+ threshold=threshold,
+ )
+ else:
+ patch_lengths = torch.ones(
+ (batch_size, sequence_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device
+ )
+ patch_lengths = process_patch_lengths(patch_lengths, max_patch_length)
+ return prediction_entropies, patch_lengths, logits
+
+ @staticmethod
+ def patch_lengths_from_entropies(
+ entropies,
+ sequence_length,
+ patch_size=None,
+ threshold=None,
+ ):
+ """
+ Computes patch lengths from token entropies.
+
+ Depending on whether a threshold is provided, the function uses either:
+ - Thresholding the entropy values (when `threshold` is set).
+ """
+
+ batch_size = entropies.shape[0]
+
+ # Always include token 0 and 1 as starting tokens
+ init_tokens = (
+ torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1)
+ )
+ offset = init_tokens.shape[1]
+
+ # Ignore first token entropy (BOS)
+ entropies = entropies[:, 1:]
+
+ # Threshold the entropy values to define patch start points
+ patch_mask = entropies > threshold
+
+ seq_len = patch_mask.shape[1]
+
+ # Create patch IDs (token indices), and add a sentinel to ensure alignment
+ token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1)
+ sentinel = torch.full_like(token_indices, seq_len)
+ padded_indices = torch.cat([token_indices, sentinel], dim=1)
+
+ # Pad mask with inverse to align sentinel correctly
+ padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1)
+
+ # Select indices where mask is True
+ patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len)
+ max_valid_patches = patch_mask.sum(dim=1).max()
+ patch_starts = patch_starts[:, :max_valid_patches]
+
+ # Offset patch starts to account for the two initial tokens
+ patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1)
+
+ # Compute patch end positions by shifting start positions
+ last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1)
+ patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1)
+
+ patch_lengths = patch_ends - patch_start_ids + 1
+
+ return patch_lengths
+
+
+def rolling_polynomial_hash(token_tensor, prime: int = 1000000007):
+ """
+ A polynomial rolling hash algorithm that converts sequences
+ of tokens into hash values. The hash is computed as:
+ hash = (token_0 * prime^0 + token_1 * prime^1 + ... + token_n * prime^n)
+
+ The rolling hash allows the model to efficiently
+ identify and encode recurring byte-level patterns in the input text.
+
+ Args:
+ token_tensor (torch.Tensor): [batch_size, seq_len, group_size] containing token IDs to hash
+ prime (int): Prime number used as the base for the polynomial hash.
+
+ Returns:
+ torch.Tensor: Hash values of shape [batch_size, seq_len] where each value
+ represents the hash of the corresponding token group
+
+ Example:
+ >>> tokens = torch.tensor([[1, 2, 3], [4, 5, 6]])
+ >>> hashes = rolling_polynomial_hash(tokens, prime=31)
+ >>> # hash[0] = 1*31^0 + 2*31^1 + 3*31^2
+ >>> # hash[1] = 4*31^0 + 5*31^1 + 6*31^2
+ """
+ prime_tensor = torch.tensor(prime, dtype=torch.int64, device=token_tensor.device)
+ powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device)
+ prime_powers = prime_tensor**powers
+ return torch.sum(token_tensor * prime_powers, dim=-1)
+
+
+def byte_group_hash_function(
+ token_ids: torch.Tensor, group_size: int = 2, prime: int = 1000000007, max_hash: int = 30000
+):
+ """Hash token groups and map to range [0, max_hash]."""
+ with torch.no_grad():
+ batch_size, seq_len = token_ids.shape
+ # Add padding for sliding window
+ padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device)
+ padded_tokens = torch.cat([padding, token_ids], dim=1)
+
+ # Create sliding windows and compute hashes
+ windows = padded_tokens.unfold(1, group_size, 1)
+ hashes = rolling_polynomial_hash(windows, prime)
+ hash_values = hashes % max_hash
+
+ return hash_values
+
+
+def compute_hash_embeddings(
+ local_encoder_tokens: torch.Tensor,
+ local_encoder,
+ encoder_hash_tok_embedding: nn.Embedding,
+ encoder_hash_byte_group_nb_functions: int,
+ encoder_hash_byte_group_size: list,
+ encoder_hash_byte_group_vocab: int,
+) -> torch.Tensor:
+ """Compute token embeddings enhanced with hash-based embeddings."""
+ # Available primes for hash functions
+ primes = [
+ 1000000007,
+ 5915587277,
+ 1500450271,
+ 3267000013,
+ 5754853343,
+ 4093082899,
+ 9576890767,
+ 3628273133,
+ 2860486313,
+ 5463458053,
+ 3367900313,
+ ]
+
+ embeddings = local_encoder.embed_tokens(local_encoder_tokens)
+ embedding_idx = 0
+ for func_nb in range(encoder_hash_byte_group_nb_functions):
+ prime = primes[func_nb % len(primes)] # Cycle through primes if more functions than primes
+ for group_size in encoder_hash_byte_group_size:
+ hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab)
+ # Apply offset to get the correct slice of the fused embedding
+ offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab
+ embeddings += encoder_hash_tok_embedding(offset_hash_ids)
+ embedding_idx += 1
+
+ return embeddings
+
+
+def _prepare_patch_cross_attention_mask(
+ patch_ids: torch.Tensor,
+ num_patches: int,
+ sequence_length: int,
+ patches_as_queries: bool = False,
+ cross_attn_k: int = 1,
+ dtype: torch.dtype = torch.float32,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Prepare cross-attention mask for patch-based attention, following mllama's robust approach.
+
+ This function creates masks that control which patches can attend to which other patches,
+ with support for query/key role swapping and cross-attention multipliers.
+
+ Args:
+ patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids.
+ num_patches (int): Total number of patches.
+ sequence_length (int): Length of the sequence.
+ patches_as_queries (bool): If True, patches are used as queries, otherwise as keys.
+ cross_attn_k (int): Cross-attention multiplier for repeating patches.
+ dtype (torch.dtype): Data type for the output mask.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]:
+ - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len]
+ """
+ batch_size, seq_len = patch_ids.shape
+ device = patch_ids.device
+
+ # Determine query and key lengths based on configuration
+ if patches_as_queries:
+ q_len = num_patches * cross_attn_k
+ kv_len = sequence_length
+ # Create patch-to-sequence mapping
+ q_patch_ids = (
+ torch.arange(num_patches, device=device)
+ .unsqueeze(0)
+ .unsqueeze(-1)
+ .expand(batch_size, num_patches, seq_len)
+ )
+ kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len)
+ else:
+ q_len = sequence_length
+ kv_len = num_patches * cross_attn_k
+ # Create sequence-to-patch mapping
+ q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches)
+ kv_patch_ids = (
+ torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, num_patches)
+ )
+
+ # Create base attention mask - boolean mask where True means "should attend"
+ # Exact patch matching
+ cross_attention_mask = q_patch_ids == kv_patch_ids
+
+ # Handle cross_attn_k multiplier by repeating along appropriate dimension
+ repeat_dim = 1 if patches_as_queries else -1
+ cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim)
+
+ # Validate dimensions
+ expected_shape = (batch_size, q_len, kv_len)
+ if cross_attention_mask.shape != expected_shape:
+ raise ValueError(
+ f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}"
+ )
+
+ # Reshape so it can be used by attn module - add head dimension
+ cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len]
+
+ # Invert the mask (following mllama pattern exactly)
+ # True -> 0.0 (attend), False -> 1.0 (will become -inf)
+ inverted_cross_attn_mask = 1.0 - cross_attention_mask.to(dtype)
+ cross_attention_mask = inverted_cross_attn_mask.masked_fill(
+ inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+ return cross_attention_mask
+
+
+class BltModel(BltPreTrainedModel):
+ def __init__(self, config: BltConfig):
+ super().__init__(config)
+ self.gradient_checkpointing = False
+
+ self.config = config
+ self.local_encoder = BltLocalEncoder(config.encoder_config)
+ self.global_transformer = BltGlobalTransformer(config.global_config)
+ self.local_decoder = BltLocalDecoder(config.decoder_config)
+ num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size)
+ total_vocab_size = config.encoder_hash_byte_group_vocab * num_embeddings
+ self.encoder_hash_tok_embedding = nn.Embedding(total_vocab_size, config.encoder_config.hidden_size)
+ if self.config.patch_in_forward:
+ self.patcher = BltPatcher(config.patcher_config)
+ self.patcher.eval()
+ for param in self.patcher.parameters():
+ param.requires_grad = False
+ else:
+ self.patcher = None
+ self.post_init()
+
+ @check_model_inputs
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ patch_lengths: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ # Extract input embeddings as early as possible
+ if inputs_embeds is not None:
+ encoder_embeds = inputs_embeds
+ batch_size, sequence_length, _ = inputs_embeds.shape
+ else:
+ batch_size, sequence_length = input_ids.shape
+ encoder_embeds = compute_hash_embeddings(
+ input_ids,
+ self.local_encoder,
+ self.encoder_hash_tok_embedding,
+ self.config.encoder_hash_byte_group_nb_functions,
+ self.config.encoder_hash_byte_group_size,
+ self.config.encoder_hash_byte_group_vocab,
+ )
+
+ if patch_lengths is None:
+ if self.config.patching_mode == "entropy" and self.patcher is not None:
+ if input_ids is None:
+ raise ValueError("input_ids is required for entropy-based patching")
+ _, patch_lengths, _ = self.patcher(
+ input_ids,
+ patch_size=self.config.patch_size,
+ threshold=self.config.patching_threshold,
+ max_patch_length=self.config.max_patch_length,
+ patching_batch_size=self.config.patching_batch_size,
+ device=input_ids.device,
+ )
+ else:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ dtype = input_ids.dtype if input_ids is not None else inputs_embeds.dtype
+ patch_lengths = process_patch_lengths(
+ torch.ones((batch_size, sequence_length + 1), dtype=dtype, device=device),
+ self.config.max_patch_length,
+ )
+ patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length)
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + encoder_embeds.shape[1], device=encoder_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=encoder_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ cross_attn_mask_enc = _prepare_patch_cross_attention_mask(
+ patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype
+ )
+ encoder_hidden_states, encoder_cross_states = self.local_encoder(
+ input_ids=input_ids,
+ inputs_embeds=encoder_embeds,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ encoder_attention_mask=cross_attn_mask_enc,
+ num_patches=patch_lengths.shape[1],
+ patch_ids=patch_ids,
+ **kwargs,
+ )
+ encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
+ global_cache_position = torch.arange(0, encoder_cross_states.shape[1], device=encoder_cross_states.device)
+ global_position_ids = global_cache_position.unsqueeze(0)
+ global_causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=encoder_cross_states,
+ attention_mask=None,
+ cache_position=global_cache_position,
+ past_key_values=None,
+ position_ids=None,
+ )
+
+ global_hidden_states = self.global_transformer(
+ input_embeds=encoder_cross_states,
+ attention_mask=global_causal_mask,
+ position_ids=global_position_ids,
+ **kwargs,
+ )
+ decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length)
+ cross_attn_mask_dec = _prepare_patch_cross_attention_mask(
+ decoder_patch_ids,
+ patch_lengths.shape[1],
+ sequence_length,
+ False,
+ self.config.cross_attn_k,
+ encoder_embeds.dtype,
+ )
+ output = self.local_decoder(
+ input_ids=input_ids,
+ inputs_embeds=encoder_hidden_states,
+ patch_embeds=global_hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ encoder_attention_mask=cross_attn_mask_dec,
+ **kwargs,
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=output,
+ past_key_values=past_key_values,
+ )
+
+ def get_input_embeddings(self):
+ return self.local_encoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.local_encoder.embed_tokens = value
+
+ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
+ batch_size = patch_lengths.shape[0]
+ patch_starts = torch.cat(
+ [
+ torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
+ patch_lengths.cumsum(dim=-1)[:, :-1],
+ ],
+ dim=-1,
+ )
+ token_positions = torch.arange(seq_len, device=patch_lengths.device)
+ return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1
+
+
+@auto_docstring(
+ custom_intro="""
+ The Blt Text Model with a language modeling head on top.
+ """
+)
+class BltForCausalLM(BltPreTrainedModel, GenerationMixin):
+ config: BltConfig
+ _can_compile_fullgraph = False
+ base_model_prefix = "model"
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: BltConfig):
+ super().__init__(config.get_text_config())
+ self.text_config = config.get_text_config()
+ self.vocab_size = config.vocab_size
+ self.model = BltModel(config)
+ self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False)
+
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ cross_attention_states: Optional[torch.LongTensor] = None, # Keep for compatibility
+ cross_attention_mask: Optional[torch.LongTensor] = None,
+ full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ cross_attention_states (`torch.FloatTensor`, *optional*):
+ Output of the vision model, used for cross-attention. This tensor contains the processed image features that
+ the language model will attend to.
+ cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*):
+ Cross-attention mask to control the interaction between text tokens and image tiles.
+ This 4D tensor defines which image tiles each text token should attend to.
+
+ For each text token (in seq_length):
+ - 1 indicates the token **should attend** to the corresponding image tile
+ - 0 indicates the token **should not attend** to the corresponding image tile
+ full_text_row_masked_out_mask (`tuple[torch.Tensor, torch.Tensor]`, *optional*):
+ A tuple containing two tensors that mask out rows in the cross-attention mechanism:
+ - The first tensor has shape `(batch_size, 1, seq_length, 1)` and contains values of 0 or 1.
+ A value of 0 indicates that the corresponding text token's entire row in the cross-attention
+ matrix should be masked out (all image tokens ignored).
+ - The second tensor has the same shape and is used internally to apply the masking during
+ the forward pass of cross-attention layers.
+ This mask is derived from the cross_attention_mask and is used to handle cases where a text token
+ should not attend to any image token.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, BltForCausalLM
+
+ >>> model = BltForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
+ >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
+
+ >>> prompt = "If I had to write a haiku, it would be:"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
+ >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ >>> print(result)
+ If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
+ I love the idea of snowflakes gently falling, each one
+ ```
+ """
+ # Call parent forward but exclude cross_attention_states from model call
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cross_attention_mask=cross_attention_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["BltPreTrainedModel", "BltModel", "BltPatcher", "BltForCausalLM"]
diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py
new file mode 100644
index 000000000000..0b04966d97fe
--- /dev/null
+++ b/src/transformers/models/blt/modular_blt.py
@@ -0,0 +1,1008 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Blt modular model, inheriting from Mllama where appropriate."""
+
+from typing import Callable, Optional, Union
+
+import torch
+import torch.distributions
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...cache_utils import Cache, DynamicCache
+from ...masking_utils import create_causal_mask
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, logging
+from ...utils.generic import OutputRecorder, check_model_inputs
+from ..cohere2.modeling_cohere2 import (
+ Cohere2RotaryEmbedding,
+ rotate_half, # noqa: F401
+)
+from ..mllama.modeling_mllama import (
+ MllamaForCausalLM,
+ MllamaPreTrainedModel,
+ MllamaSelfAttentionDecoderLayer,
+ MllamaTextCrossAttention,
+ MllamaTextMLP,
+ MllamaTextRMSNorm,
+ MllamaTextSelfAttention,
+ eager_attention_forward,
+)
+from .configuration_blt import (
+ BltConfig,
+ BltGlobalTransformerConfig,
+ BltLocalDecoderConfig,
+ BltLocalEncoderConfig,
+ BltPatcherConfig,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+def rolling_polynomial_hash(token_tensor, prime: int = 1000000007):
+ """
+ A polynomial rolling hash algorithm that converts sequences
+ of tokens into hash values. The hash is computed as:
+ hash = (token_0 * prime^0 + token_1 * prime^1 + ... + token_n * prime^n)
+
+ The rolling hash allows the model to efficiently
+ identify and encode recurring byte-level patterns in the input text.
+
+ Args:
+ token_tensor (torch.Tensor): [batch_size, seq_len, group_size] containing token IDs to hash
+ prime (int): Prime number used as the base for the polynomial hash.
+
+ Returns:
+ torch.Tensor: Hash values of shape [batch_size, seq_len] where each value
+ represents the hash of the corresponding token group
+
+ Example:
+ >>> tokens = torch.tensor([[1, 2, 3], [4, 5, 6]])
+ >>> hashes = rolling_polynomial_hash(tokens, prime=31)
+ >>> # hash[0] = 1*31^0 + 2*31^1 + 3*31^2
+ >>> # hash[1] = 4*31^0 + 5*31^1 + 6*31^2
+ """
+ prime_tensor = torch.tensor(prime, dtype=torch.int64, device=token_tensor.device)
+ powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device)
+ prime_powers = prime_tensor**powers
+ return torch.sum(token_tensor * prime_powers, dim=-1)
+
+
+def byte_group_hash_function(
+ token_ids: torch.Tensor, group_size: int = 2, prime: int = 1000000007, max_hash: int = 30000
+):
+ """Hash token groups and map to range [0, max_hash]."""
+ with torch.no_grad():
+ batch_size, seq_len = token_ids.shape
+ # Add padding for sliding window
+ padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device)
+ padded_tokens = torch.cat([padding, token_ids], dim=1)
+
+ # Create sliding windows and compute hashes
+ windows = padded_tokens.unfold(1, group_size, 1)
+ hashes = rolling_polynomial_hash(windows, prime)
+ hash_values = hashes % max_hash
+
+ return hash_values
+
+
+def compute_hash_embeddings(
+ local_encoder_tokens: torch.Tensor,
+ local_encoder,
+ encoder_hash_tok_embedding: nn.Embedding,
+ encoder_hash_byte_group_nb_functions: int,
+ encoder_hash_byte_group_size: list,
+ encoder_hash_byte_group_vocab: int,
+) -> torch.Tensor:
+ """Compute token embeddings enhanced with hash-based embeddings."""
+ # Available primes for hash functions
+ primes = [
+ 1000000007,
+ 5915587277,
+ 1500450271,
+ 3267000013,
+ 5754853343,
+ 4093082899,
+ 9576890767,
+ 3628273133,
+ 2860486313,
+ 5463458053,
+ 3367900313,
+ ]
+
+ embeddings = local_encoder.embed_tokens(local_encoder_tokens)
+ embedding_idx = 0
+ for func_nb in range(encoder_hash_byte_group_nb_functions):
+ prime = primes[func_nb % len(primes)] # Cycle through primes if more functions than primes
+ for group_size in encoder_hash_byte_group_size:
+ hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab)
+ # Apply offset to get the correct slice of the fused embedding
+ offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab
+ embeddings += encoder_hash_tok_embedding(offset_hash_ids)
+ embedding_idx += 1
+
+ return embeddings
+
+
+def _prepare_patch_cross_attention_mask(
+ patch_ids: torch.Tensor,
+ num_patches: int,
+ sequence_length: int,
+ patches_as_queries: bool = False,
+ cross_attn_k: int = 1,
+ dtype: torch.dtype = torch.float32,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Prepare cross-attention mask for patch-based attention, following mllama's robust approach.
+
+ This function creates masks that control which patches can attend to which other patches,
+ with support for query/key role swapping and cross-attention multipliers.
+
+ Args:
+ patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids.
+ num_patches (int): Total number of patches.
+ sequence_length (int): Length of the sequence.
+ patches_as_queries (bool): If True, patches are used as queries, otherwise as keys.
+ cross_attn_k (int): Cross-attention multiplier for repeating patches.
+ dtype (torch.dtype): Data type for the output mask.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]:
+ - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len]
+ """
+ batch_size, seq_len = patch_ids.shape
+ device = patch_ids.device
+
+ # Determine query and key lengths based on configuration
+ if patches_as_queries:
+ q_len = num_patches * cross_attn_k
+ kv_len = sequence_length
+ # Create patch-to-sequence mapping
+ q_patch_ids = (
+ torch.arange(num_patches, device=device)
+ .unsqueeze(0)
+ .unsqueeze(-1)
+ .expand(batch_size, num_patches, seq_len)
+ )
+ kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len)
+ else:
+ q_len = sequence_length
+ kv_len = num_patches * cross_attn_k
+ # Create sequence-to-patch mapping
+ q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches)
+ kv_patch_ids = (
+ torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, num_patches)
+ )
+
+ # Create base attention mask - boolean mask where True means "should attend"
+ # Exact patch matching
+ cross_attention_mask = q_patch_ids == kv_patch_ids
+
+ # Handle cross_attn_k multiplier by repeating along appropriate dimension
+ repeat_dim = 1 if patches_as_queries else -1
+ cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim)
+
+ # Validate dimensions
+ expected_shape = (batch_size, q_len, kv_len)
+ if cross_attention_mask.shape != expected_shape:
+ raise ValueError(
+ f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}"
+ )
+
+ # Reshape so it can be used by attn module - add head dimension
+ cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len]
+
+ # Invert the mask (following mllama pattern exactly)
+ # True -> 0.0 (attend), False -> 1.0 (will become -inf)
+ inverted_cross_attn_mask = 1.0 - cross_attention_mask.to(dtype)
+ cross_attention_mask = inverted_cross_attn_mask.masked_fill(
+ inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+ return cross_attention_mask
+
+
+def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor:
+ """
+ Splits patch lengths into smaller segments if they exceed `max_patch_length`.
+ Pads the result to uniform length across the batch.
+
+ Args:
+ patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
+ max_patch_length (int, optional): Maximum allowed length per patch.
+
+ Returns:
+ torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
+ """
+ if max_patch_length is None:
+ return patch_lengths
+
+ batch_size = patch_lengths.size(0)
+ processed = []
+
+ for seq in patch_lengths:
+ splits = []
+ for length in seq[seq > 0]:
+ length = length.item()
+ full_chunks, remainder = divmod(length, max_patch_length)
+ splits.extend([max_patch_length] * full_chunks)
+ if remainder:
+ splits.append(remainder)
+ processed.append(splits)
+
+ # Find max length to pad to
+ max_len = max(len(splits) for splits in processed)
+ padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
+
+ for i, splits in enumerate(processed):
+ if splits:
+ padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
+
+ # Trim zero columns
+ if (padded != 0).any(dim=0).sum() < padded.shape[1]:
+ last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
+ padded = padded[:, :last_nonzero]
+
+ return padded
+
+
+class BltMLP(MllamaTextMLP):
+ pass
+
+
+class BltRMSNorm(MllamaTextRMSNorm):
+ pass
+
+
+class BltRotaryEmbedding(Cohere2RotaryEmbedding):
+ pass
+
+
+class BltTransformerLayer(MllamaSelfAttentionDecoderLayer):
+ def __init__(self, config, layer_idx: int):
+ super().__init__()
+
+ self.self_attn = BltSelfAttention(config=config, layer_idx=layer_idx)
+ self.mlp = BltMLP(config)
+ self.input_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+
+class BltSelfAttention(MllamaTextSelfAttention):
+ def __init__(self, config: BltConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.is_causal = True
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_embeddings: torch.Tensor,
+ use_cache: bool = False,
+ past_key_values=None,
+ cache_position=None,
+ **kwargs,
+ ):
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_embeddings=position_embeddings,
+ use_cache=use_cache,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+
+class BltCrossAttention(MllamaTextCrossAttention):
+ """Cross-attention module for Blt, following transformers style"""
+
+ def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] = None):
+ super().__init__()
+ self.is_causal = False
+ self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+ self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ bsz, q_len, _ = hidden_states.size()
+ query_states = self.q_norm(hidden_states)
+ query_states = self.q_proj(query_states)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ if cross_attention_states is not None:
+ cross_attention_states = self.k_norm(cross_attention_states)
+ key_states = self.k_proj(cross_attention_states)
+ value_states = self.v_proj(cross_attention_states)
+ key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ if past_key_values is not None:
+ key_states, value_states = past_key_values.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ elif cache_position[0] != 0:
+ key_states, value_states = (
+ past_key_values.layers[self.layer_idx].keys,
+ past_key_values.layers[self.layer_idx].values,
+ )
+ else:
+ raise ValueError(
+ "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
+ )
+ attention_interface: Callable = eager_attention_forward
+
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ attn_output = attn_output + hidden_states
+ return attn_output, attn_weights
+
+
+@auto_docstring
+class BltPreTrainedModel(MllamaPreTrainedModel):
+ config: BltConfig
+ _supports_attention_backend = False
+ _no_split_modules = ["BltTransformerLayer"]
+ _can_record_outputs = {
+ "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"),
+ "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"),
+ }
+
+ def _init_weights(self, module):
+ raise AttributeError("No need to inherit it!")
+
+ def _update_causal_mask(self, module):
+ raise AttributeError("No need to inherit it!")
+
+ def _prepare_4d_causal_attention_mask_with_cache_position(self, module):
+ raise AttributeError("No need to inherit it!")
+
+
+class BltLocalEncoder(BltPreTrainedModel):
+ config: BltLocalEncoderConfig
+ _can_record_outputs = {
+ "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"),
+ }
+
+ def __init__(self, config: BltLocalEncoderConfig):
+ super().__init__(config)
+ self.gradient_checkpointing = False
+ self.config = config
+ self.layers = nn.ModuleList(
+ [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.rotary_emb = BltRotaryEmbedding(config=config)
+ self.patch_embedding_projection = nn.Linear(
+ in_features=config.hidden_size,
+ out_features=config.hidden_size * config.cross_attn_k,
+ bias=False,
+ )
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.cross_attn_layers = nn.ModuleList()
+ layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
+ for layer_idx in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
+ )
+
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ patch_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ num_patches: Optional[int] = None,
+ patch_ids: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ batch_size = inputs_embeds.shape[0]
+ hidden_states = F.dropout(inputs_embeds, p=self.config.dropout, training=self.training)
+
+ if position_ids is None:
+ position_ids = (
+ torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
+ )
+
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
+
+ for idx, layer in enumerate(self.layers):
+ hidden_states = layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers:
+ patch_embeds = self.patch_reduce(hidden_states, num_patches, patch_ids)
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ patch_embeds = patch_embeds.reshape(
+ batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size
+ )
+ layer_idx = idx if self.config.cross_attn_all_layers else 0
+ cross_attention_output, _ = self.cross_attn_layers[layer_idx](
+ hidden_states=patch_embeds,
+ cross_attention_states=hidden_states,
+ attention_mask=encoder_attention_mask,
+ **kwargs,
+ )
+ patch_embeds = patch_embeds + cross_attention_output
+ encoder_cross_states = patch_embeds
+ return hidden_states, encoder_cross_states
+
+ def patch_reduce(self, hidden_states, max_num_patches, patch_ids):
+ """
+ Reduce variable length patches to single embedding per patch
+ Note: this works with variable number of patches for different sequences in the batch
+ It handles variable length patches by assuming that patch_lengths will be 0 for any
+ extra patches on the *right*. Since there can be a variable number of patches
+ this function also return the number of patches for each sequence in the batch.
+ Any embeddings on the right that are not allocated to a patch
+ (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
+ will be sent to a dummy patch, which is trimmed before returning.
+ """
+ batch_size = hidden_states.shape[0]
+ embedding_dim = hidden_states.shape[-1]
+
+ patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
+
+ reduced_embeddings = torch.zeros(
+ (batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ reduced_embeddings = reduced_embeddings.scatter_reduce(
+ src=hidden_states,
+ dim=1,
+ index=patch_ids,
+ reduce="amax",
+ include_self=False,
+ )
+ reduced_embeddings = reduced_embeddings[:, :max_num_patches, :]
+
+ return reduced_embeddings
+
+
+class BltLocalDecoder(BltPreTrainedModel):
+ config: BltLocalDecoderConfig
+
+ def __init__(self, config: BltLocalDecoderConfig):
+ super().__init__(config)
+ self.gradient_checkpointing = False
+ self.config = config
+ self.cross_attn_decoder = True
+ self.layers = nn.ModuleList(
+ [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.rotary_emb = BltRotaryEmbedding(config=config)
+ self.patch_embedding_projection = nn.Linear(
+ in_features=config.hidden_size_global,
+ out_features=config.hidden_size * config.cross_attn_k,
+ bias=False,
+ )
+ self.norm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.cross_attn_layers = nn.ModuleList()
+ layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
+ for layer_idx in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
+ )
+
+ self.post_init()
+
+ @check_model_inputs
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ patch_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ batch_size = inputs_embeds.shape[0]
+ hidden_states = inputs_embeds
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ patch_embeds = patch_embeds.reshape(
+ batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size
+ )
+
+ if patch_embeds is not None and not self.cross_attn_decoder:
+ hidden_states = hidden_states + patch_embeds
+
+ if position_ids is None:
+ position_ids = (
+ torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
+ )
+
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
+
+ for i, layer in enumerate(self.layers):
+ if i == 0 or self.config.cross_attn_all_layers:
+ cross_attention_output, _ = self.cross_attn_layers[i](
+ hidden_states=hidden_states,
+ cross_attention_states=patch_embeds,
+ attention_mask=encoder_attention_mask,
+ **kwargs,
+ )
+ hidden_states = hidden_states + cross_attention_output
+ hidden_states = layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ logits = self.norm(hidden_states)
+ return logits
+
+
+class BltGlobalTransformer(BltPreTrainedModel):
+ config: BltGlobalTransformerConfig
+ _can_record_outputs = {
+ "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"),
+ }
+
+ def __init__(self, config: BltGlobalTransformerConfig):
+ super().__init__(config)
+ self.config = config
+ self.layers = nn.ModuleList()
+ for layer_idx in range(config.num_hidden_layers):
+ self.layers.append(BltTransformerLayer(config, layer_idx))
+ self.rotary_emb = BltRotaryEmbedding(config=config)
+
+ # Create token embedding projection (use nn.Identity() when no projection needed)
+ if getattr(config, "encoder_cross_output_size", None) is not None:
+ self.token_embedding_projection = nn.Linear(
+ config.encoder_cross_output_size, config.hidden_size, bias=False
+ )
+ else:
+ self.token_embedding_projection = nn.Identity()
+
+ self.post_init()
+
+ def forward(
+ self,
+ input_embeds: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ batch_size, seq_len, _ = input_embeds.shape
+ hidden_states = self.token_embedding_projection(input_embeds)
+ hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
+ if position_ids is None:
+ position_ids = (
+ torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
+ )
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ for i, layer in enumerate(self.layers):
+ hidden_states = layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ return hidden_states
+
+
+class BltPatcher(BltPreTrainedModel):
+ config: BltPatcherConfig
+
+ def __init__(self, config: BltPatcherConfig):
+ super().__init__(config)
+ self.rotary_emb = BltRotaryEmbedding(config=self.config)
+ self.layers = nn.ModuleList()
+ for layer_idx in range(self.config.num_hidden_layers):
+ self.layers.append(BltTransformerLayer(self.config, layer_idx))
+ self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
+ self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.lm_head = nn.Linear(
+ self.config.hidden_size,
+ self.config.vocab_size,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ patch_size: Optional[int] = None,
+ threshold: Optional[float] = None,
+ max_patch_length: Optional[int] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for layer in self.layers:
+ hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask)
+
+ logits = self.lm_head(self.norm(hidden_states))
+ prediction_entropies = torch.distributions.Categorical(logits=logits).entropy()
+
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+ if patch_size is not None:
+ patch_lengths = self.patch_lengths_from_entropies(
+ entropies=prediction_entropies,
+ sequence_length=sequence_length,
+ patch_size=patch_size,
+ threshold=threshold,
+ )
+ else:
+ patch_lengths = torch.ones(
+ (batch_size, sequence_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device
+ )
+ patch_lengths = process_patch_lengths(patch_lengths, max_patch_length)
+ return prediction_entropies, patch_lengths, logits
+
+ @staticmethod
+ def patch_lengths_from_entropies(
+ entropies,
+ sequence_length,
+ patch_size=None,
+ threshold=None,
+ ):
+ """
+ Computes patch lengths from token entropies.
+
+ Depending on whether a threshold is provided, the function uses either:
+ - Thresholding the entropy values (when `threshold` is set).
+ """
+
+ batch_size = entropies.shape[0]
+
+ # Always include token 0 and 1 as starting tokens
+ init_tokens = (
+ torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1)
+ )
+ offset = init_tokens.shape[1]
+
+ # Ignore first token entropy (BOS)
+ entropies = entropies[:, 1:]
+
+ # Threshold the entropy values to define patch start points
+ patch_mask = entropies > threshold
+
+ seq_len = patch_mask.shape[1]
+
+ # Create patch IDs (token indices), and add a sentinel to ensure alignment
+ token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1)
+ sentinel = torch.full_like(token_indices, seq_len)
+ padded_indices = torch.cat([token_indices, sentinel], dim=1)
+
+ # Pad mask with inverse to align sentinel correctly
+ padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1)
+
+ # Select indices where mask is True
+ patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len)
+ max_valid_patches = patch_mask.sum(dim=1).max()
+ patch_starts = patch_starts[:, :max_valid_patches]
+
+ # Offset patch starts to account for the two initial tokens
+ patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1)
+
+ # Compute patch end positions by shifting start positions
+ last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1)
+ patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1)
+
+ patch_lengths = patch_ends - patch_start_ids + 1
+
+ return patch_lengths
+
+
+class BltModel(BltPreTrainedModel):
+ def __init__(self, config: BltConfig):
+ super().__init__(config)
+ self.gradient_checkpointing = False
+
+ self.config = config
+ self.local_encoder = BltLocalEncoder(config.encoder_config)
+ self.global_transformer = BltGlobalTransformer(config.global_config)
+ self.local_decoder = BltLocalDecoder(config.decoder_config)
+ num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size)
+ total_vocab_size = config.encoder_hash_byte_group_vocab * num_embeddings
+ self.encoder_hash_tok_embedding = nn.Embedding(total_vocab_size, config.encoder_config.hidden_size)
+ if self.config.patch_in_forward:
+ self.patcher = BltPatcher(config.patcher_config)
+ self.patcher.eval()
+ for param in self.patcher.parameters():
+ param.requires_grad = False
+ else:
+ self.patcher = None
+ self.post_init()
+
+ @check_model_inputs
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ patch_lengths: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ # Extract input embeddings as early as possible
+ if inputs_embeds is not None:
+ encoder_embeds = inputs_embeds
+ batch_size, sequence_length, _ = inputs_embeds.shape
+ else:
+ batch_size, sequence_length = input_ids.shape
+ encoder_embeds = compute_hash_embeddings(
+ input_ids,
+ self.local_encoder,
+ self.encoder_hash_tok_embedding,
+ self.config.encoder_hash_byte_group_nb_functions,
+ self.config.encoder_hash_byte_group_size,
+ self.config.encoder_hash_byte_group_vocab,
+ )
+
+ if patch_lengths is None:
+ if self.config.patching_mode == "entropy" and self.patcher is not None:
+ if input_ids is None:
+ raise ValueError("input_ids is required for entropy-based patching")
+ _, patch_lengths, _ = self.patcher(
+ input_ids,
+ patch_size=self.config.patch_size,
+ threshold=self.config.patching_threshold,
+ max_patch_length=self.config.max_patch_length,
+ patching_batch_size=self.config.patching_batch_size,
+ device=input_ids.device,
+ )
+ else:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ dtype = input_ids.dtype if input_ids is not None else inputs_embeds.dtype
+ patch_lengths = process_patch_lengths(
+ torch.ones((batch_size, sequence_length + 1), dtype=dtype, device=device),
+ self.config.max_patch_length,
+ )
+ patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length)
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + encoder_embeds.shape[1], device=encoder_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=encoder_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ cross_attn_mask_enc = _prepare_patch_cross_attention_mask(
+ patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype
+ )
+ encoder_hidden_states, encoder_cross_states = self.local_encoder(
+ input_ids=input_ids,
+ inputs_embeds=encoder_embeds,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ encoder_attention_mask=cross_attn_mask_enc,
+ num_patches=patch_lengths.shape[1],
+ patch_ids=patch_ids,
+ **kwargs,
+ )
+ encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
+ global_cache_position = torch.arange(0, encoder_cross_states.shape[1], device=encoder_cross_states.device)
+ global_position_ids = global_cache_position.unsqueeze(0)
+ global_causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=encoder_cross_states,
+ attention_mask=None,
+ cache_position=global_cache_position,
+ past_key_values=None,
+ position_ids=None,
+ )
+
+ global_hidden_states = self.global_transformer(
+ input_embeds=encoder_cross_states,
+ attention_mask=global_causal_mask,
+ position_ids=global_position_ids,
+ **kwargs,
+ )
+ decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length)
+ cross_attn_mask_dec = _prepare_patch_cross_attention_mask(
+ decoder_patch_ids,
+ patch_lengths.shape[1],
+ sequence_length,
+ False,
+ self.config.cross_attn_k,
+ encoder_embeds.dtype,
+ )
+ output = self.local_decoder(
+ input_ids=input_ids,
+ inputs_embeds=encoder_hidden_states,
+ patch_embeds=global_hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ encoder_attention_mask=cross_attn_mask_dec,
+ **kwargs,
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=output,
+ past_key_values=past_key_values,
+ )
+
+ def get_input_embeddings(self):
+ return self.local_encoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.local_encoder.embed_tokens = value
+
+ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
+ batch_size = patch_lengths.shape[0]
+ patch_starts = torch.cat(
+ [
+ torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
+ patch_lengths.cumsum(dim=-1)[:, :-1],
+ ],
+ dim=-1,
+ )
+ token_positions = torch.arange(seq_len, device=patch_lengths.device)
+ return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1
+
+
+class BltForCausalLM(MllamaForCausalLM):
+ config: BltConfig
+ _can_compile_fullgraph = False
+ base_model_prefix = "model"
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: BltConfig):
+ super().__init__(config)
+ self.vocab_size = config.vocab_size
+ self.model = BltModel(config)
+ self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False)
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ cross_attention_states: Optional[torch.LongTensor] = None, # Keep for compatibility
+ cross_attention_mask: Optional[torch.LongTensor] = None,
+ full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ # Call parent forward but exclude cross_attention_states from model call
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cross_attention_mask=cross_attention_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "BltPreTrainedModel",
+ "BltModel",
+ "BltPatcher",
+ "BltForCausalLM",
+]
diff --git a/tests/causal_lm_tester.py b/tests/causal_lm_tester.py
index 8600f1dc265e..4757d4b69c6c 100644
--- a/tests/causal_lm_tester.py
+++ b/tests/causal_lm_tester.py
@@ -497,7 +497,7 @@ def _config_supports_rope_scaling(config: PretrainedConfig) -> bool:
# Has rope_theta (and no rope_scaling) -> probably an older model, but should support rope scaling as well
main_config_has_rope = hasattr(config, "rope_scaling") or hasattr(config, "rope_theta")
sub_config_has_rope = any(
- hasattr(config[sub_config], "rope_scaling") or hasattr(config[sub_config], "rope_theta")
+ hasattr(getattr(config, sub_config), "rope_scaling") or hasattr(getattr(config, sub_config), "rope_theta")
for sub_config in config.sub_configs.keys()
)
return main_config_has_rope or sub_config_has_rope
diff --git a/tests/models/blt/__init__.py b/tests/models/blt/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py
new file mode 100644
index 000000000000..dc4703974781
--- /dev/null
+++ b/tests/models/blt/test_modeling_blt.py
@@ -0,0 +1,561 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Blt model."""
+
+import unittest
+
+import pytest
+from parameterized import parameterized
+
+from transformers import AutoTokenizer, is_torch_available, set_seed
+from transformers.testing_utils import (
+ cleanup,
+ require_read_token,
+ require_torch,
+ require_torch_accelerator,
+ require_torch_bf16,
+ slow,
+ torch_device,
+)
+
+from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
+from ...test_modeling_common import (
+ TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
+ _test_eager_matches_sdpa_inference,
+ ids_tensor,
+)
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import BltConfig, BltForCausalLM, BltModel
+from transformers.models.blt.modeling_blt import BltRotaryEmbedding
+
+
+class BltModelTester(CausalLMModelTester):
+ if is_torch_available():
+ config_class = BltConfig
+ base_model_class = BltModel
+ causal_lm_class = BltForCausalLM
+
+ def __init__(
+ self,
+ parent,
+ ignore_index=-100,
+ seq_length=7,
+ is_training=True,
+ ):
+ super().__init__(parent)
+ self.parent = parent
+ self.ignore_index = ignore_index
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.batch_size = 3
+
+ # Common parameters for all configs
+ self.hidden_size = 16
+ self.num_hidden_layers = 1
+ self.num_attention_heads = 2
+ self.num_key_value_heads = 2
+ self.intermediate_size = 32
+ self.hidden_act = "silu"
+ self.max_position_embeddings = 32
+ self.vocab_size = 32
+ self.rope_theta = 500000.0
+ self.rope_scaling = {"rope_type": "default"}
+ self.rms_norm_eps = 1e-5
+ self.dropout = 0.0
+ self.encoder_hash_byte_group_size = [2, 3]
+ self.encoder_hash_byte_group_vocab = 64
+ self.encoder_hash_byte_group_nb_functions = 1
+ # Common parameters for all configs
+ self.patcher_config = {
+ "hidden_size": self.hidden_size,
+ "num_hidden_layers": self.num_hidden_layers,
+ "num_attention_heads": self.num_attention_heads,
+ "num_key_value_heads": self.num_key_value_heads,
+ "intermediate_size": self.intermediate_size,
+ "max_position_embeddings": self.max_position_embeddings,
+ "rope_theta": self.rope_theta,
+ "rope_scaling": self.rope_scaling,
+ "hidden_act": self.hidden_act,
+ "rms_norm_eps": self.rms_norm_eps,
+ "dropout": self.dropout,
+ }
+
+ self.encoder_config = {
+ "hidden_size": self.hidden_size,
+ "num_hidden_layers": self.num_hidden_layers,
+ "num_attention_heads": self.num_attention_heads,
+ "num_key_value_heads": self.num_key_value_heads,
+ "intermediate_size": self.intermediate_size,
+ "max_position_embeddings": self.max_position_embeddings,
+ "rope_theta": self.rope_theta,
+ "rope_scaling": self.rope_scaling,
+ "hidden_act": self.hidden_act,
+ "rms_norm_eps": self.rms_norm_eps,
+ "dropout": self.dropout,
+ }
+
+ self.decoder_config = {
+ "vocab_size": self.vocab_size,
+ "hidden_size": self.hidden_size,
+ "hidden_size_global": self.hidden_size * 2, # Must match global transformer output size
+ "num_hidden_layers": self.num_hidden_layers,
+ "num_attention_heads": self.num_attention_heads,
+ "num_key_value_heads": self.num_key_value_heads,
+ "intermediate_size": self.intermediate_size,
+ "max_position_embeddings": self.max_position_embeddings,
+ "rope_theta": self.rope_theta,
+ "rope_scaling": self.rope_scaling,
+ "hidden_act": self.hidden_act,
+ "rms_norm_eps": self.rms_norm_eps,
+ "dropout": self.dropout,
+ }
+
+ self.global_config = {
+ "hidden_size": self.hidden_size * 2, # Double the hidden size for global transformer
+ "num_hidden_layers": self.num_hidden_layers,
+ "num_attention_heads": self.num_attention_heads,
+ "num_key_value_heads": self.num_key_value_heads,
+ "intermediate_size": self.intermediate_size,
+ "max_position_embeddings": self.max_position_embeddings,
+ "rope_theta": self.rope_theta,
+ "rope_scaling": self.rope_scaling,
+ "hidden_act": self.hidden_act,
+ "rms_norm_eps": self.rms_norm_eps,
+ "dropout": self.dropout,
+ }
+
+ self.num_hidden_layers = self.encoder_config["num_hidden_layers"]
+
+ def get_config(self):
+ config = BltConfig(
+ vocab_size=self.vocab_size,
+ max_position_embeddings=self.max_position_embeddings,
+ patch_in_forward=False, # Disable patching for tests
+ patch_size=4,
+ patching_mode="entropy",
+ patching_threshold=1.335442066192627,
+ patching_batch_size=1,
+ max_patch_length=None,
+ cross_attn_k=2,
+ encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
+ encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab,
+ encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions,
+ patcher_config=self.patcher_config,
+ encoder_config=self.encoder_config,
+ decoder_config=self.decoder_config,
+ global_config=self.global_config,
+ rope_scaling=self.rope_scaling,
+ tie_word_embeddings=False,
+ )
+
+ config.num_attention_heads = config.decoder_config.num_attention_heads
+ config.num_hidden_layers = config.encoder_config.num_hidden_layers
+ config.hidden_size = config.decoder_config.hidden_size
+
+ return config
+
+
+@require_torch
+class BltModelTest(CausalLMModelTest, unittest.TestCase):
+ all_model_classes = (
+ (
+ BltModel,
+ BltForCausalLM,
+ )
+ if is_torch_available()
+ else ()
+ )
+ pipeline_model_mapping = (
+ {
+ "feature-extraction": BltModel,
+ "text-generation": BltForCausalLM,
+ }
+ if is_torch_available()
+ else {}
+ )
+ test_headmasking = False
+ test_pruning = False
+ fx_compatible = False
+ model_tester_class = BltModelTester
+ rotary_embedding_layer = BltRotaryEmbedding # Enables RoPE tests if set
+
+ # Need to use `0.8` instead of `0.9` for `test_cpu_offload`
+ # This is because we are hitting edge cases with the causal_mask buffer
+ model_split_percents = [0.5, 0.7, 0.8]
+
+ # used in `test_torch_compile_for_training`
+ _torch_compile_train_cls = BltForCausalLM if is_torch_available() else None
+
+ @pytest.mark.generate
+ @parameterized.expand([("greedy", 1), ("beam search", 2)])
+ @unittest.skip(
+ "Blt requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs"
+ )
+ def test_generate_from_inputs_embeds(self, _, num_beams):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(
+ "Blt requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs"
+ )
+ def test_inputs_embeds_matches_input_ids(self):
+ pass
+
+ @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
+ def test_eager_matches_sdpa_inference(
+ self,
+ name,
+ torch_dtype,
+ padding_side,
+ use_attention_mask,
+ output_attentions,
+ enable_kernels,
+ ):
+ "We need to relax a bit the `atols` for fp32 here due to the altup projections"
+ atols = {
+ ("cpu", False, torch.float32): 2e-2, # this was relaxed
+ ("cpu", False, torch.float16): 5e-3,
+ ("cpu", False, torch.bfloat16): 1e-2,
+ ("cpu", True, torch.float32): 2e-2, # this was relaxed
+ ("cpu", True, torch.float16): 5e-3,
+ ("cpu", True, torch.bfloat16): 1e-2,
+ ("cuda", False, torch.float32): 2e-2, # this was relaxed
+ ("cuda", False, torch.bfloat16): 1e-2,
+ ("cuda", False, torch.float16): 5e-3,
+ ("cuda", True, torch.float32): 2e-2, # this was relaxed
+ ("cuda", True, torch.bfloat16): 1e-2,
+ ("cuda", True, torch.float16): 5e-3,
+ }
+ _test_eager_matches_sdpa_inference(
+ self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels, atols=atols
+ )
+
+ @parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
+ def test_model_rope_scaling_from_config(self, scaling_type):
+ """Override rope scaling from config test to handle Blt's sub-config structure."""
+ if self.rotary_embedding_layer is None:
+ self.skipTest("Rotary embedding layer not set")
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ short_input = ids_tensor([1, 10], config.vocab_size)
+ long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
+
+ set_seed(42) # Fixed seed at init time so the two models get the same random weights
+ original_model = self.model_tester_class.base_model_class(config)
+ original_model.to(torch_device)
+ original_model.eval()
+ original_short_output = original_model(short_input).last_hidden_state
+ original_long_output = original_model(long_input).last_hidden_state
+
+ set_seed(42) # Fixed seed at init time so the two models get the same random weights
+ config.rope_scaling = {"rope_type": scaling_type, "factor": 10.0}
+ # Propagate rope_scaling to sub-configs for Blt
+ config.encoder_config.rope_scaling = config.rope_scaling
+ config.decoder_config.rope_scaling = config.rope_scaling
+ config.global_config.rope_scaling = config.rope_scaling
+ config.patcher_config.rope_scaling = config.rope_scaling
+
+ scaled_model = self.model_tester_class.base_model_class(config)
+ scaled_model.to(torch_device)
+ scaled_model.eval()
+ scaled_short_output = scaled_model(short_input).last_hidden_state
+ scaled_long_output = scaled_model(long_input).last_hidden_state
+
+ # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
+ # maximum sequence length, so the outputs for the short input should match.
+ if scaling_type == "dynamic":
+ torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
+ else:
+ self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
+
+ self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
+
+ @unittest.skip(reason="Decoder cannot keep gradients")
+ def test_flex_attention_with_grads():
+ pass
+
+
+@require_torch_accelerator
+class BltIntegrationTest(unittest.TestCase):
+ def tearDown(self):
+ # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves
+ # some memory allocated in the cache, which means some object is not being released properly. This causes some
+ # unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU.
+ # Investigate the root cause.
+ cleanup(torch_device, gc_collect=False)
+
+ @slow
+ @require_read_token
+ def test_model(self):
+ NUM_TOKENS_TO_GENERATE = 200
+ EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s"
+
+ prompt = "my name is"
+
+ model = BltForCausalLM.from_pretrained("itazap/blt-1b-hf", device_map="auto", attn_implementation="sdpa")
+
+ tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf")
+
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
+
+ generated_ids = model.generate(
+ **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False
+ )
+
+ output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
+ self.assertEqual(output_text, EXPECTED_TEXT)
+
+ @slow
+ @require_read_token
+ def test_model_logits(self):
+ EXPECTED_OUTPUT = torch.tensor(
+ [
+ [
+ -10.4948,
+ -10.7065,
+ -6.1813,
+ -10.5545,
+ -10.3428,
+ -9.1493,
+ -8.4937,
+ -8.6382,
+ -9.2159,
+ -9.5907,
+ -9.3679,
+ -8.4184,
+ -9.0655,
+ -3.4436,
+ 2.9616,
+ -10.3157,
+ -6.3723,
+ -6.0133,
+ -9.7100,
+ -9.2128,
+ -8.8064,
+ -9.8179,
+ -9.7516,
+ -9.4681,
+ -9.7715,
+ -9.4897,
+ -9.0491,
+ -9.8098,
+ -9.4648,
+ -9.3294,
+ ],
+ [
+ -13.3010,
+ -13.1910,
+ -5.7230,
+ -13.2895,
+ -13.4864,
+ -8.7140,
+ -7.0275,
+ -7.0182,
+ -10.1362,
+ -10.3762,
+ -9.9086,
+ -7.8049,
+ -8.8660,
+ -5.2711,
+ -3.5778,
+ -12.5346,
+ -9.1609,
+ -6.7925,
+ -10.3717,
+ -9.2650,
+ -10.6393,
+ -11.4807,
+ -11.2128,
+ -10.9615,
+ -10.5806,
+ -10.8873,
+ -11.0651,
+ -11.3471,
+ -10.5437,
+ -9.9688,
+ ],
+ ]
+ ).to(torch_device)
+
+ input_ids = [1, 42, 21, 12, 43, 23, 1, 4]
+
+ model = BltForCausalLM.from_pretrained("itazap/blt-1b-hf", attn_implementation="sdpa", device_map="auto")
+
+ with torch.no_grad():
+ output = model(torch.tensor([input_ids]).to(torch_device))[0]
+
+ torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-4, atol=1e-4)
+
+ @slow
+ @require_read_token
+ @require_torch_bf16
+ def test_model_bf16(self):
+ """Test Blt model with bfloat16 precision."""
+ NUM_TOKENS_TO_GENERATE = 200
+ EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s"
+
+ prompt = "my name is"
+
+ model = BltForCausalLM.from_pretrained(
+ "itazap/blt-1b-hf", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf")
+
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
+
+ generated_ids = model.generate(
+ **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False
+ )
+
+ output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
+ self.assertEqual(output_text, EXPECTED_TEXT)
+
+ @slow
+ @require_read_token
+ @require_torch_bf16
+ def test_model_logits_bf16(self):
+ """Test Blt model logits with bfloat16 precision."""
+
+ EXPECTED_OUTPUT = torch.tensor(
+ [
+ [
+ -10.5000,
+ -10.6875,
+ -6.1875,
+ -10.5625,
+ -10.3125,
+ -9.1875,
+ -8.5000,
+ -8.6875,
+ -9.1875,
+ -9.5625,
+ -9.3750,
+ -8.5000,
+ -9.0625,
+ -3.4219,
+ 2.9531,
+ -10.3125,
+ -6.4062,
+ -6.0000,
+ -9.6875,
+ -9.1875,
+ -8.8125,
+ -9.8125,
+ -9.7500,
+ -9.4375,
+ -9.8125,
+ -9.5000,
+ -9.0000,
+ -9.8125,
+ -9.4375,
+ -9.3125,
+ ],
+ [
+ -13.2500,
+ -13.1875,
+ -5.6875,
+ -13.3125,
+ -13.5000,
+ -8.7500,
+ -7.0625,
+ -7.0312,
+ -10.1250,
+ -10.3750,
+ -9.8750,
+ -7.8438,
+ -8.8750,
+ -5.2812,
+ -3.5625,
+ -12.5000,
+ -9.1875,
+ -6.8125,
+ -10.3750,
+ -9.3125,
+ -10.6250,
+ -11.5000,
+ -11.2500,
+ -11.0000,
+ -10.5625,
+ -10.8750,
+ -11.0625,
+ -11.3750,
+ -10.5625,
+ -10.0000,
+ ],
+ ]
+ ).to(torch_device)
+
+ input_ids = [1, 42, 21, 12, 43, 23, 1, 4]
+
+ model = BltForCausalLM.from_pretrained(
+ "itazap/blt-1b-hf", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16
+ )
+
+ with torch.no_grad():
+ output = model(torch.tensor([input_ids]).to(torch_device))[0]
+
+ torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-3, atol=1e-3)
+
+ @slow
+ @require_read_token
+ def test_model_eager(self):
+ """Test Blt model with bfloat16 precision using eager attention implementation."""
+ NUM_TOKENS_TO_GENERATE = 200
+ EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s"
+
+ prompt = "my name is"
+
+ model = BltForCausalLM.from_pretrained("itazap/blt-1b-hf", device_map="auto", attn_implementation="eager")
+
+ tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf")
+
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
+
+ generated_ids = model.generate(
+ **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False
+ )
+
+ output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
+ self.assertEqual(output_text, EXPECTED_TEXT)
+
+ @slow
+ @require_read_token
+ @require_torch_bf16
+ def test_model_bf16_static_cache(self):
+ """Test Blt model with bfloat16 precision and static cache."""
+ NUM_TOKENS_TO_GENERATE = 200
+ EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s"
+
+ prompt = "my name is"
+
+ model = BltForCausalLM.from_pretrained(
+ "itazap/blt-1b-hf", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16
+ )
+
+ model.generation_config.cache_implementation = "static"
+
+ tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf")
+
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
+
+ generated_ids = model.generate(
+ **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False
+ )
+
+ output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
+ self.assertEqual(output_text, EXPECTED_TEXT)
diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py
index 9eeda74afa48..be52c5298472 100644
--- a/utils/check_docstrings.py
+++ b/utils/check_docstrings.py
@@ -128,6 +128,8 @@
"BlipVisionConfig",
"BloomConfig",
"BloomTokenizerFast",
+ "BLTConfig",
+ "BLTPatcherConfig",
"BridgeTowerTextConfig",
"BridgeTowerVisionConfig",
"BrosModel",
@@ -460,6 +462,8 @@
"ZeroShotImageClassificationPipeline",
"ZeroShotObjectDetectionPipeline",
"Llama4TextConfig",
+ "BltConfig",
+ "BltPatcherConfig",
}
# In addition to the objects above, we also ignore objects with certain prefixes. If you add an item to the list
# below, make sure to add a comment explaining why.
diff --git a/utils/check_repo.py b/utils/check_repo.py
index e932e5bfc24c..ffd3fb56d773 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -99,6 +99,9 @@
"Glm4vVisionModel",
"Glm4vMoeVisionModel",
"EvollaSaProtPreTrainedModel",
+ "BltLocalEncoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM.
+ "BltLocalDecoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM.
+ "BltGlobalTransformer", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM.
"Ovis2VisionModel",
]
@@ -180,6 +183,10 @@
"CsmDepthDecoderForCausalLM", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
"CsmDepthDecoderModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
"CsmBackboneModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
+ "BltPatcher", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM.
+ "BltLocalEncoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM.
+ "BltLocalDecoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM.
+ "BltGlobalTransformer", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM.
"Florence2VisionBackbone", # Building part of bigger (tested) model. Tested implicitly through Florence2ForConditionalGeneration.
]
)
@@ -402,6 +409,7 @@
"CsmDepthDecoderModel", # Building part of a bigger model
"CsmDepthDecoderForCausalLM", # Building part of a bigger model
"CsmForConditionalGeneration", # Building part of a bigger model
+ "BltPatcher", # Building part of a bigger model, tested implicitly through BltForCausalLM
"Florence2VisionBackbone", # Building part of a bigger model
]
@@ -1148,6 +1156,9 @@ def ignore_undocumented(name: str) -> bool:
# MMBT model does not really work.
if name.startswith("MMBT"):
return True
+ # BLT models are internal building blocks, tested implicitly through BltForCausalLM
+ if name.startswith("Blt"):
+ return True
if name in SHOULD_HAVE_THEIR_OWN_PAGE:
return True
return False