From 0826152372cb6f4279152060633083973da620a5 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 4 Jun 2025 11:49:49 +0000 Subject: [PATCH 001/139] blt wip --- src/demo_hf.py | 200 ++ src/transformers/models/blt_wip/blt_args.py | 250 ++ .../models/blt_wip/modeling_blt_wip.py | 2279 +++++++++++++++++ .../models/blt_wip/tokenizers/__init__.py | 1 + .../blt_wip/tokenizers/abstract_tokenizer.py | 23 + .../blt_wip/tokenizers/blt_tokenizer.py | 155 ++ .../tokenizers/sentence_piece_tokenizer.py | 62 + 7 files changed, 2970 insertions(+) create mode 100644 src/demo_hf.py create mode 100644 src/transformers/models/blt_wip/blt_args.py create mode 100644 src/transformers/models/blt_wip/modeling_blt_wip.py create mode 100644 src/transformers/models/blt_wip/tokenizers/__init__.py create mode 100644 src/transformers/models/blt_wip/tokenizers/abstract_tokenizer.py create mode 100644 src/transformers/models/blt_wip/tokenizers/blt_tokenizer.py create mode 100644 src/transformers/models/blt_wip/tokenizers/sentence_piece_tokenizer.py diff --git a/src/demo_hf.py b/src/demo_hf.py new file mode 100644 index 000000000000..85cc6198f6df --- /dev/null +++ b/src/demo_hf.py @@ -0,0 +1,200 @@ +import os + +import torch +import typer + +from transformers.models.blt_wip.modeling_blt_wip import ByteLatentTransformer, ByteLatentTransformerArgs +from transformers.models.blt_wip.tokenizers.blt_tokenizer import BltTokenizer + +from huggingface_hub import hf_hub_download +import json + +import logging +import os + +import torch + +from transformers.models.blt_wip.modeling_blt_wip import Patcher, ByteLatentTransformer +from transformers.models.blt_wip.tokenizers.blt_tokenizer import BltTokenizer + +logger = logging.getLogger() + +import os +os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" + +def get_generation_range( + prompt_tokens: list[list[int]] | None, max_gen_len: int +) -> tuple[int, int]: + batch_min_prompt_length = min([len(t) for t in prompt_tokens]) + batch_max_prompt_length = max([len(t) for t in prompt_tokens]) + return batch_min_prompt_length, batch_max_prompt_length + max_gen_len + + +def sample_top_k(probs, k): + topk_value, _ = torch.topk(probs, k) # batch_sz x topk + min_value_top_k = topk_value[:, [-1]] + probs[probs < min_value_top_k] = 0.0 + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs, num_samples=1) + return next_token + + +def sample_top_p(probs, p): + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + + +@torch.inference_mode() +def generate( + prompts: list[str] | None, + *, + model: ByteLatentTransformer, + tokenizer: BltTokenizer, + patcher: Patcher, + max_prompt_len: int = 256, + max_gen_len: int = 256, + use_sampling: bool = False, + temp: float = 1.0, + top_k: int = 0, + top_p: float = 0.0, + remove_prompts: bool = True, +) -> list[list[int]]: + assert ( + patcher.realtime_patching + ), "generate_nocache requires patcher.realtime_patching=True" + model.eval() + prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts] + # Truncation + prompt_tokens = [ + t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :] + for t in prompt_tokens + ] + start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len) + batch_size = len(prompt_tokens) + tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).cuda().long() + + # Copy inputs to tensor for generated tokens + for i, row_tokens in enumerate(prompt_tokens): + tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long() + input_text_mask = tokens != tokenizer.pad_id + + for i, curr_pos in enumerate(range(start_pos, end_pos)): + current_tokens = tokens[:, :curr_pos] + patch_lengths, _ = patcher.patch(current_tokens, include_next_token=True) + logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1] + + if use_sampling: + probs = torch.softmax(logits / temp, dim=-1) + if top_p > 0.0: + next_token = sample_top_p(probs, top_p) + elif top_k > 0: + next_token = sample_top_k(probs, top_k) + else: + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(logits, dim=-1) + + next_token = torch.where( + input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token + ) + tokens[:, curr_pos] = next_token + + if remove_prompts: + generated_tokens = [ + t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist() + for i, t in enumerate(tokens) + ] + else: + generated_tokens = [ + t[: len(prompt_tokens[i]) + max_gen_len].tolist() + for i, t in enumerate(tokens) + ] + return generated_tokens + + + +def main(prompt: str = "my name is", model_name: str = "blt-1b"): + # distributed_args = DistributedArgs() + # distributed_args.configure_world() + # if not torch.distributed.is_initialized(): + # setup_torch_distributed(distributed_args) + + # Set device and ensure CUDA is available + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required but not available") + device = torch.device("cuda") + torch.cuda.empty_cache() # Clear any existing CUDA memory + + assert model_name in ["blt-1b", "blt-7b"] + model_name = model_name.replace("-", "_") + + #HF + blt_repo = "facebook/blt-1b" + + # Get the model's default configuration and entropy model params + print("Loading model configuration...") + config_path = hf_hub_download(repo_id=blt_repo, filename="config.json") + entropy_params_path = hf_hub_download(repo_id=blt_repo, filename="entropy_model/params.json") + + with open(config_path, 'r') as f: + config = json.load(f) + with open(entropy_params_path, 'r') as f: + entropy_params = json.load(f) + + config['args']['attn_bias_type'] = 'causal' + config['args']['attn_impl'] = 'sdpa' + + model_args = ByteLatentTransformerArgs(**config["args"]) + + patcher_args = entropy_params["data"]["patcher_args"] + model_args.patch_in_forward = True + model_args.patch_size = patcher_args["patch_size"] + model_args.patching_mode = patcher_args["patching_mode"] + model_args.patching_threshold = patcher_args["threshold"] + model_args.patching_threshold_add = patcher_args["threshold_add"] + model_args.max_patch_length = patcher_args["max_patch_length"] + model_args.patching_batch_size = patcher_args["patching_batch_size"] + model_args.patching_device = patcher_args["patching_device"] + model_args.monotonicity = patcher_args["monotonicity"] + + model = ByteLatentTransformer.from_pretrained(blt_repo, args=model_args).to(device) + + # Configure model's patcher + model.patcher.realtime_patching = True + model.patcher.entropy_model_checkpoint_dir = os.path.join( + "hf-weights", "entropy_model" + ) + + tokenizer = BltTokenizer( + vocab_size_unit_1=model_args.vocab_size, + add_bos=True, + add_eos=True + ) + + prompts = [prompt] + outputs = generate( + prompts, + model=model, + tokenizer=tokenizer, + patcher=model.patcher, # Use the model's patcher + max_gen_len=100 + ) + + text_outputs = [tokenizer.decode(t) for t in outputs] + for p, t in zip(prompts, text_outputs): + print(f'Prompt: "{p}"') + print(f'Completion: "{t}"') + print() + + # Clean up + torch.cuda.empty_cache() + + +if __name__ == "__main__": + typer.run(main) diff --git a/src/transformers/models/blt_wip/blt_args.py b/src/transformers/models/blt_wip/blt_args.py new file mode 100644 index 000000000000..d292d9de1f4d --- /dev/null +++ b/src/transformers/models/blt_wip/blt_args.py @@ -0,0 +1,250 @@ +from enum import Enum, auto +from typing import Any, List, Optional, Tuple, Union +from pydantic import BaseModel, ConfigDict, model_validator +from typing_extensions import Self + +EOS_ID: int = 2 + + +class InitStdFactor(str, Enum): + DISABLED = "disabled" # Init std is divided by 1.0 + GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) + CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) + DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 + + +class BaseTransformerArgs(BaseModel): + model_config = ConfigDict() + dim: int = 512 + n_layers: int = 8 + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None + + ffn_dim_multiplier: float | None = None + + multiple_of: int = 256 + + norm_eps: float = 1e-5 + + rope_theta: float = 10000.0 + rope_use_fp32_in_outer_product: bool = False + + init_base_std: float | None = None + init_std_factor: InitStdFactor = InitStdFactor.DISABLED + + max_seqlen: int = 1024 + + attn_impl: str | None = "sdpa" + attn_bias_type: str | None = None + # Special token config + eos_id: int | None = EOS_ID + +class ByteLatentTransformerArgs(BaseTransformerArgs): + # Basic model configuration + seed: int = 42 + vocab_size: int = -1 + dim: int = 512 + n_layers: int = 8 + n_heads: int = 8 + # TODO: What is the purpose of this parameter? + weight_tying: bool = False + patch_in_forward: bool = False + + # Architecture and dimensions + dim_token: int | None = None + dim_global: int = 512 + dim_local_decoder: int = 512 + dim_local_encoder: int = 512 + n_layers_global: int = 8 + n_layers_local_decoder: int = 8 + n_layers_local_encoder: int = 8 + + # Tokenization and patching + patch_size: float | None = None + patching_mode: str | None = None + patching_threshold: float | None = None + patching_threshold_add: float | None = None + monotonicity: bool = False + patching_batch_size: int = 1 + patching_device: str = "cuda" + max_patch_length: int | None = None + + # Encoder/Decoder configuration + tie_local_encoder_decoder_logits: bool = False + use_local_encoder_transformer: bool = False + encoder_lm_loss: bool = False + max_encoder_seq_length: int | None = None + pad_to_max_length: bool = False + encoder_enable_byte_ngrams: bool = False + encoder_enable_byte_group_hash: bool = False + ngram_vocab_sizes: int | None = None + + # Cross attention configurations + cross_attn_encoder: bool = False + cross_attn_decoder: bool = False + cross_attn_window_encoder: int | None = None + cross_attn_window_decoder: int | None = None + cross_attn_k: int | None = None + cross_attn_nheads: int | None = None + cross_attn_all_layers_decoder: bool = False + cross_attn_all_layers_encoder: bool = False + cross_attn_use_flex_attention: bool = True + cross_attn_init_by_pooling: bool = False + + # Encoder hash configurations + encoder_hash_byte_group_size: Any | None = None + encoder_hash_byte_group_vocab: int = 30000 + encoder_hash_byte_group_nb_functions: int = 3 + + # Model behavior and optimization + log_patch_lengths: bool = False + non_linearity: str = "swiglu" + use_rope: bool = True + recompute_fc1_out: bool = False + recompute_fc3_out: bool = False + recompute_attn: bool = True + custom_bwd: bool = False + layer_ckpt: str = "all" + + # Initialization and attention + init_use_gaussian: bool = True + init_use_depth: str = "current" + attn_bias_type: str = "causal" + alpha_depth: str = "disabled" + max_length: int = 2048 + + # Norm configuration + norm_eps: float = 1e-5 + norm_affine: bool = True + pre_norm: bool = True + norm_type: str = "rmsnorm" + + # Additional configurations + multiple_of: int = 256 + ffn_dim_multiplier: float = 1.0 + dropout: float = 0 + output_size: int = -1 + + # Additional parameters from ModelArgs + architecture: str = "vanilla" + share_encoder_decoder_emb: bool = True + global_local_decoder_residual_layer: str | None = None + + tokenize_with_bpe_delimiter: bool = False + patching_thresholds_str: str | None = None + tie_local_encoder_decoder: bool = False + encoder_preds_low_entropy_toks: float | None = None + encoder_preds_random_toks: float | None = None + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + encoder_ngram_table_dir: str | None = None + encoder_ngram_to_size_str: str | None = None + + # Model architecture params + entropy_model_checkpoint_dir: str | None = None + entropy_model_is_ngram_model: bool = False + downsampling_by_pooling: str | None = None + n_heads_global: int = 8 + n_heads_local_decoder: int = 8 + n_heads_local_encoder: int = 8 + n_kv_heads: int | None = None + n_kv_heads_global: int | None = None + conv_kernel_size: int | None = None + local_attention_window_len: int | None = None + + # Performance optimization + sequence_parallel: bool = False + loss_parallel: bool = False + fuse_sequence_parallel: bool = False + use_fsdp: bool = True + attn_to_keep: str = "all" + + # Parameter mixing + pm_size: int = 0 + + # Logging + full_logging_n_layers: int = 4 + + @model_validator(mode="after") + def check_hash_byte_sizes(self) -> Self: + if ( + self.encoder_hash_byte_group_size is not None + and type(self.encoder_hash_byte_group_size) == str + ): + self.encoder_hash_byte_group_size = [ + int(x) + for x in self.encoder_hash_byte_group_size.split(",") + if len(x) > 0 + ] + return self + + +class GlobalTransformerArgs(ByteLatentTransformerArgs): + # Global encoder specific dimensions + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + def __post_init__(self): + # Override base args with global encoder specific values + self.dim = self.dim_global + self.n_layers = self.n_layers_global + self.n_heads = self.n_heads_global + self.n_kv_heads = self.n_kv_heads_global + self.local_attention_window_len = None + self.cross_attn_encoder = False + self.cross_attn_decoder = False + + +class LocalDecoderArgs(ByteLatentTransformerArgs): + # Local decoder specific dimensions + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + def __post_init__(self): + # Override base args with local decoder specific values + self.dim = self.dim_local_decoder + self.n_layers = self.n_layers_local_decoder + self.n_heads = self.n_heads_local_decoder + self.cross_attn_encoder = False + self.cross_attn_init_by_pooling = False + self.attn_bias_type = "local_block_causal" + + +class LocalModelArgs(BaseTransformerArgs): + model_config = ConfigDict() + # Override defaults + attn_impl: str | None = "xformers" + attn_bias_type: str | None = "local_block_causal" + + # Local encoder specific dimensions + dropout: float + vocab_size: int + patch_size: float + sliding_window: int | None + use_rope: bool + cross_attn_encoder: bool | None + cross_attn_decoder: bool | None + cross_attn_k: int | None + cross_attn_init_by_pooling: bool + patching_mode: str + use_local_encoder_transformer: bool + downsampling_by_pooling: str | None + encoder_hash_byte_group_size: Any | None = None + cross_attn_all_layers_encoder: bool = False + cross_attn_all_layers_decoder: bool = False + cross_attn_nheads: int | None + + dim_token_emb: int + dim_patch_emb: int | None + + +class LMTransformerArgs(BaseTransformerArgs): + seed: int = 42 + + vocab_size: int = -1 + weight_tying: bool = False + + sliding_window: int | None = None + diff --git a/src/transformers/models/blt_wip/modeling_blt_wip.py b/src/transformers/models/blt_wip/modeling_blt_wip.py new file mode 100644 index 000000000000..208a39b126ed --- /dev/null +++ b/src/transformers/models/blt_wip/modeling_blt_wip.py @@ -0,0 +1,2279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from enum import Enum, auto +from typing import Any, List, Optional, Tuple, Union + +import torch +from huggingface_hub import PyTorchModelHubMixin +from pydantic import model_validator +from torch import nn +from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention +from typing_extensions import Self +import json +import logging + +import torch +import torch.nn +import torch.nn as nn +from pydantic import ConfigDict +from torch.nn import functional as F + +import abc + +import os +import time +from collections import defaultdict + +from pydantic import BaseModel + +SEP = " " +BOS_ID: int = 1 +EOS_ID: int = 2 +PAD_ID: int = -1 +BOE_ID: int = 0 +BPE_ID: int = 3 +OFFSET: int = 4 + +BYTE_UNITS: int = 256 + +RMSNorm = nn.RMSNorm + +logger = logging.getLogger() + +from .blt_args import ( + BaseTransformerArgs, + ByteLatentTransformerArgs, + GlobalTransformerArgs, + LocalDecoderArgs, + LocalModelArgs, + LMTransformerArgs, + +) + +if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: + flex_attention_comp = torch.compile(flex_attention) +else: + flex_attention_comp = None + + +def patch_reduce(h, max_num_patches, reduction, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + bs, seq_len, emb_dim = h.shape + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + + reduced_embs = torch.zeros( + (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device + ) + reduced_embs = reduced_embs.scatter_reduce( + src=h, + dim=1, + index=patch_ids, + reduce=reduction, + include_self=False, + ) + reduced_embs = reduced_embs[:, :max_num_patches, :] + + return reduced_embs + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): + """ + 0 0 0 1 0 0 0 1 0 0 0 + 0 1 0 0 0 1 0 0 0 0 0 + -> 4 4 3 2 4 5 + """ + mask = batch == eos_id + mask[:, -1] = True # virtual eos at the end of each row + + # 0 0 0 1 0 0 0 1 0 0 X + # 0 1 0 0 0 1 0 0 0 0 X + row, col = torch.where(mask) + + # row = 0, 0, 0, 1, 1, 1 + # col = 3, 7, 10, 1, 5, 10 + seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] + # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) + return [int(col[0].item() + 1)] + seqlens.tolist() + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "sdpa": + BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) + + if attn_bias_type == "causal": + return "causal" + + if BLT_SUPPRESS_ATTN_ERROR == 1: + return "causal" + else: + raise ValueError( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" + ) + elif attn_impl == "flex_attention": + return create_block_mask(causal_mask, None, None, seqlen, seqlen) + else: + raise NotImplementedError( + f"Attention {attn_impl} with {sliding_window} sliding window not implemented" + ) + + +class InitStdFactor(str, Enum): + DISABLED = "disabled" # Init std is divided by 1.0 + GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) + CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) + DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 + +def cross_entropy(pred, target, **kwargs): + return F.nll_loss( + F.log_softmax(pred.flatten(end_dim=-2).float(), -1), + target.flatten(end_dim=-1), + **kwargs, + ) + + +def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + rope_use_fp32_in_outer_product: bool = False, +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + if rope_use_fp32_in_outer_product: + t = t.to(torch.float32) + + freqs = torch.outer(t, freqs).float() + + cos, sin = freqs.cos(), freqs.sin() + + return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + seq_dim (int): Sequence dimension index. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= seq_dim < ndim + assert freqs_cis.shape == ( + x.shape[seq_dim], + x.shape[-3], + 2, + 2, + ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" + shape = [ + d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2]) + ] + [2, 2] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + seq_dim: int, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + freqs_cis = reshape_for_broadcast( + freqs_cis, xq_, seq_dim + ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 + xq_out = (xq_ * freqs_cis).sum(5).flatten(3) + xk_out = (xk_ * freqs_cis).sum(5).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed. +class RotaryEmbedding(torch.nn.Module): + """ + RotaryEmbedding Module + """ + + def __init__( + self, + theta: float, + head_dim: int, + max_seqlen: int = 1024, + rope_use_fp32_in_outer_product: bool = False, + ): + super().__init__() + + self.theta = theta + self.head_dim = head_dim + self.max_seqlen = max_seqlen + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + dim=head_dim, + end=max_seqlen, + theta=theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ), + persistent=False, + ) + + def reset_parameters(self): + self.freqs_cis[...] = precompute_freqs_cis( + dim=self.head_dim, + end=self.max_seqlen, + theta=self.theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ) + + def forward( + self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None + ): + """ + Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions + Args: + seqlen (int): Contiguous sequence length + tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen + + Returns: + Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis + """ + test = (seqlen is not None) or (tok_idx is not None) + assert test, "Should provide atleast seqlen or tok_idx" + if tok_idx is not None: + return self.freqs_cis[tok_idx] + elif seqlen is not None: + return self.freqs_cis[0:seqlen] + + +def _reshape_for_attn_bias( + attn_bias: None, + *tensors: torch.Tensor, +) -> list[torch.Tensor]: + to_transform = list(tensors) + if isinstance(attn_bias): + # could be `view` instead of reshape during training, but for inference + # have to reshape due to strides mismatch + to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform] + return to_transform + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + rope_theta: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + self.rope_theta = rope_theta + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + # B S D + bsz, seq_len, dim = x.shape + xq = self.wq(x.view_as(x)) + xk = self.wk(x.view_as(x)) + xv = self.wv(x.view_as(x)) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) + + # This condition helps us be easily compatible + # with inference by adding a pluggable KVCache + if hasattr(self, "kv_cache"): + xk, xv = self.kv_cache.update(xk, xv, tok_idx) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + if attn_impl == "flex_attention": + assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + output = flex_attention_comp(xq, xk, xv, block_mask=mask) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + elif attn_impl == "sdpa": + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + assert mask is None or isinstance(mask, (str, torch.Tensor)) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask if isinstance(mask, torch.Tensor) else None + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + else: + raise NotImplementedError( + f"Attention implementation {attn_impl} not supported" + ) + + output_reshaped = output.reshape(output_shape) + + output = self.wo(output_reshaped) + + return output + + def reset_parameters(self, init_std=None, factor=1.0): + init_std = init_std or (self.dim ** (-0.5)) / factor + + for w in [self.wq, self.wk, self.wv]: + nn.init.trunc_normal_( + w.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + nn.init.trunc_normal_( + self.wo.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + mp_size: int = 1, + ): + super().__init__() + + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + assert hidden_dim % mp_size == 0 + + self.dim = dim + self.hidden_dim = hidden_dim + + self.w1 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w3 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # B S D + x1 = self.w1(x.view_as(x)) + x3 = self.w3(x.view_as(x)) + output = self.w2(F.silu(x1) * x3) + return output + + def reset_parameters(self, init_std=None, factor=1.0): + in_init_std = init_std or (self.dim ** (-0.5)) / factor + out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor + + nn.init.trunc_normal_( + self.w1.weight, + mean=0.0, + std=in_init_std, + a=-3 * in_init_std, + b=3 * in_init_std, + ) + nn.init.trunc_normal_( + self.w2.weight, + mean=0.0, + std=out_init_std, + a=-3 * out_init_std, + b=3 * out_init_std, + ) + nn.init.trunc_normal_( + self.w3.weight, + mean=0.0, + std=in_init_std, + a=-3 * in_init_std, + b=3 * in_init_std, + ) + + +class TransformerBlock(nn.Module): + def __init__(self, args: BaseTransformerArgs): + super().__init__() + + assert (args.head_dim is not None) or ( + args.n_heads is not None + ), "Should specify at least head_dim or n_heads" + self.head_dim = args.head_dim or args.dim // args.n_heads + self.n_heads = args.n_heads or args.dim // args.head_dim + self.n_kv_heads = args.n_kv_heads or self.n_heads + + assert args.n_heads % self.n_kv_heads == 0 + assert args.dim % args.n_heads == 0 + + self.attention = Attention( + dim=args.dim, + head_dim=self.head_dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + rope_theta=args.rope_theta, + ) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + norm_x = self.attention_norm(x) + attn_out = self.attention( + norm_x, + freq_cis, + tok_idx=tok_idx, + mask=mask, + attn_impl=attn_impl, + ) + h = x + attn_out + h_norm = self.ffn_norm(h) + out = h + self.feed_forward(h_norm) + return out + + def init_weights(self, init_std=None, factor=1.0): + self.attention.reset_parameters(init_std, factor) + self.attention_norm.reset_parameters() + + self.feed_forward.reset_parameters(init_std, factor) + self.ffn_norm.reset_parameters() + + +class SequenceModelWithOutput(abc.ABC): + @abc.abstractmethod + def get_output_seq_len(self) -> int: + pass + + +class BaseTransformer(nn.Module, SequenceModelWithOutput): + def __init__(self, args: BaseTransformerArgs): + super().__init__() + self.dim = args.dim + self.init_base_std = args.init_base_std + self.attn_impl = args.attn_impl + self.attn_bias_type = args.attn_bias_type + self.init_std_factor = InitStdFactor(args.init_std_factor) + self.max_seqlen = args.max_seqlen + self.rope_embeddings = RotaryEmbedding( + theta=args.rope_theta, + head_dim=args.head_dim or args.dim // args.n_heads, + max_seqlen=args.max_seqlen, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + ) + self.eos_id = args.eos_id + + self.layers = nn.ModuleList() + for _ in range(args.n_layers): + self.layers.append(TransformerBlock(args)) + + def get_output_seq_len(self): + return self.max_seqlen + + def forward( + self, + h, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, str]] = None, + attn_impl: str = "sdpa", + ): + + freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + return h + + def init_weights(self): + self.rope_embeddings.reset_parameters() + for depth, layer in enumerate(self.layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(self.init_base_std, factor) + + +class LMTransformer( + BaseTransformer, + PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/blt", + # paper_url="https://arxiv.org/abs/2412.09871", + pipeline_tag="text-generation", + license="other", + license_name="fair-noncommercial-research-license", + license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", + coders={ + LMTransformerArgs: ( + lambda x: {"args": x.model_dump()}, + lambda data: LMTransformerArgs(**data), + ) + }, +): + def __init__(self, args: LMTransformerArgs): + super().__init__(args) + self.weight_tying = args.weight_tying + self.sliding_window = args.sliding_window + + assert args.vocab_size > 0 + + self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) + + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + + self.output = nn.Linear( + args.dim, + args.vocab_size, + bias=False, + ) + + if args.weight_tying: + self.output.weight = self.embeddings.tok_embeddings.weight + + def push_to_hub(self, *args, **kwargs): + raise ValueError( + "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." + ) + + def forward( + self, + token_values: torch.Tensor, + target: Optional[torch.Tensor] = None, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + attn_impl: str | None = None, + ): + if attn_impl is None: + attn_impl = self.attn_impl + bsz, seqlen = token_values.shape + + h = self.tok_embeddings(token_values) + + mask = ( + mask + if mask is not None + else create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=token_values, + eos_id=self.eos_id, + ) + ) + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + + logits = self.output(self.norm(h)) + if target is not None: + return cross_entropy(logits, target) + else: + return logits + + def reset_parameters(self, init_std=None): + self.norm.reset_parameters() + + def init_weights(self): + self.reset_parameters() + init_std = self.dim ** (-0.5) + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + super().init_weights() + + if not self.weight_tying: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + + + + + +class PatchingModeEnum(str, Enum): + entropy = "entropy" + bpe = "bpe" + bpe_patcher = "bpe_patcher" + space = "space" + static = "static" + byte = "byte" + + +class PatcherArgs(BaseModel): + patching_mode: PatchingModeEnum = PatchingModeEnum.entropy + patching_device: str = "cuda" + entropy_model_checkpoint_dir: str | None = None + realtime_patching: bool = False + threshold: float = 1.335442066192627 + threshold_add: float | None = None + max_patch_length: int | None = None + patch_size: float = 4.5 + patching_batch_size: int = 1 + device: str = "cuda" + monotonicity: bool = False + log_time: bool = False + + def build(self) -> "Patcher": + return Patcher(self) + +def rightpad(seq, pad_id, max_len): + return seq + [pad_id] * (max_len - len(seq)) + + +def check_non_zero_after_zero(tensor): + zero_mask = tensor == 0 + shifted_mask = torch.cat( + [ + torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), + zero_mask[:, :-1], + ], + dim=1, + ) + non_zero_after_zero = (tensor != 0) & shifted_mask + return non_zero_after_zero.any() + + +def to_device(entropy_model, device=None): + if device == "cuda": + rank = get_local_rank() + device = f"cuda:{rank}" + entropy_model = entropy_model.to(device) + return entropy_model, device + + +def split_large_numbers(lst, m): + new_lst = [] + for i in lst: + if i > m: + while i > m: + new_lst.append(m) + i -= m + new_lst.append(i) + else: + new_lst.append(i) + assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" + return new_lst + + +class Patcher: + def __init__(self, patcher_args: PatcherArgs): + self.patcher_args = patcher_args + self.patching_mode = patcher_args.patching_mode + self.realtime_patching = patcher_args.realtime_patching + if self.realtime_patching: + assert ( + patcher_args.entropy_model_checkpoint_dir is not None + ), "Cannot require realtime patching without an entropy model checkpoint" + maybe_consolidated = os.path.join( + patcher_args.entropy_model_checkpoint_dir, + "consolidated/consolidated.pth", + ) + if os.path.exists(maybe_consolidated): + state_path = maybe_consolidated + else: + state_path = os.path.join( + patcher_args.entropy_model_checkpoint_dir, "consolidated.pth" + ) + entropy_model, _ = load_entropy_model( + patcher_args.entropy_model_checkpoint_dir, + state_path, + ) + # entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) + entropy_model = entropy_model.to(patcher_args.patching_device) + self.entropy_model = entropy_model + else: + self.entropy_model = None + self.threshold = patcher_args.threshold + self.threshold_add = patcher_args.threshold_add + self.max_patch_length = patcher_args.max_patch_length + self.patch_size = patcher_args.patch_size + self.patching_batch_size = patcher_args.patching_batch_size + self.device = patcher_args.device + self.monotonicity = patcher_args.monotonicity + self.log_time = patcher_args.log_time + if self.log_time: + self.log = defaultdict(float) + + def patch( + self, + tokens: torch.Tensor, + include_next_token: bool = False, + preds: torch.Tensor | None = None, + entropies: torch.Tensor | None = None, + threshold: float = None, + ) -> torch.Tensor: + """ + tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched + Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.) + -> output tensor: [batch_size, max_num_patches] + each tensor is processed independently and gets right padded with zeros. + + Patching with the following modes: + 1. patching_mode = None: static patch size + 2. patching_mode = "entropy": + calculate entropy of each token, allocate patches so that the total + number of patches is the same as static patching but choose to begin + patches on tokens where the model is most uncertain (highest entropy). + + When threshold is provided, it uses the threshold to decide when to + start a new patch. + 3. patching_mode = "space": + use space like tokens to define the patches. + 4. patching_mode = "bpe": + use bpe delim tokens to define the patches. + + To correctly patch the last token, it may be necessary to include the next token in the patch + lengths calculations. This is controlled by the include_next_token argument. + """ + bs, seq_len = tokens.shape + seq_len_next_tok = seq_len + 1 if include_next_token else seq_len + scores = None + # STATIC + if self.patching_mode == PatchingModeEnum.byte: + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device + ) + else: + raise NotImplementedError(f"self.patching_mode {self.patching_mode}") + + # Apply any processing to patch lengths + if self.max_patch_length is not None: + # TODO: avoid going back to a list here. + patch_lengths = [ + split_large_numbers(pl, self.max_patch_length) + for pl in patch_lengths.tolist() + ] + max_len = max([len(pl) for pl in patch_lengths]) + patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] + patch_lengths = torch.tensor( + patch_lengths, dtype=tokens.dtype, device=tokens.device + ) + assert not check_non_zero_after_zero(patch_lengths) + # Find the last non-zero column index using argmax on a reversed version of the tensor + last_non_zero_col_reversed = ( + (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() + ) + # Slice the tensor up to the last non-zero column + patch_lengths = patch_lengths[ + :, : patch_lengths.shape[1] - last_non_zero_col_reversed + ] + assert ( + torch.sum(patch_lengths) + == tokens.numel() + include_next_token * tokens.shape[0] + ), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}" + if self.log_time: + self.log["postprocessing_patch_lengths"] += time.time() - s + self.log["tokens"] += patch_lengths.sum().item() + return patch_lengths, scores + + + +def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"): + with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: + reloaded = json.loads(fr.read()) + + torch.set_default_dtype(torch.bfloat16) + model_params = reloaded["entropy_model"] + logger.warning( + "Update checkpoint to load attn and sliding window args from checkpoint" + ) + entropy_model_args = LMTransformerArgs( + dim=model_params["dim"], + n_layers=model_params["n_layers"], + n_heads=model_params["n_heads"], + max_seqlen=model_params["max_seqlen"], + ffn_dim_multiplier=model_params["ffn_dim_multiplier"], + vocab_size=model_params["vocab_size"], + attn_bias_type="local_block_causal", + attn_impl="xformers", + sliding_window=512, + ) + entropy_model = LMTransformer(entropy_model_args) + + entropy_model.load_state_dict( + torch.load(state_dict_path, map_location=device)["model"], strict=False + ) + entropy_model.to(device) + entropy_model = entropy_model.eval() + # no grads for the model: + for param in entropy_model.parameters(): + param.requires_grad = False + return entropy_model, entropy_model_args + + +def get_encoder_dim_token_emb(args): + if args.dim_token is not None: + dim_token_emb = args.dim_token + elif args.use_local_encoder_transformer: + dim_token_emb = args.dim_local_encoder + else: + dim_token_emb = args.dim_global // args.patch_size + return dim_token_emb + + +def get_encoder_dim_patch_emb(args): + dim_patch_emb = None + if args.cross_attn_encoder: + if args.cross_attn_init_by_pooling: + dim_patch_emb = args.dim_local_encoder + else: + dim_patch_emb = args.dim_global + return dim_patch_emb + + +def get_global_dim_patch_emb(args): + dim_token_emb = get_encoder_dim_token_emb(args) + if args.cross_attn_encoder: + dim_patch_emb = dim_token_emb * args.cross_attn_k + elif ( + args.downsampling_by_pooling is None + or not args.downsampling_by_pooling + or len(args.downsampling_by_pooling) == 0 + ): + dim_patch_emb = dim_token_emb * args.patch_size + else: + dim_patch_emb = dim_token_emb * sum( + [ + pooling in args.downsampling_by_pooling + for pooling in ["avg", "min", "max"] + ] + ) + return dim_patch_emb + + +def get_decoder_dim_token_emb(args): + if args.share_encoder_decoder_emb: + dim_token_emb = get_encoder_dim_token_emb(args) + elif args.dim_token is not None: + dim_token_emb = args.dim_token + else: + dim_token_emb = args.dim_local_decoder + return dim_token_emb + + +def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]: + if ngram_to_size_str is None: + return None + ngram_to_size = {} + for entry in ngram_to_size_str.split(","): + ngram, size = entry.split(":") + ngram = int(ngram) + size = int(size) + ngram_to_size[ngram] = size + return ngram_to_size + + +def fill_tokens(tokens, patch_size, fill_id): + batch_size, seq_len = tokens.shape + if seq_len % patch_size == 0: + return tokens + else: + remaining = patch_size - seq_len % patch_size + final_padding = tokens.new(batch_size, remaining).fill_(fill_id) + return torch.cat((tokens, final_padding), dim=1) + + +def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len): + first_patch_length = patch_lengths[0, 0] + assert torch.all( + first_patch_length == patch_lengths[:, 0] + ), "first patch should always be the same size (1 for dynamic, patch_size for static)." + assert ( + first_patch_length - nb_boe == 1 + ), f"First patch (patch length: {first_patch_length}) should have one non-boe token (boe toks: {nb_boe})" + # Remove first patch from patch_ids for local decoder inputs and shift the last patch. + # decoder_patch_lengths = patch_lengths[:, 1:].clone() + # decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1) + decoder_patch_lengths = patch_lengths[:, 1:] + assert ( + decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0] + == patch_lengths.sum() + ), f"{decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]} != {patch_lengths.sum()}" + assert torch.all(decoder_patch_lengths >= 0), f"{decoder_patch_lengths}" + decoder_patch_ids = patch_ids_from_lengths( + patch_lengths=decoder_patch_lengths, seq_len=seq_len + ) + return decoder_patch_ids + + +primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, +] + + +def rolling_polynomial_hash(t, hash_func_nb: int = 0): + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) + prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) + return torch.sum(t * prime_powers, dim=-1) + +def byte_group_hash_function( + x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 +): + """ + Returns a hash of the input x and maps it to a value in the range [0, max_hash]. + + expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. + returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. + + Note: max hash can make a big difference on the number of collisions. + """ + with torch.no_grad(): + bs, seq_len = x.shape + # x_numpy = x.numpy() + # hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False) + # for i in range(bs): + # for j in range(seq_len): + # start = max(j, j-group_size+1) + # end = j+1 + # hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash) + + prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) + x = torch.cat([prefix, x], dim=1) + windows = x.unfold(1, group_size, 1) + # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values_range = hashes % max_hash + hash_values_range.requires_grad = False + return hash_values_range + + +def create_patch_mask_from_ids( + patch_ids, num_patches, window=None, patches_as_queries=False +): + """ + Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) + is True if the patch id at position (i, j) is less than or equal to k. + Args: + patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. + num_patches (int): Total number of patches. + window (int): If not None, only considers patches within a window of size window. + patches_as_queries (bool): If True, the patches are used as queries + Returns: + torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. + """ + bs, seq_len = patch_ids.shape + if not patches_as_queries: + q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) + kv_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(bs, seq_len, num_patches) + ) + else: + kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) + q_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(bs, num_patches, seq_len) + ) + if window is None: + mask = q_ids == kv_ids + else: + mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window) + return mask + + +def cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=1, + window=None, + block_mask=True, +): + bs = patch_ids.shape[0] + with torch.no_grad(): + # Create the patch mask + cross_mask = create_patch_mask_from_ids( + patch_ids, + patch_lengths.shape[1], + window=window, + patches_as_queries=patches_as_queries, + ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1) + q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N + kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k + assert cross_mask.shape == ( + bs, + q_len, + kv_len, + ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" + if block_mask: + + def patch_mask(b, h, q_idx, kv_idx): + return cross_mask[b, q_idx, kv_idx] + + block_mask = create_block_mask( + patch_mask, + B=bs, + H=None, + Q_LEN=q_len, + KV_LEN=kv_len, + _compile=True, + ) + return block_mask + else: + return torch.where( + cross_mask, torch.tensor(0.0), torch.tensor(float("-inf")) + ).unsqueeze( + 1 + ) # [bs, 1, q_len, kv_len] + + +def get_blt_input( + tokens: torch.Tensor, + enforce_patch_size_multiple: bool, + nb_boe: torch.Tensor, + patch_size: int, + boe_id: int, +): + """ + This function returns X_et, X_gt and X_dt, the encoder, global, and decoder + tokens respectively. + + Consider the input and target sequences: + X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13] + Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14] + with patch_size=4 + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + + Current without boe: + X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] + X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + --> lag fix: + X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]] + X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]] + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + Dynamic (current): + X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + entropy patching: + input: 7, bos, 9, 10 + pred (high entropy): eos, 8, 10, eos + + X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos] + X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]] + X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + --> lag fix no boe (force single byte first patch): + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + input: 4, 7, bos, 9, 10 + pred (high entropy): 5, eos, 8, 10, eos + + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + Handle the last byte properly. + patch_lengths = [1, 1, 3, 2, 2 1 2 2 1] + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]] + + + bpe delim + X_et = [[3,4,5,6,7,,eos,bos,,8,9,,10,,eos,bos,11,12] + X_g = [[3], [4,5,6,7,], [eos,bos,], .. + X_dt = [[3,4,5,6,7], [,eos,bos], [,bos,8], .. + Y = [4,5,6,7,, eos,bos, 8,9,, .. + + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + """ + batch_size, seq_len = tokens.shape + local_encoder_tokens = tokens + local_decoder_tokens = tokens + + if nb_boe > 0: + padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id) + local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1) + # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id) + + # create global tokens, contains boe tokens and eos + # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size) + # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:] + # global_tokens += global_tokens.eq(0).int() * boe_id + # TODO: fix this when we want to use block causal in the global. + + if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0: + local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + + return local_encoder_tokens, None, local_decoder_tokens + + +def patch_ids_from_lengths(patch_lengths, seq_len): + bs, num_patches = patch_lengths.shape + # Create a tensor of cumulative sums of the patch lengths + cum_d = torch.cat( + [ + torch.zeros(bs, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1), + ], + dim=-1, + ) + patch_ids = (cum_d.unsqueeze(-1) <= torch.arange(seq_len, device=cum_d.device)).sum( + dim=-2 + ) - 1 + assert not ( + torch.max(patch_ids) > patch_lengths.shape[-1] or torch.min(patch_ids) < 0 + ), f"{torch.max(patch_ids)} > {patch_lengths.shape[-1]} or {torch.min(patch_ids)} < 0" + return patch_ids + + + +def create_global_transformer(args: ByteLatentTransformerArgs): + global_args = args.model_copy( + deep=True, + update=dict( + dim=args.dim_global, + n_layers=args.n_layers_global, + n_heads=args.n_heads_global, + n_kv_heads=args.n_kv_heads_global, + local_attention_window_len=None, + dim_token_emb=get_global_dim_patch_emb(args), + dim_patch_emb=None, + cross_attn_encoder=False, + cross_attn_decoder=False, + ), + ) + + return GlobalTransformer(global_args) + + +class LocalModelBase(nn.Module): + def __init__(self, args: LocalModelArgs): + super().__init__() + + self.dim = args.dim + self.dropout = args.dropout + self.vocab_size = args.vocab_size + self.patch_size = args.patch_size + self.dim_patch_emb = args.dim_patch_emb + + self.attn_impl = args.attn_impl + self.sliding_window = args.sliding_window + self.use_rope = args.use_rope + self.init_std_factor = args.init_std_factor + self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) + self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) + self.cross_attn_k = getattr(args, "cross_attn_k", None) + self.eos_id = args.eos_id + + self.boe_id = BOE_ID + + self.layers = nn.ModuleList( + [TransformerBlock(args) for _ in range(args.n_layers)] + ) + + if not self.use_rope: + self.pos_embeddings = nn.Embedding(args.max_length, args.dim) + else: + self.rope = RotaryEmbedding( + theta=args.rope_theta, + head_dim=args.head_dim or args.dim // args.n_heads, + max_seqlen=args.max_seqlen, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + ) + self.pos_embeddings = None + + self.token_embedding_projection = ( + nn.Linear(args.dim_token_emb, args.dim, bias=False) + if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim + else None + ) + + self.patch_embedding_projection = self._create_patch_projection(args) + + def _should_create_patch_projection(self, args: LocalModelArgs): + dimension_mismatch = ( + getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim + ) + + # Check cross attention conditions + cross_attn_conditions = ( + args.cross_attn_encoder and args.cross_attn_init_by_pooling + ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling) + + return dimension_mismatch or cross_attn_conditions + + def _create_patch_projection(self, args): + if not self._should_create_patch_projection(args): + return None + + output_dim = args.dim_token_emb * (self.cross_attn_k or 1) + + return nn.Linear( + in_features=args.dim_patch_emb, + out_features=output_dim, + bias=False, + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + return embeds + else: + return self.tok_embeddings(tokens) + + def init_weights(self, init_std=None): + self.rope.reset_parameters() + if hasattr(self, "norm"): + self.norm.reset_parameters() + + init_std = init_std or (self.dim ** (-0.5)) + if hasattr(self, "tok_embeddings"): + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + if self.pos_embeddings is not None: + nn.init.trunc_normal_( + self.pos_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + for depth, layer in enumerate(self.layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(None, factor) + + if hasattr(self, "output"): + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if self.token_embedding_projection is not None: + nn.init.trunc_normal_( + self.token_embedding_projection.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if self.patch_embedding_projection is not None: + patch_emb_std = self.dim_patch_emb ** (-0.5) + nn.init.trunc_normal_( + self.patch_embedding_projection.weight, + mean=0.0, + std=patch_emb_std, + a=-3 * patch_emb_std, + b=3 * patch_emb_std, + ) + + if self.cross_attn_layers is not None: + for depth, layer in enumerate(self.cross_attn_layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(None, factor) + + +class LocalEncoder(LocalModelBase): + def __init__(self, args: LocalModelArgs): + super().__init__(args) + + self.apply_transformer = args.use_local_encoder_transformer + self.downsampling_by_pooling = args.downsampling_by_pooling + self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None + self.cross_attn_encoder = args.cross_attn_encoder + self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder + self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling + self.cross_attn_nheads = args.cross_attn_nheads + + self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim) + + if self.cross_attn_encoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + CrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=args.norm_eps, + ) + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + assert ( + self.expects_hash_embeddings + ), "Not expecting embeddings to be passed." + return embeds + else: + return self.tok_embeddings(tokens) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ """ + bs, seqlen = tokens.shape + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) + + h = self.apply_embedding(tokens, embeds) + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + + for i, layer in enumerate(self.layers): + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) + # check if cross attention should be applied to either all layer or only the last layer + if self.cross_attn_encoder and ( + i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder + ): + patch_embeds = self.apply_cross_attention( + h, patch_embeds, i, bs, num_patches, patch_ids, cross_mask + ) + + h_residual = patch_embeds if self.cross_attn_encoder else None + return (h, h_residual), cache + + def apply_cross_attention( + self, h, patch_embeds, layer_idx, bs, num_patches, patch_ids, cross_mask + ): + # apply pooling and project + if self.cross_attn_init_by_pooling and patch_embeds is None: + # patch_embeds = downsample( + # h, + # num_patches, + # patch_ids=patch_ids, + # downsampling_by_pooling=self.downsampling_by_pooling, + # patch_size=self.patch_size, + # ) + patch_embeds = patch_reduce(h, num_patches, "amax", patch_ids) + if self.patch_embedding_projection is not None: + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape( + bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim + ) + + layer_idx = layer_idx if self.cross_attn_all_layers_encoder else 0 + patch_embeds_cross = self.cross_attn_layers[layer_idx]( + x=patch_embeds, + kv=h, + mask=cross_mask, + ) + return patch_embeds + patch_embeds_cross + + +class LocalDecoder(LocalModelBase): + def __init__(self, args: LocalModelArgs): + super().__init__(args) + + # Model configuration flags + self.cross_attn_decoder = args.cross_attn_decoder + self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder + self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling + self.cross_attn_nheads = args.cross_attn_nheads + + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + + if self.cross_attn_decoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + CrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=args.norm_eps, + ) + ) + + self.output = nn.Linear( + self.dim, + args.vocab_size, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor], + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + bs, seqlen = tokens.shape + assert embeds is not None, "Embeddings must be provided" + + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) + + h = embeds + + if self.patch_embedding_projection is not None: + assert patch_embeds is not None, "Patch embeddings must be passed." + patch_embeds = self.patch_embedding_projection(patch_embeds) + if self.cross_attn_k is not None: + patch_embeds = patch_embeds.reshape( + bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim + ) + + if patch_embeds is not None and not self.cross_attn_decoder: + h = h + patch_embeds + + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + for i, layer in enumerate(self.layers): + if self.cross_attn_decoder and ( + i == 0 or self.cross_attn_all_layers_decoder + ): + # Use cross attention to extract info from patch_embeds into h + h_cross = self.cross_attn_layers[i]( + x=h, + kv=patch_embeds, + mask=cross_mask, + ) + h = h + h_cross + + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) + + h_preds = self.norm(h) + h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) + h_preds = self.output(h_preds) + h_preds = h_preds.float() + return h_preds, cache + + +class CrossAttention(nn.Module): + """ + CrossAttention block to attend to the encoder states from the decoder. + Rope is not supported. + """ + + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + kv: torch.Tensor, + mask: Optional[Union[BlockMask, str]] = None, + ) -> torch.Tensor: + # B S D + bsz, seq_len, _ = x.shape + _, slen_kv, _ = kv.shape + x_norm = self.cross_attn_norm_q(x) + kv = self.cross_attn_norm_kv(kv) + + xq = self.wq(x_norm) + xk = self.wk(kv) + xv = self.wv(kv) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + output = flex_attention_comp(xq, xk, xv, block_mask=mask) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + output = self.wo(output.reshape(output_shape)) + + return x + output + + def init_weights(self, base_std: float, factor: float = 1.0): + std = base_std or (self.dim ** (-0.5)) / factor + + nn.init.trunc_normal_( + self.wq.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wk.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wv.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wo.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + self.cross_attn_norm_q.reset_parameters() + self.cross_attn_norm_kv.reset_parameters() + + +class GlobalTransformer(BaseTransformer): + def __init__(self, args: BaseTransformerArgs): + super().__init__(args) + self.dropout = args.dropout + self.eos_id = args.eos_id + self.dim_token_emb = args.dim_token_emb + + self.token_embedding_projection = None + if args.dim_token_emb is not None and args.dim_token_emb != self.dim: + self.token_embedding_projection = nn.Linear( + args.dim_token_emb, + args.dim, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + embeds: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ + Similar to BaseTransformer.forward, but with an additional embeds argument + and projection to the token space. + """ + bs, seqlen = tokens.shape + + h = embeds + + mask = ( + mask + if mask is not None + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) + ) + + if self.token_embedding_projection is not None and h.shape[-1] != self.dim: + h = self.token_embedding_projection(h) + + h = F.dropout(h, p=self.dropout, training=self.training) + + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) + return h, cache + + def init_weights(self): + super().init_weights() + std = self.dim_token_emb ** (-0.5) + if self.token_embedding_projection is not None: + nn.init.trunc_normal_( + self.token_embedding_projection.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + +class EmbeddingType(Enum): + HASH_TOK = auto() + NGRAM = auto() + + +def init_embeddings( + args, + embedding_type: EmbeddingType, + local_encoder_dim: int, + encoder_hash_byte_group_size: list = None, +): + if ( + embedding_type == EmbeddingType.HASH_TOK + and args.encoder_hash_byte_group_size is None + ): + return None + if embedding_type == EmbeddingType.NGRAM and args.encoder_ngram_to_size_str is None: + return None + + embeddings = [] + + if embedding_type == EmbeddingType.HASH_TOK: + emb_dim = local_encoder_dim + encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab + for _ in range(args.encoder_hash_byte_group_nb_functions): + for _ in encoder_hash_byte_group_size: + embeddings.append( + nn.Embedding( + encoder_hash_byte_group_vocab, + emb_dim, + ) + ) + + elif embedding_type == EmbeddingType.NGRAM: + encoder_ngram_to_size = parse_ngram_to_size(args.encoder_ngram_to_size_str) + emb_dim = local_encoder_dim + OFFSET = 4 # This should be passed as parameter if it's variable + for ngram_vocab_size in encoder_ngram_to_size.values(): + embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim)) + + return nn.ModuleList(embeddings) + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """ + Compute embeddings using hash token embeddings. + + Args: + local_encoder_tokens: Input tokens tensor + local_encoder: Encoder object with tok_embeddings method + encoder_hash_tok_embedding: ModuleList of hash token embeddings + encoder_hash_byte_group_nb_functions: Number of hash functions + encoder_hash_byte_group_size: List of byte group sizes + encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings + + Returns: + torch.Tensor: Combined embeddings + """ + if encoder_hash_tok_embedding is None: + return None + + local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens) + + i = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for byte_group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, + byte_group_size, + hash_func_nb=func_nb, + max_hash=encoder_hash_byte_group_vocab, + ) + hash_tok_embedding = encoder_hash_tok_embedding[i] + local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) + i += 1 + + assert i == len(encoder_hash_tok_embedding) + return local_encoder_embeds + + +class ByteLatentTransformer( + nn.Module, + SequenceModelWithOutput, + PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/blt", + # paper_url="https://arxiv.org/abs/2412.09871", + pipeline_tag="text-generation", + license="other", + license_name="fair-noncommercial-research-license", + license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", + coders={ + ByteLatentTransformerArgs: ( + lambda x: {"args": x.model_dump()}, + lambda data: ByteLatentTransformerArgs(**data), + ) + }, +): + """ + The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences + by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, + and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for + improved performance and inference efficiency. + """ + + def __init__(self, args: ByteLatentTransformerArgs): + super().__init__() + + # General configuration + self.weight_tying = args.weight_tying + self.patch_size = args.patch_size + self.patching_mode = args.patching_mode + self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( + BOE_ID, + BOS_ID, + PAD_ID, + EOS_ID, + ) + self.downsampling_by_pooling = args.downsampling_by_pooling + self.patching_threshold = args.patching_threshold + self.dim = args.dim + self.init_base_std = args.init_base_std + self.init_std_factor = InitStdFactor(args.init_std_factor) + self.max_seqlen = args.max_seqlen + + # Cross attention configuration + self.cross_attn_encoder = args.cross_attn_encoder + self.cross_attn_decoder = args.cross_attn_decoder + self.cross_attn_k = args.cross_attn_k + self.cross_attn_window_encoder = args.cross_attn_window_encoder + self.cross_attn_window_decoder = args.cross_attn_window_decoder + self.cross_attn_use_flex_attention = args.cross_attn_use_flex_attention + + # Encoder hash configuration + self.encoder_hash_byte_group_size = args.encoder_hash_byte_group_size + self.encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab + self.encoder_hash_byte_group_nb_functions = ( + args.encoder_hash_byte_group_nb_functions + ) + + # ByteLatent modules + local_encoder_args = LocalModelArgs( + # Updated args + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_encoder=args.cross_attn_encoder, + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, + ) + self.local_encoder = LocalEncoder(local_encoder_args) + local_decoder_args = LocalModelArgs( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_encoder=False, + cross_attn_decoder=args.cross_attn_decoder, + cross_attn_init_by_pooling=False, # states are already defined + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, + ) + + self.global_transformer = create_global_transformer(args) + self.local_decoder = LocalDecoder(local_decoder_args) + self.encoder_hash_tok_embedding = init_embeddings( + args, + EmbeddingType.HASH_TOK, + local_encoder_dim=self.local_encoder.dim, + encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, + ) + self.encoder_ngram_embedding = init_embeddings( + args, + EmbeddingType.NGRAM, + local_encoder_dim=self.local_encoder.dim, + encoder_hash_byte_group_size=None, + ) + + # Encoder ngram embedding tables + self.encoder_ngram_embedding = None + if args.encoder_enable_byte_ngrams: + self.encoder_ngram_embedding = nn.ModuleList() + assert args.ngram_vocab_sizes is not None + self.encoder_ngram_to_size = parse_ngram_to_size( + args.encoder_ngram_to_size_str + ) + ngram_emb_dim = self.local_encoder.dim + for ngram_vocab_size in self.encoder_ngram_to_size.values(): + self.encoder_ngram_embedding.append( + nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim) + ) + + # Output layer + assert args.vocab_size > 0, "vocab_size must be greater than 0" + + # Patcher module + if args.patch_in_forward: + self.patcher = Patcher( + PatcherArgs( + patch_size=args.patch_size, + patching_mode=args.patching_mode, + patching_threshold=args.patching_threshold, + patching_threshold_add=args.patching_threshold_add, + monotonicity=args.monotonicity, + max_patch_length=args.max_patch_length, + ) + ) + + def push_to_hub(self, *args, **kwargs): + raise ValueError( + "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." + ) + + def get_output_seq_len(self): + return self.max_seqlen + + def forward( + self, + tokens: torch.Tensor, + patch_lengths: Optional[torch.Tensor] = None, + ngram_ids: Optional[torch.Tensor] = None, + ): + # Ensure ngram_ids is either a tensor or None + assert ( + isinstance(ngram_ids, torch.Tensor) or ngram_ids is None + ), f"ngram_ids must be a tensor or None, but was: {type(ngram_ids)}" + + bs, N = tokens.shape # Batch size and sequence length + + # Get megabyte inputs + nb_boe = int(0 if self.patching_mode != "" else self.patch_size - 1) + local_encoder_tokens, _, local_decoder_tokens = get_blt_input( + tokens=tokens, + enforce_patch_size_multiple=False, + nb_boe=nb_boe, + patch_size=self.patch_size, + boe_id=self.boe_id, + ) + + # Patching + if patch_lengths is None: + assert ( + getattr(self, "patcher", None) is not None + ), "Patcher not defined and no patch_lengths passed." + patch_lengths, tok_scores = self.patcher.patch( + local_encoder_tokens, + include_next_token=True, + threshold=self.patcher.threshold, + ) + else: + if nb_boe > 0: + patch_lengths[:, 0] += nb_boe + + assert torch.min(patch_lengths) >= 0 + + # Generate patch IDs from patch_lengths + patch_ids = patch_ids_from_lengths( + patch_lengths, local_encoder_tokens.shape[-1] + ) + assert torch.max(patch_ids) + 1 <= torch.max( + (patch_lengths != 0).sum(dim=-1) + ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" + + cross_attn_mask_enc = None + # Cross-attention encoder + if self.cross_attn_encoder: + cross_attn_mask_enc = cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=True, + cross_attn_k=self.cross_attn_k, + window=self.cross_attn_window_encoder, + block_mask=self.cross_attn_use_flex_attention, + ) + + # Hashing and embedding + local_encoder_embeds = compute_hash_embeddings( + local_encoder_tokens=local_encoder_tokens, + local_encoder=self.local_encoder, + encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, + encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab, + ) + + # N-gram table embeddings + if self.encoder_ngram_embedding is not None: + assert ngram_ids is not None, "ngram_ids must be provided" + if local_encoder_embeds is None: + local_encoder_embeds = self.local_encoder.tok_embeddings( + local_encoder_tokens + ) + assert len(ngram_ids) == len( + self.encoder_ngram_embedding + ), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(self.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}" + for i in range(ngram_ids.shape[0]): + ngram_embedding = self.encoder_ngram_embedding[i] + ngram_embeds = ngram_embedding(ngram_ids[i]) + assert ( + local_encoder_embeds.shape == ngram_embeds.shape + ), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}" + local_encoder_embeds = local_encoder_embeds + ngram_embeds + + # Local encoder + (h_encoder, h_cross), cache_encoder = self.local_encoder( + tokens=local_encoder_tokens, + embeds=local_encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + + # Downsampling + h = h_cross.view(bs, patch_lengths.shape[1], -1) + + # Global transformer + global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id) + rows, cols = torch.where(local_encoder_tokens == self.eos_id) + eos_patch_ids = patch_ids[rows, cols] + global_tokens[rows, eos_patch_ids] = self.eos_id + + h, _ = self.global_transformer( + embeds=h, + tokens=global_tokens, + ) + + # Unpatching + dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :] + + # Generate decoder patch IDs + decoder_patch_ids = decoder_patch_ids_from_lengths( + patch_lengths, nb_boe, local_decoder_tokens.shape[-1] + ) + assert ( + torch.max(decoder_patch_ids) + 1 <= h.shape[1] + ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" + assert ( + decoder_patch_ids.shape[1] == dec_embeds.shape[1] + ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" + + # Cross-attention decoder + if not self.cross_attn_decoder: + h = torch.gather( + h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + ) + cross_attn_mask_dec = None + assert local_decoder_tokens.shape == h.shape[:-1] + else: + cross_attn_mask_dec = cross_attn_mask( + decoder_patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=self.cross_attn_k, + window=self.cross_attn_window_decoder, + block_mask=self.cross_attn_use_flex_attention, + ) + + # Local decoder + output, _ = self.local_decoder( + embeds=dec_embeds, + patch_embeds=h, + tokens=local_decoder_tokens, + cross_mask=cross_attn_mask_dec, + ) + return output + + def init_weights(self): + self.local_encoder.init_weights() + self.global_transformer.init_weights() + self.local_decoder.init_weights() + + emb_std = self.local_encoder.dim ** (-0.5) + for emb in self.encoder_hash_tok_embedding: + nn.init.trunc_normal_( + emb.weight, + mean=0.0, + std=emb_std, + a=-3 * emb_std, + b=3 * emb_std, + ) \ No newline at end of file diff --git a/src/transformers/models/blt_wip/tokenizers/__init__.py b/src/transformers/models/blt_wip/tokenizers/__init__.py new file mode 100644 index 000000000000..71ca4b12c770 --- /dev/null +++ b/src/transformers/models/blt_wip/tokenizers/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/src/transformers/models/blt_wip/tokenizers/abstract_tokenizer.py b/src/transformers/models/blt_wip/tokenizers/abstract_tokenizer.py new file mode 100644 index 000000000000..f827302aaa4e --- /dev/null +++ b/src/transformers/models/blt_wip/tokenizers/abstract_tokenizer.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import abc + + +class Tokenizer(abc.ABC): + @abc.abstractmethod + def encode(self, text: str, add_bos: bool, add_eos: bool): + pass + + @abc.abstractmethod + def decode(self, tokens: list[int]): + pass + + @abc.abstractmethod + def get_token_offsets( + self, text: str, tokens: list[int] | None = None + ) -> tuple[list[str], list[int]]: + """Return the offsets of the tokens in the original text. Only used for evaluation.""" + pass + + @abc.abstractmethod + def get_vocab_size(self) -> int: + pass diff --git a/src/transformers/models/blt_wip/tokenizers/blt_tokenizer.py b/src/transformers/models/blt_wip/tokenizers/blt_tokenizer.py new file mode 100644 index 000000000000..6d874d910c11 --- /dev/null +++ b/src/transformers/models/blt_wip/tokenizers/blt_tokenizer.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import re + +from .abstract_tokenizer import Tokenizer +from .sentence_piece_tokenizer import SentencePieceTokenizer + + +SEP = " " +BOS_ID: int = 1 +EOS_ID: int = 2 +PAD_ID: int = -1 +BOE_ID: int = 0 +BPE_ID: int = 3 +OFFSET: int = 4 + +BYTE_UNITS: int = 256 + + +def convert_to_bytes(s): + # check if the output is a bytes like object of the format <0x00> + if re.match(r"<0x[0-9a-fA-F]+>", s): + return bytes.fromhex(s[3:-1]) + else: + return bytes(s, "utf-8", errors="ignore") + + +def text2bytes_bpe_delims( + text: str, + *, + bpe_tokenizer, + bpe_id: int, + offsetting_special_char: int, + add_bos: bool, + add_eos: bool, +): + cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos) + # merge the leading space tokens + leading_space_tokens = [] + other_bpe_tokens = [] + leading = True + for token in cur_bpe: + bpe_str = bpe_tokenizer.sp_model.id_to_piece(token) + if leading and all(c == "▁" for c in bpe_str): + leading_space_tokens.append(bpe_str) + else: + leading = False + other_bpe_tokens.append(bpe_str) + cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens + + # Remove the '▁' characters + bpe_strs = [] + for i, bpe_str in enumerate(cur_bpe_strs): + if ( + len(bpe_strs) <= 1 + and all([c == " " for s in bpe_strs for c in s]) + and not all(c == "▁" for c in bpe_str) + ): + # Remove leading space for first non space token. + bpe_str = bpe_str.replace("▁", "") + elif i == 0 and all(c == "▁" for c in bpe_str): + bpe_str = " " * (len(text) - len(text.lstrip(" "))) + else: + bpe_str = bpe_str.replace("▁", " ") + if len(bpe_str) > 0: + bpe_strs.append(bpe_str) + ex_seq = [] + # Convert bpe tokens to bytes + for s in bpe_strs: + byte_chunk = convert_to_bytes(s) + proc_chunk = [int(unit) for unit in byte_chunk] + ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk) + + return ex_seq + + +class BltTokenizer(Tokenizer): + def __init__( + self, + *, + vocab_size_unit_1: int = BYTE_UNITS, + bpe_delim: bool = False, + bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", + add_bos: bool = True, + add_eos: bool = True, + ): + self.add_bos = add_bos + self.add_eos = add_eos + self.vocab_size_unit_1 = vocab_size_unit_1 + self.boe_id = BOE_ID + self.bos_id = BOS_ID + self.eos_id = EOS_ID + self.pad_id = PAD_ID + self.bpe_id = BPE_ID + self.bpe_tokenizer_path = bpe_tokenizer_path + if bpe_delim: + self.bpe_tokenizer = SentencePieceTokenizer( + model_path=self.bpe_tokenizer_path + ) + else: + self.bpe_tokenizer = None + self.bpe_delim = bpe_delim + self.offsetting_special_char = OFFSET + self.vocab_size_unit_1 = vocab_size_unit_1 + self.n_words = vocab_size_unit_1 + self.offsetting_special_char + + def get_vocab_size(self) -> int: + return self.n_words + + def encode( + self, text: str, add_bos: bool | None = None, add_eos: bool | None = None + ): + if add_bos is None: + add_bos = self.add_bos + if add_eos is None: + add_eos = self.add_eos + + if self.bpe_delim: + tokens = text2bytes_bpe_delims( + text, + bpe_tokenizer=self.bpe_tokenizer, + bpe_id=self.bpe_id, + offsetting_special_char=self.offsetting_special_char, + add_bos=False, + add_eos=False, + ) + else: + tokens = bytes(text, encoding="utf-8", errors="ignore") + + # Offsetting + tokens = [int(unit) + self.offsetting_special_char for unit in tokens] + + if add_bos: + tokens.insert(0, self.bos_id) + if add_eos: + tokens.append(self.eos_id) + + return tokens + + def decode(self, tokens: list[int], cut_at_eos: bool = False): + if cut_at_eos: + for k, t in enumerate(tokens): + if t == self.eos_id: + tokens = tokens[: k + 1] + break + return bytes( + [ + tok - self.offsetting_special_char + for tok in tokens + if tok - self.offsetting_special_char >= 0 + ] + ).decode("utf-8", errors="ignore") + + def get_token_offsets(self, text: str, tokens: list[int] | None = None): + # TODO: Figure out what this does + raise NotImplementedError() diff --git a/src/transformers/models/blt_wip/tokenizers/sentence_piece_tokenizer.py b/src/transformers/models/blt_wip/tokenizers/sentence_piece_tokenizer.py new file mode 100644 index 000000000000..f789ae77d7fa --- /dev/null +++ b/src/transformers/models/blt_wip/tokenizers/sentence_piece_tokenizer.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import logging +import os + +try: + from sentencepiece import SentencePieceProcessor + + has_sp = True +except ImportError: + has_sp = False + +from .abstract_tokenizer import Tokenizer + +logger = logging.getLogger(__name__) + + +class SentencePieceTokenizer(Tokenizer): + def __init__( + self, model_path: str, add_bos: bool = True, add_eos: bool = True + ) -> None: + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + logger.info(f"Reloaded SentencePiece model from {model_path}") + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + self.add_bos = add_bos + self.add_eos = add_eos + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def get_vocab_size(self) -> int: + return self.n_words + + def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None): + if add_bos is None: + add_bos = self.add_bos + + if add_eos is None: + add_eos = self.add_eos + assert type(s) is str + tokens = ( + [self.bos_id] * add_bos + self.sp_model.encode(s) + [self.eos_id] * add_eos + ) + return tokens + + def decode(self, tokens: list[int]): + return self.sp_model.decode(tokens) + + def get_token_offsets( + self, text: str, tokens: list[int] | None = None + ) -> tuple[list[str], list[int]]: + pieces = self.sp_model.encode_as_immutable_proto(text).pieces + substrs = [p.surface for p in pieces] + offsets = [p.begin for p in pieces] + return substrs, offsets From 62019471c7adeae4a4f39ddcf5dfc57b05940b63 Mon Sep 17 00:00:00 2001 From: itazap Date: Thu, 5 Jun 2025 10:31:13 +0200 Subject: [PATCH 002/139] cpu version --- src/demo_hf.py | 38 +- src/transformers/models/blt_wip/blt_args.py | 2 + .../models/blt_wip/modeling_blt_wip.py | 390 +++++++++++++----- 3 files changed, 296 insertions(+), 134 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index 85cc6198f6df..f123d0241bc4 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -1,11 +1,3 @@ -import os - -import torch -import typer - -from transformers.models.blt_wip.modeling_blt_wip import ByteLatentTransformer, ByteLatentTransformerArgs -from transformers.models.blt_wip.tokenizers.blt_tokenizer import BltTokenizer - from huggingface_hub import hf_hub_download import json @@ -14,7 +6,7 @@ import torch -from transformers.models.blt_wip.modeling_blt_wip import Patcher, ByteLatentTransformer +from transformers.models.blt_wip.modeling_blt_wip import Patcher, ByteLatentTransformer, ByteLatentTransformerArgs from transformers.models.blt_wip.tokenizers.blt_tokenizer import BltTokenizer logger = logging.getLogger() @@ -77,7 +69,8 @@ def generate( ] start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len) batch_size = len(prompt_tokens) - tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).cuda().long() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tokens = torch.full((batch_size, end_pos), tokenizer.pad_id, dtype=torch.long, device=device) # Copy inputs to tensor for generated tokens for i, row_tokens in enumerate(prompt_tokens): @@ -120,19 +113,7 @@ def generate( def main(prompt: str = "my name is", model_name: str = "blt-1b"): - # distributed_args = DistributedArgs() - # distributed_args.configure_world() - # if not torch.distributed.is_initialized(): - # setup_torch_distributed(distributed_args) - - # Set device and ensure CUDA is available - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required but not available") - device = torch.device("cuda") - torch.cuda.empty_cache() # Clear any existing CUDA memory - - assert model_name in ["blt-1b", "blt-7b"] - model_name = model_name.replace("-", "_") + device = "cpu" #HF blt_repo = "facebook/blt-1b" @@ -155,14 +136,14 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): patcher_args = entropy_params["data"]["patcher_args"] model_args.patch_in_forward = True model_args.patch_size = patcher_args["patch_size"] - model_args.patching_mode = patcher_args["patching_mode"] + model_args.patching_mode = patcher_args["patching_mode"] #TODO: we need to pass "entropy" to run through the Patcher / "entropy model", which is the LMTransformer model_args.patching_threshold = patcher_args["threshold"] model_args.patching_threshold_add = patcher_args["threshold_add"] model_args.max_patch_length = patcher_args["max_patch_length"] model_args.patching_batch_size = patcher_args["patching_batch_size"] model_args.patching_device = patcher_args["patching_device"] model_args.monotonicity = patcher_args["monotonicity"] - + model = ByteLatentTransformer.from_pretrained(blt_repo, args=model_args).to(device) # Configure model's patcher @@ -192,9 +173,6 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): print(f'Completion: "{t}"') print() - # Clean up - torch.cuda.empty_cache() - - if __name__ == "__main__": - typer.run(main) + main() + diff --git a/src/transformers/models/blt_wip/blt_args.py b/src/transformers/models/blt_wip/blt_args.py index d292d9de1f4d..b43f33574634 100644 --- a/src/transformers/models/blt_wip/blt_args.py +++ b/src/transformers/models/blt_wip/blt_args.py @@ -51,6 +51,8 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): weight_tying: bool = False patch_in_forward: bool = False + realtime_patching: bool = True + # Architecture and dimensions dim_token: int | None = None dim_global: int = 512 diff --git a/src/transformers/models/blt_wip/modeling_blt_wip.py b/src/transformers/models/blt_wip/modeling_blt_wip.py index 208a39b126ed..4ca845fb23c2 100644 --- a/src/transformers/models/blt_wip/modeling_blt_wip.py +++ b/src/transformers/models/blt_wip/modeling_blt_wip.py @@ -50,66 +50,13 @@ ) -if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: - flex_attention_comp = torch.compile(flex_attention) -else: - flex_attention_comp = None - - -def patch_reduce(h, max_num_patches, reduction, patch_ids): - """ - Reduce variable length patches to single embedding per patch - Note: this works with variable number of patches for different sequences in the batch - It handles variable length patches by assuming that patch_lengths will be 0 for any - extra patches on the *right*. Since there can be a variable number of patches - this function also return the number of patches for each sequence in the batch. - Any embeddings on the right that are not allocated to a patch - (i.e. if the sum(patch_lengths[i]) < seq_len for any i) - will be sent to a dummy patch, which is trimmed before returning. - """ - bs, seq_len, emb_dim = h.shape - - patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) - - reduced_embs = torch.zeros( - (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device - ) - reduced_embs = reduced_embs.scatter_reduce( - src=h, - dim=1, - index=patch_ids, - reduce=reduction, - include_self=False, - ) - reduced_embs = reduced_embs[:, :max_num_patches, :] - - return reduced_embs +flex_attention_comp = flex_attention def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx -def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): - """ - 0 0 0 1 0 0 0 1 0 0 0 - 0 1 0 0 0 1 0 0 0 0 0 - -> 4 4 3 2 4 5 - """ - mask = batch == eos_id - mask[:, -1] = True # virtual eos at the end of each row - - # 0 0 0 1 0 0 0 1 0 0 X - # 0 1 0 0 0 1 0 0 0 0 X - row, col = torch.where(mask) - - # row = 0, 0, 0, 1, 1, 1 - # col = 3, 7, 10, 1, 5, 10 - seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] - # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) - return [int(col[0].item() + 1)] + seqlens.tolist() - - def create_causal_mask( seqlen, attn_impl: str, @@ -723,10 +670,6 @@ def init_weights(self): ) - - - - class PatchingModeEnum(str, Enum): entropy = "entropy" bpe = "bpe" @@ -738,7 +681,7 @@ class PatchingModeEnum(str, Enum): class PatcherArgs(BaseModel): patching_mode: PatchingModeEnum = PatchingModeEnum.entropy - patching_device: str = "cuda" + patching_device: str = "cpu" entropy_model_checkpoint_dir: str | None = None realtime_patching: bool = False threshold: float = 1.335442066192627 @@ -746,7 +689,7 @@ class PatcherArgs(BaseModel): max_patch_length: int | None = None patch_size: float = 4.5 patching_batch_size: int = 1 - device: str = "cuda" + device: str = "cpu" monotonicity: bool = False log_time: bool = False @@ -771,13 +714,179 @@ def check_non_zero_after_zero(tensor): def to_device(entropy_model, device=None): - if device == "cuda": + if device == "cpu": rank = get_local_rank() device = f"cuda:{rank}" entropy_model = entropy_model.to(device) return entropy_model, device + +def entropy(scores): + """ + scores: [bs, seq_len, vocab] + returns [bs, seq_len] + + Computes the entropy for each token in the batch. + Note: uses natural log. + """ + log_probs = F.log_softmax(scores, dim=-1) + probs = torch.exp(log_probs) + p_log_p = log_probs * probs + entropy = -p_log_p.sum(dim=-1) + return entropy + + +def calculate_entropies( + tokens: torch.tensor, + entropy_model, + patching_batch_size, + device: str | None = None, + enable_grad: bool = False, +): + """ + tokens: 2D tensor of shape [batch_size, seq_len] + Return 2D tensor of shape [batch_size, seq_len] with entropies for each token. + + Splits the tokens into chunks of size max_length and calculates entropies for each chunk. + Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument. + """ + + grad_context = nullcontext() if enable_grad else torch.no_grad() + + with grad_context: + entropies = [] + preds = [] + max_length = getattr(entropy_model, "max_length", 8192) + batch_numel = max_length * patching_batch_size + splits = torch.split(tokens.flatten(), batch_numel) + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros( + pad_size, dtype=split.dtype, device=split.device, requires_grad=False + ) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length) + if device is not None: + split = split.to(device) + # assert torch.all(split >= 0) and torch.all(split < 260) + pred = entropy_model(split) + pred = pred.reshape(-1, pred.shape[-1])[ + : split.numel() - pad_size, : + ] # [batch_size * seq_len, vocab] + preds.append(pred) + pred_entropies = entropy(pred) + entropies.append(pred_entropies) + + concat_entropies = torch.cat(entropies, dim=0) + concat_entropies = concat_entropies.reshape(tokens.shape) + concat_preds = torch.cat(preds, dim=0) + concat_preds = concat_preds.reshape(tokens.shape[0], -1) + return concat_entropies, concat_preds + + +def patch_start_ids_from_patch_start_mask(patch_start_mask): + bs, trunc_seq_len = patch_start_mask.shape + max_patches = patch_start_mask.sum(dim=1).max() + if max_patches == 0: + patch_start_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + else: + patch_ids = ( + torch.arange(trunc_seq_len, device=patch_start_mask.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + extra_patch_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) + patch_start_mask_padded = torch.cat( + (patch_start_mask, ~patch_start_mask), dim=1 + ) + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape( + bs, trunc_seq_len + )[:, :max_patches] + return patch_start_ids + + +def patch_lengths_from_start_ids(patch_start_ids, seq_len): + """ + Calculate patch lengths from start ids. + start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then + the rest are filled to the seq len. + seq_len: ex: 7 length of the sequence + + returns the patch lengths: + [1, 6] for the above example. + """ + last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) + patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) + patch_lengths = patch_end_ids - patch_start_ids + 1 + assert torch.all(patch_lengths >= 0), f"{patch_lengths}" + assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" + return patch_lengths + +def find_entropy_patch_start_ids( + entropies, + patch_size=None, + threshold=None, + threshold_add=None, + monotonicity=False, + include_next_token=True, +): + """ + Use entropies to find the start ids of each patch. + Use patch_size or threshold to figure out the total number of patches to allocate. + + When threshold is not None the number of patches is not constant between + different sequences, but patches can be identified incrementally rather than + decided globally using the entire sequence. + """ + bs, seq_len = entropies.shape[:2] + + first_ids = ( + torch.tensor([0, 1], dtype=torch.long, device=entropies.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + preds_truncation_len = first_ids.shape[ + 1 + ] # remove the first preds because they will be start of patches. + entropies = entropies[:, 1:] + if threshold is None: + num_patches = seq_len // patch_size + patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices + patch_start_ids = patch_start_ids.sort(dim=1).values + else: + # Assumes that there is at least one token going over the threshold + if monotonicity: + patch_start_mask = patch_start_mask_from_entropy_with_monotonicity( + entropies, threshold + ) + elif threshold_add is not None and threshold is not None: + patch_start_mask = patch_start_mask_global_and_monotonicity( + entropies, threshold, threshold_add + ) + else: + patch_start_mask = entropies > threshold + if not include_next_token: + patch_start_mask = patch_start_mask[:, :-1] + # patch_start_mask[1:] |= tokens[:-1] < OFFSET + patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask) + + patch_start_ids = torch.cat( + (first_ids, patch_start_ids + preds_truncation_len), dim=1 + ) + return patch_start_ids + def split_large_numbers(lst, m): new_lst = [] for i in lst: @@ -791,12 +900,15 @@ def split_large_numbers(lst, m): assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" return new_lst - class Patcher: def __init__(self, patcher_args: PatcherArgs): self.patcher_args = patcher_args self.patching_mode = patcher_args.patching_mode self.realtime_patching = patcher_args.realtime_patching + # self.realtime_patching = True + # entropy_model_checkpoint_dir = os.path.join( + # "hf-weights", "entropy_model" + # ) if self.realtime_patching: assert ( patcher_args.entropy_model_checkpoint_dir is not None @@ -815,8 +927,7 @@ def __init__(self, patcher_args: PatcherArgs): patcher_args.entropy_model_checkpoint_dir, state_path, ) - # entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) - entropy_model = entropy_model.to(patcher_args.patching_device) + entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) self.entropy_model = entropy_model else: self.entropy_model = None @@ -870,6 +981,41 @@ def patch( patch_lengths = torch.ones( (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device ) + elif self.patching_mode == PatchingModeEnum.entropy: + if self.log_time: + s = time.time() + if entropies is not None: + scores = entropies.to(dtype=torch.float32) + elif preds is not None: + scores = entropy(preds) + else: + start_entropies = time.time() + scores, _ = calculate_entropies( + tokens, + self.entropy_model, + self.patching_batch_size, + self.device, + ) + if self.log_time: + self.log["calculate_entropies"] += time.time() - s + s = time.time() + patch_start_ids = find_entropy_patch_start_ids( + scores, + self.patch_size, + include_next_token=include_next_token, + threshold=threshold if threshold is not None else self.threshold, + threshold_add=self.threshold_add, + monotonicity=self.monotonicity, + ) + if self.log_time: + self.log["find_entropy_patch_start_ids"] += time.time() - s + s = time.time() + patch_lengths = patch_lengths_from_start_ids( + patch_start_ids, seq_len_next_tok + ) + if self.log_time: + self.log["patch_lengths_from_start_ids"] += time.time() - s + s = time.time() else: raise NotImplementedError(f"self.patching_mode {self.patching_mode}") @@ -909,6 +1055,8 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: reloaded = json.loads(fr.read()) + + # with open("/Users/itazaporozhets/.cache/huggingface/hub/models--facebook--blt-1b/snapshots/8134b32f0b1d25d1248c30e8c7bdfd442d3bb380/entropy_model/params.json") as fr: torch.set_default_dtype(torch.bfloat16) model_params = reloaded["entropy_model"] logger.warning( @@ -927,6 +1075,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp ) entropy_model = LMTransformer(entropy_model_args) + # state_dict_path = /Users/itazaporozhets/.cache/huggingface/hub/models--facebook--blt-1b/snapshots/8134b32f0b1d25d1248c30e8c7bdfd442d3bb380/model.safetensors" entropy_model.load_state_dict( torch.load(state_dict_path, map_location=device)["model"], strict=False ) @@ -1033,22 +1182,20 @@ def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len): return decoder_patch_ids -primes = [ - 1000000007, - 5915587277, - 1500450271, - 3267000013, - 5754853343, - 4093082899, - 9576890767, - 3628273133, - 2860486313, - 5463458053, - 3367900313, -] - - def rolling_polynomial_hash(t, hash_func_nb: int = 0): + primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, + ] prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) return torch.sum(t * prime_powers, dim=-1) @@ -1147,6 +1294,7 @@ def cross_attn_mask( q_len, kv_len, ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" + block_mask = None if block_mask: def patch_mask(b, h, q_idx, kv_idx): @@ -1286,26 +1434,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len): return patch_ids - -def create_global_transformer(args: ByteLatentTransformerArgs): - global_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_global, - n_layers=args.n_layers_global, - n_heads=args.n_heads_global, - n_kv_heads=args.n_kv_heads_global, - local_attention_window_len=None, - dim_token_emb=get_global_dim_patch_emb(args), - dim_patch_emb=None, - cross_attn_encoder=False, - cross_attn_decoder=False, - ), - ) - - return GlobalTransformer(global_args) - - class LocalModelBase(nn.Module): def __init__(self, args: LocalModelArgs): super().__init__() @@ -1543,7 +1671,7 @@ def apply_cross_attention( # downsampling_by_pooling=self.downsampling_by_pooling, # patch_size=self.patch_size, # ) - patch_embeds = patch_reduce(h, num_patches, "amax", patch_ids) + patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) if self.patch_embedding_projection is not None: patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( @@ -1558,6 +1686,35 @@ def apply_cross_attention( ) return patch_embeds + patch_embeds_cross + def patch_reduce(self, h, max_num_patches, reduction, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + bs, seq_len, emb_dim = h.shape + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + + reduced_embs = torch.zeros( + (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device + ) + reduced_embs = reduced_embs.scatter_reduce( + src=h, + dim=1, + index=patch_ids, + reduce=reduction, + include_self=False, + ) + reduced_embs = reduced_embs[:, :max_num_patches, :] + + return reduced_embs + class LocalDecoder(LocalModelBase): def __init__(self, args: LocalModelArgs): @@ -1723,9 +1880,18 @@ def forward( xk = repeat_kv(xk, self.heads_per_group, dim=2) xv = repeat_kv(xv, self.heads_per_group, dim=2) - assert mask is None or isinstance(mask, BlockMask) + # assert mask is None or isinstance(mask, BlockMask) xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) - output = flex_attention_comp(xq, xk, xv, block_mask=mask) + #output = flex_attention_comp(xq, xk, xv, block_mask=mask) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask if isinstance(mask, torch.Tensor) else None + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D output = self.wo(output.reshape(output_shape)) @@ -2057,7 +2223,23 @@ def __init__(self, args: ByteLatentTransformerArgs): eos_id=args.eos_id, ) - self.global_transformer = create_global_transformer(args) + global_args = args.model_copy( + deep=True, + update=dict( + dim=args.dim_global, + n_layers=args.n_layers_global, + n_heads=args.n_heads_global, + n_kv_heads=args.n_kv_heads_global, + local_attention_window_len=None, + dim_token_emb=get_global_dim_patch_emb(args), + dim_patch_emb=None, + cross_attn_encoder=False, + cross_attn_decoder=False, + ), + ) + + self.global_transformer = GlobalTransformer(global_args) + self.local_decoder = LocalDecoder(local_decoder_args) self.encoder_hash_tok_embedding = init_embeddings( args, @@ -2081,7 +2263,7 @@ def __init__(self, args: ByteLatentTransformerArgs): args.encoder_ngram_to_size_str ) ngram_emb_dim = self.local_encoder.dim - for ngram_vocab_size in self.encoder_ngram_to_size.values(): + for ngram_vocab_size in self.encoderngram_to_size.values(): self.encoder_ngram_embedding.append( nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim) ) From 58c4a4e75d4951abe89a46a79a74731a099f0592 Mon Sep 17 00:00:00 2001 From: itazap Date: Fri, 6 Jun 2025 11:38:54 +0200 Subject: [PATCH 003/139] cpu friendly with full entropy model (real time patching) --- src/demo_hf.py | 6 ++- src/transformers/models/blt_wip/blt_args.py | 2 +- .../models/blt_wip/modeling_blt_wip.py | 38 ++++--------------- 3 files changed, 13 insertions(+), 33 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index f123d0241bc4..98cd01a22db2 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -122,6 +122,8 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): print("Loading model configuration...") config_path = hf_hub_download(repo_id=blt_repo, filename="config.json") entropy_params_path = hf_hub_download(repo_id=blt_repo, filename="entropy_model/params.json") + entropy_checkpoint_path = hf_hub_download(repo_id=blt_repo, filename="entropy_model/consolidated.pth") + entropy_dir = os.path.dirname(entropy_checkpoint_path) with open(config_path, 'r') as f: config = json.load(f) @@ -136,13 +138,15 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): patcher_args = entropy_params["data"]["patcher_args"] model_args.patch_in_forward = True model_args.patch_size = patcher_args["patch_size"] - model_args.patching_mode = patcher_args["patching_mode"] #TODO: we need to pass "entropy" to run through the Patcher / "entropy model", which is the LMTransformer + model_args.patching_mode = "entropy" #patcher_args["patching_mode"] #TODO: we need to pass "entropy" to run through the Patcher / "entropy model", which is the LMTransformer model_args.patching_threshold = patcher_args["threshold"] model_args.patching_threshold_add = patcher_args["threshold_add"] model_args.max_patch_length = patcher_args["max_patch_length"] model_args.patching_batch_size = patcher_args["patching_batch_size"] model_args.patching_device = patcher_args["patching_device"] model_args.monotonicity = patcher_args["monotonicity"] + model_args.entropy_model_checkpoint_dir = entropy_dir #original args on the hub don't set this + model = ByteLatentTransformer.from_pretrained(blt_repo, args=model_args).to(device) diff --git a/src/transformers/models/blt_wip/blt_args.py b/src/transformers/models/blt_wip/blt_args.py index b43f33574634..600e852d1980 100644 --- a/src/transformers/models/blt_wip/blt_args.py +++ b/src/transformers/models/blt_wip/blt_args.py @@ -217,7 +217,7 @@ def __post_init__(self): class LocalModelArgs(BaseTransformerArgs): model_config = ConfigDict() # Override defaults - attn_impl: str | None = "xformers" + attn_impl: str | None = "sdpa" # originally xformers attn_bias_type: str | None = "local_block_causal" # Local encoder specific dimensions diff --git a/src/transformers/models/blt_wip/modeling_blt_wip.py b/src/transformers/models/blt_wip/modeling_blt_wip.py index 4ca845fb23c2..661c3f304b46 100644 --- a/src/transformers/models/blt_wip/modeling_blt_wip.py +++ b/src/transformers/models/blt_wip/modeling_blt_wip.py @@ -624,7 +624,7 @@ def forward( bsz, seqlen = token_values.shape h = self.tok_embeddings(token_values) - + # attn_impl = "sdpa" mask = ( mask if mask is not None @@ -712,16 +712,6 @@ def check_non_zero_after_zero(tensor): non_zero_after_zero = (tensor != 0) & shifted_mask return non_zero_after_zero.any() - -def to_device(entropy_model, device=None): - if device == "cpu": - rank = get_local_rank() - device = f"cuda:{rank}" - entropy_model = entropy_model.to(device) - return entropy_model, device - - - def entropy(scores): """ scores: [bs, seq_len, vocab] @@ -866,17 +856,7 @@ def find_entropy_patch_start_ids( patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices patch_start_ids = patch_start_ids.sort(dim=1).values else: - # Assumes that there is at least one token going over the threshold - if monotonicity: - patch_start_mask = patch_start_mask_from_entropy_with_monotonicity( - entropies, threshold - ) - elif threshold_add is not None and threshold is not None: - patch_start_mask = patch_start_mask_global_and_monotonicity( - entropies, threshold, threshold_add - ) - else: - patch_start_mask = entropies > threshold + patch_start_mask = entropies > threshold if not include_next_token: patch_start_mask = patch_start_mask[:, :-1] # patch_start_mask[1:] |= tokens[:-1] < OFFSET @@ -905,10 +885,7 @@ def __init__(self, patcher_args: PatcherArgs): self.patcher_args = patcher_args self.patching_mode = patcher_args.patching_mode self.realtime_patching = patcher_args.realtime_patching - # self.realtime_patching = True - # entropy_model_checkpoint_dir = os.path.join( - # "hf-weights", "entropy_model" - # ) + self.realtime_patching = True if self.realtime_patching: assert ( patcher_args.entropy_model_checkpoint_dir is not None @@ -927,7 +904,7 @@ def __init__(self, patcher_args: PatcherArgs): patcher_args.entropy_model_checkpoint_dir, state_path, ) - entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) + entropy_model = entropy_model.to(patcher_args.patching_device) self.entropy_model = entropy_model else: self.entropy_model = None @@ -1055,8 +1032,6 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: reloaded = json.loads(fr.read()) - - # with open("/Users/itazaporozhets/.cache/huggingface/hub/models--facebook--blt-1b/snapshots/8134b32f0b1d25d1248c30e8c7bdfd442d3bb380/entropy_model/params.json") as fr: torch.set_default_dtype(torch.bfloat16) model_params = reloaded["entropy_model"] logger.warning( @@ -1070,12 +1045,11 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp ffn_dim_multiplier=model_params["ffn_dim_multiplier"], vocab_size=model_params["vocab_size"], attn_bias_type="local_block_causal", - attn_impl="xformers", + attn_impl="sdpa", #originally xformers sliding_window=512, ) entropy_model = LMTransformer(entropy_model_args) - # state_dict_path = /Users/itazaporozhets/.cache/huggingface/hub/models--facebook--blt-1b/snapshots/8134b32f0b1d25d1248c30e8c7bdfd442d3bb380/model.safetensors" entropy_model.load_state_dict( torch.load(state_dict_path, map_location=device)["model"], strict=False ) @@ -1885,6 +1859,7 @@ def forward( #output = flex_attention_comp(xq, xk, xv, block_mask=mask) is_causal = (mask == "causal") if isinstance(mask, str) else False mask = mask if isinstance(mask, torch.Tensor) else None + mask = mask.to(dtype=xq.dtype) output = F.scaled_dot_product_attention( xq, xk, @@ -2281,6 +2256,7 @@ def __init__(self, args: ByteLatentTransformerArgs): patching_threshold_add=args.patching_threshold_add, monotonicity=args.monotonicity, max_patch_length=args.max_patch_length, + entropy_model_checkpoint_dir=args.entropy_model_checkpoint_dir ) ) From 1d00859a7de1cf00d6124c68af1083595388703c Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 6 Jun 2025 15:51:19 +0000 Subject: [PATCH 004/139] adding config file instead of args file --- src/demo_hf.py | 55 +- src/transformers/models/blt_wip/blt_args.py | 273 +++--- .../models/blt_wip/configuration_blt.py | 563 ++++++++++++ .../models/blt_wip/modeling_blt_wip.py | 799 ++++++++---------- 4 files changed, 1043 insertions(+), 647 deletions(-) create mode 100644 src/transformers/models/blt_wip/configuration_blt.py diff --git a/src/demo_hf.py b/src/demo_hf.py index 98cd01a22db2..c93599a0e09a 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -6,7 +6,8 @@ import torch -from transformers.models.blt_wip.modeling_blt_wip import Patcher, ByteLatentTransformer, ByteLatentTransformerArgs +from transformers.models.blt_wip.modeling_blt_wip import ByteLatentTransformer +from transformers.models.blt_wip.configuration_blt import BLTConfig from transformers.models.blt_wip.tokenizers.blt_tokenizer import BltTokenizer logger = logging.getLogger() @@ -48,7 +49,6 @@ def generate( *, model: ByteLatentTransformer, tokenizer: BltTokenizer, - patcher: Patcher, max_prompt_len: int = 256, max_gen_len: int = 256, use_sampling: bool = False, @@ -57,9 +57,11 @@ def generate( top_p: float = 0.0, remove_prompts: bool = True, ) -> list[list[int]]: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + assert ( - patcher.realtime_patching - ), "generate_nocache requires patcher.realtime_patching=True" + model.patch_in_forward + ), "generate requires model.patch_in_forward=True" model.eval() prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts] # Truncation @@ -69,7 +71,6 @@ def generate( ] start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len) batch_size = len(prompt_tokens) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokens = torch.full((batch_size, end_pos), tokenizer.pad_id, dtype=torch.long, device=device) # Copy inputs to tensor for generated tokens @@ -79,7 +80,7 @@ def generate( for i, curr_pos in enumerate(range(start_pos, end_pos)): current_tokens = tokens[:, :curr_pos] - patch_lengths, _ = patcher.patch(current_tokens, include_next_token=True) + patch_lengths, _ = model.patch(current_tokens, include_next_token=True) logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1] if use_sampling: @@ -113,7 +114,7 @@ def generate( def main(prompt: str = "my name is", model_name: str = "blt-1b"): - device = "cpu" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #HF blt_repo = "facebook/blt-1b" @@ -133,31 +134,26 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): config['args']['attn_bias_type'] = 'causal' config['args']['attn_impl'] = 'sdpa' - model_args = ByteLatentTransformerArgs(**config["args"]) + model_config = BLTConfig(**config["args"]) patcher_args = entropy_params["data"]["patcher_args"] - model_args.patch_in_forward = True - model_args.patch_size = patcher_args["patch_size"] - model_args.patching_mode = "entropy" #patcher_args["patching_mode"] #TODO: we need to pass "entropy" to run through the Patcher / "entropy model", which is the LMTransformer - model_args.patching_threshold = patcher_args["threshold"] - model_args.patching_threshold_add = patcher_args["threshold_add"] - model_args.max_patch_length = patcher_args["max_patch_length"] - model_args.patching_batch_size = patcher_args["patching_batch_size"] - model_args.patching_device = patcher_args["patching_device"] - model_args.monotonicity = patcher_args["monotonicity"] - model_args.entropy_model_checkpoint_dir = entropy_dir #original args on the hub don't set this - - - model = ByteLatentTransformer.from_pretrained(blt_repo, args=model_args).to(device) - - # Configure model's patcher - model.patcher.realtime_patching = True - model.patcher.entropy_model_checkpoint_dir = os.path.join( - "hf-weights", "entropy_model" - ) + model_config.patch_in_forward = True + model_config.realtime_patching = True # Enable realtime patching + model_config.patch_size = patcher_args["patch_size"] + model_config.patching_mode = "entropy" #patcher_args["patching_mode"] #TODO: we need to pass "entropy" to run through the Patcher / "entropy model", which is the LMTransformer + model_config.patching_threshold = patcher_args["threshold"] + model_config.patching_threshold_add = patcher_args["threshold_add"] + model_config.max_patch_length = patcher_args["max_patch_length"] + model_config.patching_batch_size = patcher_args["patching_batch_size"] + model_config.patching_device = patcher_args["patching_device"] + model_config.monotonicity = patcher_args["monotonicity"] + model_config.entropy_model_checkpoint_dir = entropy_dir #original config on the hub don't set this + + + model = ByteLatentTransformer.from_pretrained(blt_repo, config=model_config).to(device) tokenizer = BltTokenizer( - vocab_size_unit_1=model_args.vocab_size, + vocab_size_unit_1=model_config.vocab_size, add_bos=True, add_eos=True ) @@ -166,8 +162,7 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): outputs = generate( prompts, model=model, - tokenizer=tokenizer, - patcher=model.patcher, # Use the model's patcher + tokenizer=tokenizer, max_gen_len=100 ) diff --git a/src/transformers/models/blt_wip/blt_args.py b/src/transformers/models/blt_wip/blt_args.py index 600e852d1980..78fc3482c362 100644 --- a/src/transformers/models/blt_wip/blt_args.py +++ b/src/transformers/models/blt_wip/blt_args.py @@ -1,5 +1,5 @@ -from enum import Enum, auto -from typing import Any, List, Optional, Tuple, Union +from enum import Enum +from typing import Any, Optional from pydantic import BaseModel, ConfigDict, model_validator from typing_extensions import Self @@ -13,56 +13,114 @@ class InitStdFactor(str, Enum): DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 -class BaseTransformerArgs(BaseModel): +class PatchingModeEnum(str, Enum): + entropy = "entropy" + bpe = "bpe" + bpe_patcher = "bpe_patcher" + space = "space" + static = "static" + byte = "byte" + + +class LMTransformerArgs(BaseModel): + """Arguments for the Language Model Transformer (used as entropy model for patching)""" model_config = ConfigDict() + + # Basic architecture dim: int = 512 n_layers: int = 8 head_dim: int | None = None n_heads: int | None = None n_kv_heads: int | None = None - + + # Transformer configuration + max_seqlen: int = 1024 + norm_eps: float = 1e-5 + dropout: float = 0 + vocab_size: int = -1 + sliding_window: int | None = None + + # Feedforward ffn_dim_multiplier: float | None = None - multiple_of: int = 256 - - norm_eps: float = 1e-5 - + + # Positional encoding rope_theta: float = 10000.0 rope_use_fp32_in_outer_product: bool = False - + + # Attention + attn_impl: str = "sdpa" + attn_bias_type: str = "causal" + + # Initialization init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED - - max_seqlen: int = 1024 - - attn_impl: str | None = "sdpa" - attn_bias_type: str | None = None + + # Embedding dimensions + dim_token_emb: int | None = None + + # Model behavior + weight_tying: bool = False + seed: int = 42 + # Special token config - eos_id: int | None = EOS_ID + eos_id: int = EOS_ID + -class ByteLatentTransformerArgs(BaseTransformerArgs): +class ByteLatentTransformerArgs(BaseModel): + """Arguments for the Byte Latent Transformer (main BLT model)""" + model_config = ConfigDict() + # Basic model configuration seed: int = 42 vocab_size: int = -1 + + # Main architecture dimensions (these will be used for creating transformer args) dim: int = 512 n_layers: int = 8 - n_heads: int = 8 - # TODO: What is the purpose of this parameter? - weight_tying: bool = False - patch_in_forward: bool = False - - realtime_patching: bool = True - - # Architecture and dimensions - dim_token: int | None = None + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None + + # Component-specific dimensions dim_global: int = 512 dim_local_decoder: int = 512 dim_local_encoder: int = 512 n_layers_global: int = 8 n_layers_local_decoder: int = 8 n_layers_local_encoder: int = 8 - - # Tokenization and patching + n_heads_global: int = 8 + n_heads_local_decoder: int = 8 + n_heads_local_encoder: int = 8 + n_kv_heads_global: int | None = None + + # Transformer configuration (needed by transformer components) + max_seqlen: int = 1024 + norm_eps: float = 1e-5 + dropout: float = 0 + + # Feedforward (needed by transformer components) + ffn_dim_multiplier: float = 1.0 + multiple_of: int = 256 + + # Positional encoding (needed by transformer components) + rope_theta: float = 10000.0 + rope_use_fp32_in_outer_product: bool = False + + # Attention (needed by transformer components) + attn_impl: str = "sdpa" + attn_bias_type: str = "causal" + + # Initialization (needed by transformer components) + init_base_std: float | None = None + init_std_factor: InitStdFactor = InitStdFactor.DISABLED + + # Embedding dimensions (needed by transformer components) + dim_token_emb: int | None = None + + # Patching configuration + patch_in_forward: bool = False + realtime_patching: bool = True patch_size: float | None = None patching_mode: str | None = None patching_threshold: float | None = None @@ -71,17 +129,8 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): patching_batch_size: int = 1 patching_device: str = "cuda" max_patch_length: int | None = None - - # Encoder/Decoder configuration - tie_local_encoder_decoder_logits: bool = False - use_local_encoder_transformer: bool = False - encoder_lm_loss: bool = False - max_encoder_seq_length: int | None = None - pad_to_max_length: bool = False - encoder_enable_byte_ngrams: bool = False - encoder_enable_byte_group_hash: bool = False - ngram_vocab_sizes: int | None = None - + entropy_model_checkpoint_dir: str | None = None + # Cross attention configurations cross_attn_encoder: bool = False cross_attn_decoder: bool = False @@ -93,81 +142,37 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): cross_attn_all_layers_encoder: bool = False cross_attn_use_flex_attention: bool = True cross_attn_init_by_pooling: bool = False - - # Encoder hash configurations + + # Encoder configurations + use_local_encoder_transformer: bool = False + max_encoder_seq_length: int | None = None encoder_hash_byte_group_size: Any | None = None encoder_hash_byte_group_vocab: int = 30000 encoder_hash_byte_group_nb_functions: int = 3 - - # Model behavior and optimization - log_patch_lengths: bool = False - non_linearity: str = "swiglu" - use_rope: bool = True - recompute_fc1_out: bool = False - recompute_fc3_out: bool = False - recompute_attn: bool = True - custom_bwd: bool = False - layer_ckpt: str = "all" - - # Initialization and attention - init_use_gaussian: bool = True - init_use_depth: str = "current" - attn_bias_type: str = "causal" - alpha_depth: str = "disabled" - max_length: int = 2048 - - # Norm configuration - norm_eps: float = 1e-5 - norm_affine: bool = True - pre_norm: bool = True - norm_type: str = "rmsnorm" - - # Additional configurations - multiple_of: int = 256 - ffn_dim_multiplier: float = 1.0 - dropout: float = 0 - output_size: int = -1 - - # Additional parameters from ModelArgs - architecture: str = "vanilla" - share_encoder_decoder_emb: bool = True - global_local_decoder_residual_layer: str | None = None - - tokenize_with_bpe_delimiter: bool = False - patching_thresholds_str: str | None = None - tie_local_encoder_decoder: bool = False - encoder_preds_low_entropy_toks: float | None = None - encoder_preds_random_toks: float | None = None - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - encoder_ngram_table_dir: str | None = None + encoder_enable_byte_ngrams: bool = False encoder_ngram_to_size_str: str | None = None - - # Model architecture params - entropy_model_checkpoint_dir: str | None = None - entropy_model_is_ngram_model: bool = False downsampling_by_pooling: str | None = None - n_heads_global: int = 8 - n_heads_local_decoder: int = 8 - n_heads_local_encoder: int = 8 - n_kv_heads: int | None = None - n_kv_heads_global: int | None = None - conv_kernel_size: int | None = None + + # Architecture and dimensions + dim_token: int | None = None + share_encoder_decoder_emb: bool = True + weight_tying: bool = False + + # Attention configuration local_attention_window_len: int | None = None - + use_rope: bool = True + # Performance optimization sequence_parallel: bool = False loss_parallel: bool = False fuse_sequence_parallel: bool = False use_fsdp: bool = True - attn_to_keep: str = "all" - + # Parameter mixing pm_size: int = 0 - - # Logging - full_logging_n_layers: int = 4 + + # Special token config + eos_id: int = EOS_ID @model_validator(mode="after") def check_hash_byte_sizes(self) -> Self: @@ -182,71 +187,3 @@ def check_hash_byte_sizes(self) -> Self: ] return self - -class GlobalTransformerArgs(ByteLatentTransformerArgs): - # Global encoder specific dimensions - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - def __post_init__(self): - # Override base args with global encoder specific values - self.dim = self.dim_global - self.n_layers = self.n_layers_global - self.n_heads = self.n_heads_global - self.n_kv_heads = self.n_kv_heads_global - self.local_attention_window_len = None - self.cross_attn_encoder = False - self.cross_attn_decoder = False - - -class LocalDecoderArgs(ByteLatentTransformerArgs): - # Local decoder specific dimensions - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - def __post_init__(self): - # Override base args with local decoder specific values - self.dim = self.dim_local_decoder - self.n_layers = self.n_layers_local_decoder - self.n_heads = self.n_heads_local_decoder - self.cross_attn_encoder = False - self.cross_attn_init_by_pooling = False - self.attn_bias_type = "local_block_causal" - - -class LocalModelArgs(BaseTransformerArgs): - model_config = ConfigDict() - # Override defaults - attn_impl: str | None = "sdpa" # originally xformers - attn_bias_type: str | None = "local_block_causal" - - # Local encoder specific dimensions - dropout: float - vocab_size: int - patch_size: float - sliding_window: int | None - use_rope: bool - cross_attn_encoder: bool | None - cross_attn_decoder: bool | None - cross_attn_k: int | None - cross_attn_init_by_pooling: bool - patching_mode: str - use_local_encoder_transformer: bool - downsampling_by_pooling: str | None - encoder_hash_byte_group_size: Any | None = None - cross_attn_all_layers_encoder: bool = False - cross_attn_all_layers_decoder: bool = False - cross_attn_nheads: int | None - - dim_token_emb: int - dim_patch_emb: int | None - - -class LMTransformerArgs(BaseTransformerArgs): - seed: int = 42 - - vocab_size: int = -1 - weight_tying: bool = False - - sliding_window: int | None = None - diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py new file mode 100644 index 000000000000..d6fc85789329 --- /dev/null +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -0,0 +1,563 @@ +# coding=utf-8 +# Copyright 2024 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLT (Byte Latent Transformer) model configuration""" + +from enum import Enum +from typing import Any, Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class InitStdFactor(str, Enum): + DISABLED = "disabled" # Init std is divided by 1.0 + GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) + CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) + DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 + + +class PatchingModeEnum(str, Enum): + entropy = "entropy" + bpe = "bpe" + bpe_patcher = "bpe_patcher" + space = "space" + static = "static" + byte = "byte" + + +class BLTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ByteLatentTransformer`]. It is used to instantiate a + BLT model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 256): + Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented. + max_seqlen (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model can handle. + + # Main architecture dimensions + dim (`int`, *optional*, defaults to 512): + Main dimension of the model. + n_layers (`int`, *optional*, defaults to 8): + Number of layers in the main transformer. + n_heads (`int`, *optional*, defaults to 8): + Number of attention heads in the main transformer. + head_dim (`int`, *optional*): + Dimension of each attention head. If not specified, computed as dim // n_heads. + n_kv_heads (`int`, *optional*): + Number of key-value heads for grouped query attention. If not specified, defaults to n_heads. + + # Component-specific dimensions + dim_global (`int`, *optional*, defaults to 512): + Dimension of the global transformer component. + dim_local_decoder (`int`, *optional*, defaults to 512): + Dimension of the local decoder component. + dim_local_encoder (`int`, *optional*, defaults to 512): + Dimension of the local encoder component. + n_layers_global (`int`, *optional*, defaults to 8): + Number of layers in the global transformer. + n_layers_local_decoder (`int`, *optional*, defaults to 8): + Number of layers in the local decoder. + n_layers_local_encoder (`int`, *optional*, defaults to 8): + Number of layers in the local encoder. + n_heads_global (`int`, *optional*, defaults to 8): + Number of attention heads in the global transformer. + n_heads_local_decoder (`int`, *optional*, defaults to 8): + Number of attention heads in the local decoder. + n_heads_local_encoder (`int`, *optional*, defaults to 8): + Number of attention heads in the local encoder. + n_kv_heads_global (`int`, *optional*): + Number of key-value heads in the global transformer. + + # Transformer configuration + norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers. + ffn_dim_multiplier (`float`, *optional*, defaults to 1.0): + Multiplier for the feedforward network dimension. + multiple_of (`int`, *optional*, defaults to 256): + Make feedforward network dimension multiple of this value. + + # Positional encoding + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False): + Whether to use fp32 in RoPE outer product computation. + + # Attention configuration + attn_impl (`str`, *optional*, defaults to "sdpa"): + Attention implementation to use ("sdpa" or "flex_attention"). + attn_bias_type (`str`, *optional*, defaults to "causal"): + Type of attention bias to apply. + local_attention_window_len (`int`, *optional*): + Window length for local attention. + use_rope (`bool`, *optional*, defaults to True): + Whether to use rotary position embeddings. + + # Initialization + init_base_std (`float`, *optional*): + Base standard deviation for weight initialization. + init_std_factor (`str`, *optional*, defaults to "disabled"): + Factor for adjusting initialization standard deviation. + + # Embedding dimensions + dim_token_emb (`int`, *optional*): + Token embedding dimension. + dim_token (`int`, *optional*): + Token dimension. + + # Patching configuration + patch_in_forward (`bool`, *optional*, defaults to False): + Whether to perform patching during forward pass. + realtime_patching (`bool`, *optional*, defaults to True): + Whether to use realtime patching. + patch_size (`float`, *optional*): + Size of patches for static patching. + patching_mode (`str`, *optional*): + Mode for patching ("entropy", "static", etc.). + patching_threshold (`float`, *optional*): + Threshold for entropy-based patching. + patching_threshold_add (`float`, *optional*): + Additional threshold parameter for patching. + monotonicity (`bool`, *optional*, defaults to False): + Whether to enforce monotonicity in patching. + patching_batch_size (`int`, *optional*, defaults to 1): + Batch size for patching operations. + patching_device (`str`, *optional*, defaults to "cuda"): + Device to use for patching operations. + max_patch_length (`int`, *optional*): + Maximum length of patches. + entropy_model_checkpoint_dir (`str`, *optional*): + Directory containing entropy model checkpoint. + + # Cross attention configurations + cross_attn_encoder (`bool`, *optional*, defaults to False): + Whether to use cross attention in encoder. + cross_attn_decoder (`bool`, *optional*, defaults to False): + Whether to use cross attention in decoder. + cross_attn_window_encoder (`int`, *optional*): + Cross attention window for encoder. + cross_attn_window_decoder (`int`, *optional*): + Cross attention window for decoder. + cross_attn_k (`int`, *optional*): + Number of cross attention components. + cross_attn_nheads (`int`, *optional*): + Number of heads for cross attention. + cross_attn_all_layers_decoder (`bool`, *optional*, defaults to False): + Whether to apply cross attention to all decoder layers. + cross_attn_all_layers_encoder (`bool`, *optional*, defaults to False): + Whether to apply cross attention to all encoder layers. + cross_attn_use_flex_attention (`bool`, *optional*, defaults to True): + Whether to use flexible attention for cross attention. + cross_attn_init_by_pooling (`bool`, *optional*, defaults to False): + Whether to initialize cross attention by pooling. + + # Encoder configurations + use_local_encoder_transformer (`bool`, *optional*, defaults to False): + Whether to use transformer in local encoder. + max_encoder_seq_length (`int`, *optional*): + Maximum sequence length for encoder. + encoder_hash_byte_group_size (`Any`, *optional*): + Hash byte group size for encoder. + encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 30000): + Vocabulary size for hash byte groups. + encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3): + Number of hash functions for byte groups. + encoder_enable_byte_ngrams (`bool`, *optional*, defaults to False): + Whether to enable byte n-grams in encoder. + encoder_ngram_to_size_str (`str`, *optional*): + String defining n-gram sizes. + downsampling_by_pooling (`str`, *optional*): + Type of pooling for downsampling. + + # Model behavior + share_encoder_decoder_emb (`bool`, *optional*, defaults to True): + Whether to share encoder and decoder embeddings. + weight_tying (`bool`, *optional*, defaults to False): + Whether to tie input and output embeddings. + + # Performance optimization + sequence_parallel (`bool`, *optional*, defaults to False): + Whether to use sequence parallelism. + loss_parallel (`bool`, *optional*, defaults to False): + Whether to use loss parallelism. + fuse_sequence_parallel (`bool`, *optional*, defaults to False): + Whether to fuse sequence parallel operations. + use_fsdp (`bool`, *optional*, defaults to True): + Whether to use fully sharded data parallel. + + # Parameter mixing + pm_size (`int`, *optional*, defaults to 0): + Parameter mixing size. + + # Special tokens + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to -1): + The id of the padding token. + + ```python + >>> from transformers import ByteLatentTransformer, BLTConfig + + >>> # Initializing a BLT configuration + >>> configuration = BLTConfig() + + >>> # Initializing a model from the configuration + >>> model = ByteLatentTransformer(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "blt" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=256, + max_seqlen=1024, + + # Main architecture dimensions + dim=512, + n_layers=8, + n_heads=8, + head_dim=None, + n_kv_heads=None, + + # Component-specific dimensions + dim_global=512, + dim_local_decoder=512, + dim_local_encoder=512, + n_layers_global=8, + n_layers_local_decoder=8, + n_layers_local_encoder=8, + n_heads_global=8, + n_heads_local_decoder=8, + n_heads_local_encoder=8, + n_kv_heads_global=None, + + # Transformer configuration + norm_eps=1e-5, + dropout=0.0, + ffn_dim_multiplier=1.0, + multiple_of=256, + + # Positional encoding + rope_theta=10000.0, + rope_use_fp32_in_outer_product=False, + + # Attention configuration + attn_impl="sdpa", + attn_bias_type="causal", + local_attention_window_len=None, + use_rope=True, + + # Initialization + init_base_std=None, + init_std_factor="disabled", + + # Embedding dimensions + dim_token_emb=None, + dim_token=None, + + # Patching configuration + patch_in_forward=False, + realtime_patching=True, + patch_size=None, + patching_mode=None, + patching_threshold=None, + patching_threshold_add=None, + monotonicity=False, + patching_batch_size=1, + patching_device="cuda", + max_patch_length=None, + entropy_model_checkpoint_dir=None, + + # Cross attention configurations + cross_attn_encoder=False, + cross_attn_decoder=False, + cross_attn_window_encoder=None, + cross_attn_window_decoder=None, + cross_attn_k=None, + cross_attn_nheads=None, + cross_attn_all_layers_decoder=False, + cross_attn_all_layers_encoder=False, + cross_attn_use_flex_attention=True, + cross_attn_init_by_pooling=False, + + # Encoder configurations + use_local_encoder_transformer=False, + max_encoder_seq_length=None, + encoder_hash_byte_group_size=None, + encoder_hash_byte_group_vocab=30000, + encoder_hash_byte_group_nb_functions=3, + encoder_enable_byte_ngrams=False, + encoder_ngram_to_size_str=None, + downsampling_by_pooling=None, + + # Model behavior + share_encoder_decoder_emb=True, + weight_tying=False, + + # Performance optimization + sequence_parallel=False, + loss_parallel=False, + fuse_sequence_parallel=False, + use_fsdp=True, + + # Parameter mixing + pm_size=0, + + # Special tokens + bos_token_id=1, + eos_token_id=2, + pad_token_id=-1, + + # Inherited + **kwargs, + ): + # Basic model configuration + self.vocab_size = vocab_size + self.max_seqlen = max_seqlen + + # Main architecture dimensions + self.dim = dim + self.n_layers = n_layers + self.n_heads = n_heads + self.head_dim = head_dim + self.n_kv_heads = n_kv_heads + + # Component-specific dimensions + self.dim_global = dim_global + self.dim_local_decoder = dim_local_decoder + self.dim_local_encoder = dim_local_encoder + self.n_layers_global = n_layers_global + self.n_layers_local_decoder = n_layers_local_decoder + self.n_layers_local_encoder = n_layers_local_encoder + self.n_heads_global = n_heads_global + self.n_heads_local_decoder = n_heads_local_decoder + self.n_heads_local_encoder = n_heads_local_encoder + self.n_kv_heads_global = n_kv_heads_global + + # Transformer configuration + self.norm_eps = norm_eps + self.dropout = dropout + self.ffn_dim_multiplier = ffn_dim_multiplier + self.multiple_of = multiple_of + + # Positional encoding + self.rope_theta = rope_theta + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product + + # Attention configuration + self.attn_impl = attn_impl + self.attn_bias_type = attn_bias_type + self.local_attention_window_len = local_attention_window_len + self.use_rope = use_rope + + # Initialization + self.init_base_std = init_base_std + self.init_std_factor = InitStdFactor(init_std_factor) + + # Embedding dimensions + self.dim_token_emb = dim_token_emb + self.dim_token = dim_token + + # Patching configuration + self.patch_in_forward = patch_in_forward + self.realtime_patching = realtime_patching + self.patch_size = patch_size + self.patching_mode = patching_mode + self.patching_threshold = patching_threshold + self.patching_threshold_add = patching_threshold_add + self.monotonicity = monotonicity + self.patching_batch_size = patching_batch_size + self.patching_device = patching_device + self.max_patch_length = max_patch_length + self.entropy_model_checkpoint_dir = entropy_model_checkpoint_dir + + # Cross attention configurations + self.cross_attn_encoder = cross_attn_encoder + self.cross_attn_decoder = cross_attn_decoder + self.cross_attn_window_encoder = cross_attn_window_encoder + self.cross_attn_window_decoder = cross_attn_window_decoder + self.cross_attn_k = cross_attn_k + self.cross_attn_nheads = cross_attn_nheads + self.cross_attn_all_layers_decoder = cross_attn_all_layers_decoder + self.cross_attn_all_layers_encoder = cross_attn_all_layers_encoder + self.cross_attn_use_flex_attention = cross_attn_use_flex_attention + self.cross_attn_init_by_pooling = cross_attn_init_by_pooling + + # Encoder configurations + self.use_local_encoder_transformer = use_local_encoder_transformer + self.max_encoder_seq_length = max_encoder_seq_length + self.encoder_hash_byte_group_size = encoder_hash_byte_group_size + self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab + self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions + self.encoder_enable_byte_ngrams = encoder_enable_byte_ngrams + self.encoder_ngram_to_size_str = encoder_ngram_to_size_str + self.downsampling_by_pooling = downsampling_by_pooling + + # Model behavior + self.share_encoder_decoder_emb = share_encoder_decoder_emb + self.weight_tying = weight_tying + + # Performance optimization + self.sequence_parallel = sequence_parallel + self.loss_parallel = loss_parallel + self.fuse_sequence_parallel = fuse_sequence_parallel + self.use_fsdp = use_fsdp + + # Parameter mixing + self.pm_size = pm_size + + # Handle hash byte group size validation + if ( + self.encoder_hash_byte_group_size is not None + and type(self.encoder_hash_byte_group_size) == str + ): + self.encoder_hash_byte_group_size = [ + int(x) + for x in self.encoder_hash_byte_group_size.split(",") + if len(x) > 0 + ] + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + **kwargs, + ) + + +# Separate config for the LM Transformer (entropy model) +class LMTransformerConfig(PretrainedConfig): + r""" + Configuration class for the LM Transformer used as entropy model in BLT patching. + + Args: + vocab_size (`int`, *optional*, defaults to 256): + Vocabulary size of the LM model. + dim (`int`, *optional*, defaults to 512): + Dimension of the hidden representations. + n_layers (`int`, *optional*, defaults to 8): + Number of hidden layers. + n_heads (`int`, *optional*, defaults to 8): + Number of attention heads. + head_dim (`int`, *optional*): + Dimension of each attention head. + n_kv_heads (`int`, *optional*): + Number of key-value heads for grouped query attention. + max_seqlen (`int`, *optional*, defaults to 1024): + Maximum sequence length. + norm_eps (`float`, *optional*, defaults to 1e-5): + Epsilon for layer normalization. + dropout (`float`, *optional*, defaults to 0.0): + Dropout probability. + sliding_window (`int`, *optional*): + Sliding window size for attention. + ffn_dim_multiplier (`float`, *optional*): + Multiplier for feedforward dimension. + multiple_of (`int`, *optional*, defaults to 256): + Make feedforward dimension multiple of this. + rope_theta (`float`, *optional*, defaults to 10000.0): + RoPE theta parameter. + rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False): + Whether to use fp32 in RoPE outer product. + attn_impl (`str`, *optional*, defaults to "sdpa"): + Attention implementation. + attn_bias_type (`str`, *optional*, defaults to "causal"): + Attention bias type. + init_base_std (`float`, *optional*): + Base initialization standard deviation. + init_std_factor (`str`, *optional*, defaults to "disabled"): + Initialization std factor. + dim_token_emb (`int`, *optional*): + Token embedding dimension. + weight_tying (`bool`, *optional*, defaults to False): + Whether to tie embeddings. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of sequence token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of sequence token id. + """ + + model_type = "lm_transformer" + + def __init__( + self, + vocab_size=256, + dim=512, + n_layers=8, + n_heads=8, + head_dim=None, + n_kv_heads=None, + max_seqlen=1024, + norm_eps=1e-5, + dropout=0.0, + sliding_window=None, + ffn_dim_multiplier=None, + multiple_of=256, + rope_theta=10000.0, + rope_use_fp32_in_outer_product=False, + attn_impl="sdpa", + attn_bias_type="causal", + init_base_std=None, + init_std_factor="disabled", + dim_token_emb=None, + weight_tying=False, + bos_token_id=1, + eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.dim = dim + self.n_layers = n_layers + self.n_heads = n_heads + self.head_dim = head_dim + self.n_kv_heads = n_kv_heads + self.max_seqlen = max_seqlen + self.norm_eps = norm_eps + self.dropout = dropout + self.sliding_window = sliding_window + self.ffn_dim_multiplier = ffn_dim_multiplier + self.multiple_of = multiple_of + self.rope_theta = rope_theta + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product + self.attn_impl = attn_impl + self.attn_bias_type = attn_bias_type + self.init_base_std = init_base_std + self.init_std_factor = InitStdFactor(init_std_factor) + self.dim_token_emb = dim_token_emb + self.weight_tying = weight_tying + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + +__all__ = ["BLTConfig", "LMTransformerConfig", "InitStdFactor", "PatchingModeEnum"] \ No newline at end of file diff --git a/src/transformers/models/blt_wip/modeling_blt_wip.py b/src/transformers/models/blt_wip/modeling_blt_wip.py index 661c3f304b46..1b4ec71e2c9f 100644 --- a/src/transformers/models/blt_wip/modeling_blt_wip.py +++ b/src/transformers/models/blt_wip/modeling_blt_wip.py @@ -21,8 +21,7 @@ import abc import os -import time -from collections import defaultdict +from contextlib import nullcontext from pydantic import BaseModel @@ -40,14 +39,11 @@ logger = logging.getLogger() -from .blt_args import ( - BaseTransformerArgs, - ByteLatentTransformerArgs, - GlobalTransformerArgs, - LocalDecoderArgs, - LocalModelArgs, - LMTransformerArgs, - +from .configuration_blt import ( + BLTConfig, + LMTransformerConfig, + PatchingModeEnum, + InitStdFactor, ) flex_attention_comp = flex_attention @@ -455,7 +451,7 @@ def reset_parameters(self, init_std=None, factor=1.0): class TransformerBlock(nn.Module): - def __init__(self, args: BaseTransformerArgs): + def __init__(self, args): super().__init__() assert (args.head_dim is not None) or ( @@ -520,25 +516,26 @@ def get_output_seq_len(self) -> int: class BaseTransformer(nn.Module, SequenceModelWithOutput): - def __init__(self, args: BaseTransformerArgs): + def __init__(self, config): super().__init__() - self.dim = args.dim - self.init_base_std = args.init_base_std - self.attn_impl = args.attn_impl - self.attn_bias_type = args.attn_bias_type - self.init_std_factor = InitStdFactor(args.init_std_factor) - self.max_seqlen = args.max_seqlen + self.dim = config.dim + self.init_base_std = config.init_base_std + self.attn_impl = config.attn_impl + self.attn_bias_type = config.attn_bias_type + self.init_std_factor = config.init_std_factor + self.max_seqlen = config.max_seqlen self.rope_embeddings = RotaryEmbedding( - theta=args.rope_theta, - head_dim=args.head_dim or args.dim // args.n_heads, - max_seqlen=args.max_seqlen, - rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + theta=config.rope_theta, + head_dim=config.head_dim or config.dim // config.n_heads, + max_seqlen=config.max_seqlen, + rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, ) - self.eos_id = args.eos_id + # Handle both eos_id and eos_token_id for compatibility + self.eos_id = getattr(config, 'eos_id', getattr(config, 'eos_token_id', 2)) self.layers = nn.ModuleList() - for _ in range(args.n_layers): - self.layers.append(TransformerBlock(args)) + for _ in range(config.n_layers): + self.layers.append(TransformerBlock(config)) def get_output_seq_len(self): return self.max_seqlen @@ -580,30 +577,30 @@ class LMTransformer( license_name="fair-noncommercial-research-license", license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", coders={ - LMTransformerArgs: ( - lambda x: {"args": x.model_dump()}, - lambda data: LMTransformerArgs(**data), + LMTransformerConfig: ( + lambda x: {"config": x.to_dict()}, + lambda data: LMTransformerConfig(**data), ) }, ): - def __init__(self, args: LMTransformerArgs): - super().__init__(args) - self.weight_tying = args.weight_tying - self.sliding_window = args.sliding_window + def __init__(self, config: LMTransformerConfig): + super().__init__(config) + self.weight_tying = config.weight_tying + self.sliding_window = config.sliding_window - assert args.vocab_size > 0 + assert config.vocab_size > 0 - self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) + self.tok_embeddings = torch.nn.Embedding(config.vocab_size, config.dim) - self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) self.output = nn.Linear( - args.dim, - args.vocab_size, + config.dim, + config.vocab_size, bias=False, ) - if args.weight_tying: + if config.weight_tying: self.output.weight = self.embeddings.tok_embeddings.weight def push_to_hub(self, *args, **kwargs): @@ -670,32 +667,6 @@ def init_weights(self): ) -class PatchingModeEnum(str, Enum): - entropy = "entropy" - bpe = "bpe" - bpe_patcher = "bpe_patcher" - space = "space" - static = "static" - byte = "byte" - - -class PatcherArgs(BaseModel): - patching_mode: PatchingModeEnum = PatchingModeEnum.entropy - patching_device: str = "cpu" - entropy_model_checkpoint_dir: str | None = None - realtime_patching: bool = False - threshold: float = 1.335442066192627 - threshold_add: float | None = None - max_patch_length: int | None = None - patch_size: float = 4.5 - patching_batch_size: int = 1 - device: str = "cpu" - monotonicity: bool = False - log_time: bool = False - - def build(self) -> "Patcher": - return Patcher(self) - def rightpad(seq, pad_id, max_len): return seq + [pad_id] * (max_len - len(seq)) @@ -880,186 +851,8 @@ def split_large_numbers(lst, m): assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" return new_lst -class Patcher: - def __init__(self, patcher_args: PatcherArgs): - self.patcher_args = patcher_args - self.patching_mode = patcher_args.patching_mode - self.realtime_patching = patcher_args.realtime_patching - self.realtime_patching = True - if self.realtime_patching: - assert ( - patcher_args.entropy_model_checkpoint_dir is not None - ), "Cannot require realtime patching without an entropy model checkpoint" - maybe_consolidated = os.path.join( - patcher_args.entropy_model_checkpoint_dir, - "consolidated/consolidated.pth", - ) - if os.path.exists(maybe_consolidated): - state_path = maybe_consolidated - else: - state_path = os.path.join( - patcher_args.entropy_model_checkpoint_dir, "consolidated.pth" - ) - entropy_model, _ = load_entropy_model( - patcher_args.entropy_model_checkpoint_dir, - state_path, - ) - entropy_model = entropy_model.to(patcher_args.patching_device) - self.entropy_model = entropy_model - else: - self.entropy_model = None - self.threshold = patcher_args.threshold - self.threshold_add = patcher_args.threshold_add - self.max_patch_length = patcher_args.max_patch_length - self.patch_size = patcher_args.patch_size - self.patching_batch_size = patcher_args.patching_batch_size - self.device = patcher_args.device - self.monotonicity = patcher_args.monotonicity - self.log_time = patcher_args.log_time - if self.log_time: - self.log = defaultdict(float) - - def patch( - self, - tokens: torch.Tensor, - include_next_token: bool = False, - preds: torch.Tensor | None = None, - entropies: torch.Tensor | None = None, - threshold: float = None, - ) -> torch.Tensor: - """ - tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched - Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.) - -> output tensor: [batch_size, max_num_patches] - each tensor is processed independently and gets right padded with zeros. - - Patching with the following modes: - 1. patching_mode = None: static patch size - 2. patching_mode = "entropy": - calculate entropy of each token, allocate patches so that the total - number of patches is the same as static patching but choose to begin - patches on tokens where the model is most uncertain (highest entropy). - - When threshold is provided, it uses the threshold to decide when to - start a new patch. - 3. patching_mode = "space": - use space like tokens to define the patches. - 4. patching_mode = "bpe": - use bpe delim tokens to define the patches. - - To correctly patch the last token, it may be necessary to include the next token in the patch - lengths calculations. This is controlled by the include_next_token argument. - """ - bs, seq_len = tokens.shape - seq_len_next_tok = seq_len + 1 if include_next_token else seq_len - scores = None - # STATIC - if self.patching_mode == PatchingModeEnum.byte: - patch_lengths = torch.ones( - (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device - ) - elif self.patching_mode == PatchingModeEnum.entropy: - if self.log_time: - s = time.time() - if entropies is not None: - scores = entropies.to(dtype=torch.float32) - elif preds is not None: - scores = entropy(preds) - else: - start_entropies = time.time() - scores, _ = calculate_entropies( - tokens, - self.entropy_model, - self.patching_batch_size, - self.device, - ) - if self.log_time: - self.log["calculate_entropies"] += time.time() - s - s = time.time() - patch_start_ids = find_entropy_patch_start_ids( - scores, - self.patch_size, - include_next_token=include_next_token, - threshold=threshold if threshold is not None else self.threshold, - threshold_add=self.threshold_add, - monotonicity=self.monotonicity, - ) - if self.log_time: - self.log["find_entropy_patch_start_ids"] += time.time() - s - s = time.time() - patch_lengths = patch_lengths_from_start_ids( - patch_start_ids, seq_len_next_tok - ) - if self.log_time: - self.log["patch_lengths_from_start_ids"] += time.time() - s - s = time.time() - else: - raise NotImplementedError(f"self.patching_mode {self.patching_mode}") - - # Apply any processing to patch lengths - if self.max_patch_length is not None: - # TODO: avoid going back to a list here. - patch_lengths = [ - split_large_numbers(pl, self.max_patch_length) - for pl in patch_lengths.tolist() - ] - max_len = max([len(pl) for pl in patch_lengths]) - patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] - patch_lengths = torch.tensor( - patch_lengths, dtype=tokens.dtype, device=tokens.device - ) - assert not check_non_zero_after_zero(patch_lengths) - # Find the last non-zero column index using argmax on a reversed version of the tensor - last_non_zero_col_reversed = ( - (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() - ) - # Slice the tensor up to the last non-zero column - patch_lengths = patch_lengths[ - :, : patch_lengths.shape[1] - last_non_zero_col_reversed - ] - assert ( - torch.sum(patch_lengths) - == tokens.numel() + include_next_token * tokens.shape[0] - ), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}" - if self.log_time: - self.log["postprocessing_patch_lengths"] += time.time() - s - self.log["tokens"] += patch_lengths.sum().item() - return patch_lengths, scores - -def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"): - with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: - reloaded = json.loads(fr.read()) - - torch.set_default_dtype(torch.bfloat16) - model_params = reloaded["entropy_model"] - logger.warning( - "Update checkpoint to load attn and sliding window args from checkpoint" - ) - entropy_model_args = LMTransformerArgs( - dim=model_params["dim"], - n_layers=model_params["n_layers"], - n_heads=model_params["n_heads"], - max_seqlen=model_params["max_seqlen"], - ffn_dim_multiplier=model_params["ffn_dim_multiplier"], - vocab_size=model_params["vocab_size"], - attn_bias_type="local_block_causal", - attn_impl="sdpa", #originally xformers - sliding_window=512, - ) - entropy_model = LMTransformer(entropy_model_args) - - entropy_model.load_state_dict( - torch.load(state_dict_path, map_location=device)["model"], strict=False - ) - entropy_model.to(device) - entropy_model = entropy_model.eval() - # no grads for the model: - for param in entropy_model.parameters(): - param.requires_grad = False - return entropy_model, entropy_model_args - def get_encoder_dim_token_emb(args): if args.dim_token is not None: @@ -1409,69 +1202,105 @@ def patch_ids_from_lengths(patch_lengths, seq_len): class LocalModelBase(nn.Module): - def __init__(self, args: LocalModelArgs): + def __init__(self, config: BLTConfig, component_type: str = "encoder"): super().__init__() - self.dim = args.dim - self.dropout = args.dropout - self.vocab_size = args.vocab_size - self.patch_size = args.patch_size - self.dim_patch_emb = args.dim_patch_emb - - self.attn_impl = args.attn_impl - self.sliding_window = args.sliding_window - self.use_rope = args.use_rope - self.init_std_factor = args.init_std_factor - self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) - self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) - self.cross_attn_k = getattr(args, "cross_attn_k", None) - self.eos_id = args.eos_id + # Use component-specific dimensions + if component_type == "encoder": + self.dim = config.dim_local_encoder + self.n_layers = config.n_layers_local_encoder + self.n_heads = config.n_heads_local_encoder + self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen + self.attn_bias_type = "local_block_causal" + self.sliding_window = config.local_attention_window_len + elif component_type == "decoder": + self.dim = config.dim_local_decoder + self.n_layers = config.n_layers_local_decoder + self.n_heads = config.n_heads_local_decoder + self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen + self.attn_bias_type = "local_block_causal" + self.sliding_window = config.local_attention_window_len + else: + raise ValueError(f"Unknown component_type: {component_type}") + + self.dropout = config.dropout + self.vocab_size = config.vocab_size + config.pm_size + self.patch_size = config.patch_size + + self.attn_impl = config.attn_impl + self.use_rope = config.use_rope + self.init_std_factor = config.init_std_factor + self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None) + self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None) + self.cross_attn_k = getattr(config, "cross_attn_k", None) + self.eos_id = config.eos_id self.boe_id = BOE_ID + + # Initialize cross attention layers as None (will be set by subclasses if needed) + self.cross_attn_layers = None + + # Create component-specific config for TransformerBlocks by copying config and overriding dimensions + component_config = type(config)(**config.to_dict()) + component_config.dim = self.dim + component_config.n_layers = self.n_layers + component_config.n_heads = self.n_heads + if hasattr(config, 'attn_bias_type'): + component_config.attn_bias_type = self.attn_bias_type + if hasattr(config, 'max_seqlen'): + component_config.max_seqlen = self.max_seqlen self.layers = nn.ModuleList( - [TransformerBlock(args) for _ in range(args.n_layers)] + [TransformerBlock(component_config) for _ in range(self.n_layers)] ) if not self.use_rope: - self.pos_embeddings = nn.Embedding(args.max_length, args.dim) + self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length else: self.rope = RotaryEmbedding( - theta=args.rope_theta, - head_dim=args.head_dim or args.dim // args.n_heads, - max_seqlen=args.max_seqlen, - rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + theta=config.rope_theta, + head_dim=config.head_dim or self.dim // self.n_heads, + max_seqlen=self.max_seqlen, + rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, ) self.pos_embeddings = None + # Set dimension-specific embedding dimensions + if component_type == "encoder": + self.dim_token_emb = get_encoder_dim_token_emb(config) + self.dim_patch_emb = get_encoder_dim_patch_emb(config) + elif component_type == "decoder": + self.dim_token_emb = get_decoder_dim_token_emb(config) + self.dim_patch_emb = config.dim_global + self.token_embedding_projection = ( - nn.Linear(args.dim_token_emb, args.dim, bias=False) - if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim + nn.Linear(self.dim_token_emb, self.dim, bias=False) + if self.dim_token_emb is not None and self.dim_token_emb != self.dim else None ) - self.patch_embedding_projection = self._create_patch_projection(args) + self.patch_embedding_projection = self._create_patch_projection(config) - def _should_create_patch_projection(self, args: LocalModelArgs): + def _should_create_patch_projection(self, config: BLTConfig): dimension_mismatch = ( - getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim + self.dim_patch_emb is not None and self.dim_patch_emb != self.dim ) # Check cross attention conditions cross_attn_conditions = ( - args.cross_attn_encoder and args.cross_attn_init_by_pooling - ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling) + config.cross_attn_encoder and config.cross_attn_init_by_pooling + ) or (config.cross_attn_decoder and config.cross_attn_init_by_pooling) return dimension_mismatch or cross_attn_conditions - def _create_patch_projection(self, args): - if not self._should_create_patch_projection(args): + def _create_patch_projection(self, config): + if not self._should_create_patch_projection(config): return None - output_dim = args.dim_token_emb * (self.cross_attn_k or 1) + output_dim = self.dim_token_emb * (self.cross_attn_k or 1) return nn.Linear( - in_features=args.dim_patch_emb, + in_features=self.dim_patch_emb, out_features=output_dim, bias=False, ) @@ -1556,22 +1385,22 @@ def init_weights(self, init_std=None): class LocalEncoder(LocalModelBase): - def __init__(self, args: LocalModelArgs): - super().__init__(args) + def __init__(self, config: BLTConfig): + super().__init__(config, component_type="encoder") - self.apply_transformer = args.use_local_encoder_transformer - self.downsampling_by_pooling = args.downsampling_by_pooling - self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None - self.cross_attn_encoder = args.cross_attn_encoder - self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder - self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling - self.cross_attn_nheads = args.cross_attn_nheads + self.apply_transformer = config.use_local_encoder_transformer + self.downsampling_by_pooling = config.downsampling_by_pooling + self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None + self.cross_attn_encoder = config.cross_attn_encoder + self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_nheads = config.cross_attn_nheads - self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim) + self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim) if self.cross_attn_encoder: self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1 + layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1 for _ in range(layers_to_add): self.cross_attn_layers.append( CrossAttention( @@ -1579,7 +1408,7 @@ def __init__(self, args: LocalModelArgs): head_dim=self.dim // self.cross_attn_nheads, n_heads=self.cross_attn_nheads, n_kv_heads=self.cross_attn_nheads, - norm_eps=args.norm_eps, + norm_eps=config.norm_eps, ) ) @@ -1691,20 +1520,20 @@ def patch_reduce(self, h, max_num_patches, reduction, patch_ids): class LocalDecoder(LocalModelBase): - def __init__(self, args: LocalModelArgs): - super().__init__(args) + def __init__(self, config: BLTConfig): + super().__init__(config, component_type="decoder") # Model configuration flags - self.cross_attn_decoder = args.cross_attn_decoder - self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder - self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling - self.cross_attn_nheads = args.cross_attn_nheads + self.cross_attn_decoder = config.cross_attn_decoder + self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_nheads = config.cross_attn_nheads - self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.norm = RMSNorm(self.dim, eps=config.norm_eps) if self.cross_attn_decoder: self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1 + layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1 for _ in range(layers_to_add): self.cross_attn_layers.append( CrossAttention( @@ -1712,13 +1541,13 @@ def __init__(self, args: LocalModelArgs): head_dim=self.dim // self.cross_attn_nheads, n_heads=self.cross_attn_nheads, n_kv_heads=self.cross_attn_nheads, - norm_eps=args.norm_eps, + norm_eps=config.norm_eps, ) ) self.output = nn.Linear( self.dim, - args.vocab_size, + config.vocab_size, bias=False, ) @@ -1912,17 +1741,17 @@ def init_weights(self, base_std: float, factor: float = 1.0): class GlobalTransformer(BaseTransformer): - def __init__(self, args: BaseTransformerArgs): - super().__init__(args) - self.dropout = args.dropout - self.eos_id = args.eos_id - self.dim_token_emb = args.dim_token_emb + def __init__(self, config): + super().__init__(config) + self.dropout = config.dropout + # eos_id is already set in BaseTransformer + self.dim_token_emb = config.dim_token_emb self.token_embedding_projection = None - if args.dim_token_emb is not None and args.dim_token_emb != self.dim: + if config.dim_token_emb is not None and config.dim_token_emb != self.dim: self.token_embedding_projection = nn.Linear( - args.dim_token_emb, - args.dim, + config.dim_token_emb, + config.dim, bias=False, ) @@ -1980,25 +1809,25 @@ class EmbeddingType(Enum): def init_embeddings( - args, + config, embedding_type: EmbeddingType, local_encoder_dim: int, encoder_hash_byte_group_size: list = None, ): if ( embedding_type == EmbeddingType.HASH_TOK - and args.encoder_hash_byte_group_size is None + and config.encoder_hash_byte_group_size is None ): return None - if embedding_type == EmbeddingType.NGRAM and args.encoder_ngram_to_size_str is None: + if embedding_type == EmbeddingType.NGRAM and config.encoder_ngram_to_size_str is None: return None embeddings = [] if embedding_type == EmbeddingType.HASH_TOK: emb_dim = local_encoder_dim - encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab - for _ in range(args.encoder_hash_byte_group_nb_functions): + encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab + for _ in range(config.encoder_hash_byte_group_nb_functions): for _ in encoder_hash_byte_group_size: embeddings.append( nn.Embedding( @@ -2008,7 +1837,7 @@ def init_embeddings( ) elif embedding_type == EmbeddingType.NGRAM: - encoder_ngram_to_size = parse_ngram_to_size(args.encoder_ngram_to_size_str) + encoder_ngram_to_size = parse_ngram_to_size(config.encoder_ngram_to_size_str) emb_dim = local_encoder_dim OFFSET = 4 # This should be passed as parameter if it's variable for ngram_vocab_size in encoder_ngram_to_size.values(): @@ -2072,9 +1901,9 @@ class ByteLatentTransformer( license_name="fair-noncommercial-research-license", license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", coders={ - ByteLatentTransformerArgs: ( - lambda x: {"args": x.model_dump()}, - lambda data: ByteLatentTransformerArgs(**data), + BLTConfig: ( + lambda x: {"config": x.to_dict()}, + lambda data: BLTConfig(**data), ) }, ): @@ -2085,145 +1914,66 @@ class ByteLatentTransformer( improved performance and inference efficiency. """ - def __init__(self, args: ByteLatentTransformerArgs): + def __init__(self, config: BLTConfig): super().__init__() + # Store config reference + self.config = config + # General configuration - self.weight_tying = args.weight_tying - self.patch_size = args.patch_size - self.patching_mode = args.patching_mode + self.weight_tying = config.weight_tying + self.patch_size = config.patch_size + self.patching_mode = config.patching_mode self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( BOE_ID, BOS_ID, PAD_ID, - EOS_ID, + config.eos_token_id, ) - self.downsampling_by_pooling = args.downsampling_by_pooling - self.patching_threshold = args.patching_threshold - self.dim = args.dim - self.init_base_std = args.init_base_std - self.init_std_factor = InitStdFactor(args.init_std_factor) - self.max_seqlen = args.max_seqlen + self.downsampling_by_pooling = config.downsampling_by_pooling + self.patching_threshold = config.patching_threshold + self.dim = config.dim + self.init_base_std = config.init_base_std + self.init_std_factor = config.init_std_factor + self.max_seqlen = config.max_seqlen # Cross attention configuration - self.cross_attn_encoder = args.cross_attn_encoder - self.cross_attn_decoder = args.cross_attn_decoder - self.cross_attn_k = args.cross_attn_k - self.cross_attn_window_encoder = args.cross_attn_window_encoder - self.cross_attn_window_decoder = args.cross_attn_window_decoder - self.cross_attn_use_flex_attention = args.cross_attn_use_flex_attention + self.cross_attn_encoder = config.cross_attn_encoder + self.cross_attn_decoder = config.cross_attn_decoder + self.cross_attn_k = config.cross_attn_k + self.cross_attn_window_encoder = config.cross_attn_window_encoder + self.cross_attn_window_decoder = config.cross_attn_window_decoder + self.cross_attn_use_flex_attention = config.cross_attn_use_flex_attention # Encoder hash configuration - self.encoder_hash_byte_group_size = args.encoder_hash_byte_group_size - self.encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab + self.encoder_hash_byte_group_size = config.encoder_hash_byte_group_size + self.encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab self.encoder_hash_byte_group_nb_functions = ( - args.encoder_hash_byte_group_nb_functions + config.encoder_hash_byte_group_nb_functions ) - # ByteLatent modules - local_encoder_args = LocalModelArgs( - # Updated args - dim=args.dim_local_encoder, - n_layers=args.n_layers_local_encoder, - n_heads=args.n_heads_local_encoder, - dim_token_emb=get_encoder_dim_token_emb(args), - dim_patch_emb=get_encoder_dim_patch_emb(args), - cross_attn_encoder=args.cross_attn_encoder, - cross_attn_decoder=False, - cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, - cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, - # Defaults - head_dim=args.head_dim, - max_seqlen=args.max_encoder_seq_length, - dropout=args.dropout, - vocab_size=args.vocab_size + args.pm_size, - norm_eps=args.norm_eps, - patch_size=args.patch_size, - sliding_window=args.local_attention_window_len, - use_rope=args.use_rope, - rope_theta=args.rope_theta, - rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, - init_base_std=args.init_base_std, - init_std_factor=args.init_std_factor, - n_kv_heads=args.n_kv_heads, - attn_impl=args.attn_impl, - attn_bias_type="local_block_causal", - multiple_of=args.multiple_of, - ffn_dim_multiplier=args.ffn_dim_multiplier, - patching_mode=args.patching_mode, - use_local_encoder_transformer=args.use_local_encoder_transformer, - downsampling_by_pooling=args.downsampling_by_pooling, - encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, - cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, - cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, - cross_attn_nheads=args.cross_attn_nheads, - eos_id=args.eos_id, - ) - self.local_encoder = LocalEncoder(local_encoder_args) - local_decoder_args = LocalModelArgs( - dim=args.dim_local_decoder, - n_layers=args.n_layers_local_decoder, - n_heads=args.n_heads_local_decoder, - dim_token_emb=get_decoder_dim_token_emb(args), - dim_patch_emb=args.dim_global, - cross_attn_encoder=False, - cross_attn_decoder=args.cross_attn_decoder, - cross_attn_init_by_pooling=False, # states are already defined - cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, - # Defaults - head_dim=args.head_dim, - max_seqlen=args.max_encoder_seq_length, - dropout=args.dropout, - vocab_size=args.vocab_size + args.pm_size, - norm_eps=args.norm_eps, - patch_size=args.patch_size, - sliding_window=args.local_attention_window_len, - use_rope=args.use_rope, - rope_theta=args.rope_theta, - rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, - init_base_std=args.init_base_std, - init_std_factor=args.init_std_factor, - n_kv_heads=args.n_kv_heads, - attn_impl=args.attn_impl, - attn_bias_type="local_block_causal", - multiple_of=args.multiple_of, - ffn_dim_multiplier=args.ffn_dim_multiplier, - patching_mode=args.patching_mode, - use_local_encoder_transformer=args.use_local_encoder_transformer, - downsampling_by_pooling=args.downsampling_by_pooling, - encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, - cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, - cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, - cross_attn_nheads=args.cross_attn_nheads, - eos_id=args.eos_id, - ) + # ByteLatent modules - pass the full config directly to local models + self.local_encoder = LocalEncoder(config) - global_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_global, - n_layers=args.n_layers_global, - n_heads=args.n_heads_global, - n_kv_heads=args.n_kv_heads_global, - local_attention_window_len=None, - dim_token_emb=get_global_dim_patch_emb(args), - dim_patch_emb=None, - cross_attn_encoder=False, - cross_attn_decoder=False, - ), - ) + # Create global-specific config by copying config and overriding dimensions + global_config = type(config)(**config.to_dict()) + global_config.dim = config.dim_global + global_config.n_layers = config.n_layers_global + global_config.n_heads = config.n_heads_global + global_config.n_kv_heads = config.n_kv_heads_global + global_config.dim_token_emb = get_global_dim_patch_emb(config) - self.global_transformer = GlobalTransformer(global_args) + self.global_transformer = GlobalTransformer(global_config) - self.local_decoder = LocalDecoder(local_decoder_args) + self.local_decoder = LocalDecoder(config) self.encoder_hash_tok_embedding = init_embeddings( - args, + config, EmbeddingType.HASH_TOK, local_encoder_dim=self.local_encoder.dim, encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, ) self.encoder_ngram_embedding = init_embeddings( - args, + config, EmbeddingType.NGRAM, local_encoder_dim=self.local_encoder.dim, encoder_hash_byte_group_size=None, @@ -2231,34 +1981,185 @@ def __init__(self, args: ByteLatentTransformerArgs): # Encoder ngram embedding tables self.encoder_ngram_embedding = None - if args.encoder_enable_byte_ngrams: + if config.encoder_enable_byte_ngrams: self.encoder_ngram_embedding = nn.ModuleList() - assert args.ngram_vocab_sizes is not None self.encoder_ngram_to_size = parse_ngram_to_size( - args.encoder_ngram_to_size_str + config.encoder_ngram_to_size_str ) ngram_emb_dim = self.local_encoder.dim - for ngram_vocab_size in self.encoderngram_to_size.values(): + for ngram_vocab_size in self.encoder_ngram_to_size.values(): self.encoder_ngram_embedding.append( nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim) ) # Output layer - assert args.vocab_size > 0, "vocab_size must be greater than 0" - - # Patcher module - if args.patch_in_forward: - self.patcher = Patcher( - PatcherArgs( - patch_size=args.patch_size, - patching_mode=args.patching_mode, - patching_threshold=args.patching_threshold, - patching_threshold_add=args.patching_threshold_add, - monotonicity=args.monotonicity, - max_patch_length=args.max_patch_length, - entropy_model_checkpoint_dir=args.entropy_model_checkpoint_dir + assert config.vocab_size > 0, "vocab_size must be greater than 0" + + # Patcher configuration + self.patch_in_forward = config.patch_in_forward + if config.patch_in_forward: + # Store patching parameters + self.patching_mode = config.patching_mode + self.patching_threshold = config.patching_threshold + self.patching_threshold_add = config.patching_threshold_add + self.monotonicity = config.monotonicity + self.max_patch_length = config.max_patch_length + self.patching_batch_size = config.patching_batch_size or 1 + self.patching_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #config.patching_device or "cuda" + + # Initialize entropy model (patcher) if realtime_patching is True + if config.realtime_patching and config.entropy_model_checkpoint_dir is not None: + # Load entropy model directly + entropy_model_checkpoint_dir = config.entropy_model_checkpoint_dir + + if not os.path.exists(entropy_model_checkpoint_dir): + raise FileNotFoundError(f"Entropy model checkpoint directory not found: {entropy_model_checkpoint_dir}") + + # Load entropy model parameters + params_path = os.path.join(entropy_model_checkpoint_dir, "params.json") + if not os.path.exists(params_path): + raise FileNotFoundError(f"params.json not found in: {entropy_model_checkpoint_dir}") + + with open(params_path) as fr: + reloaded = json.loads(fr.read()) + + torch.set_default_dtype(torch.bfloat16) + model_params = reloaded["entropy_model"] + logger.warning( + "Update checkpoint to load attn and sliding window args from checkpoint" + ) + entropy_model_config = LMTransformerConfig( + dim=model_params["dim"], + n_layers=model_params["n_layers"], + n_heads=model_params["n_heads"], + max_seqlen=model_params["max_seqlen"], + ffn_dim_multiplier=model_params["ffn_dim_multiplier"], + vocab_size=model_params["vocab_size"], + attn_bias_type="local_block_causal", + attn_impl="sdpa", # originally xformers + sliding_window=512, + ) + self.patcher = LMTransformer(entropy_model_config) + + # Load state dict + maybe_consolidated = os.path.join( + entropy_model_checkpoint_dir, + "consolidated/consolidated.pth", ) + if os.path.exists(maybe_consolidated): + state_path = maybe_consolidated + else: + state_path = os.path.join( + entropy_model_checkpoint_dir, "consolidated.pth" + ) + + if not os.path.exists(state_path): + raise FileNotFoundError(f"Model checkpoint not found at: {state_path}") + + self.patcher.load_state_dict( + torch.load(state_path, map_location=self.patching_device)["model"], strict=False + ) + self.patcher.to(self.patching_device) + self.patcher = self.patcher.eval() + # no grads for the model: + for param in self.patcher.parameters(): + param.requires_grad = False + else: + self.patcher = None + + + + def patch( + self, + tokens: torch.Tensor, + include_next_token: bool = False, + preds: torch.Tensor | None = None, + entropies: torch.Tensor | None = None, + threshold: float = None, + ) -> torch.Tensor: + """ + tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched + Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.) + -> output tensor: [batch_size, max_num_patches] + each tensor is processed independently and gets right padded with zeros. + + Patching with the following modes: + 1. patching_mode = None: static patch size + 2. patching_mode = "entropy": + calculate entropy of each token, allocate patches so that the total + number of patches is the same as static patching but choose to begin + patches on tokens where the model is most uncertain (highest entropy). + + When threshold is provided, it uses the threshold to decide when to + start a new patch. + 3. patching_mode = "space": + use space like tokens to define the patches. + 4. patching_mode = "bpe": + use bpe delim tokens to define the patches. + + To correctly patch the last token, it may be necessary to include the next token in the patch + lengths calculations. This is controlled by the include_next_token argument. + """ + bs, seq_len = tokens.shape + seq_len_next_tok = seq_len + 1 if include_next_token else seq_len + scores = None + # STATIC + if self.patching_mode == PatchingModeEnum.byte: + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device ) + elif self.patching_mode == PatchingModeEnum.entropy: + if entropies is not None: + scores = entropies.to(dtype=torch.float32) + elif preds is not None: + scores = entropy(preds) + else: + scores, _ = calculate_entropies( + tokens, + self.patcher, + self.patching_batch_size, + self.patching_device, + ) + patch_start_ids = find_entropy_patch_start_ids( + scores, + self.patch_size, + include_next_token=include_next_token, + threshold=threshold if threshold is not None else self.patching_threshold, + threshold_add=self.patching_threshold_add, + monotonicity=self.monotonicity, + ) + patch_lengths = patch_lengths_from_start_ids( + patch_start_ids, seq_len_next_tok + ) + else: + raise NotImplementedError(f"self.patching_mode {self.patching_mode}") + + # Apply any processing to patch lengths + if self.max_patch_length is not None: + # TODO: avoid going back to a list here. + patch_lengths = [ + split_large_numbers(pl, self.max_patch_length) + for pl in patch_lengths.tolist() + ] + max_len = max([len(pl) for pl in patch_lengths]) + patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] + patch_lengths = torch.tensor( + patch_lengths, dtype=tokens.dtype, device=tokens.device + ) + assert not check_non_zero_after_zero(patch_lengths) + # Find the last non-zero column index using argmax on a reversed version of the tensor + last_non_zero_col_reversed = ( + (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() + ) + # Slice the tensor up to the last non-zero column + patch_lengths = patch_lengths[ + :, : patch_lengths.shape[1] - last_non_zero_col_reversed + ] + assert ( + torch.sum(patch_lengths) + == tokens.numel() + include_next_token * tokens.shape[0] + ), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}" + return patch_lengths, scores def push_to_hub(self, *args, **kwargs): raise ValueError( @@ -2294,12 +2195,12 @@ def forward( # Patching if patch_lengths is None: assert ( - getattr(self, "patcher", None) is not None - ), "Patcher not defined and no patch_lengths passed." - patch_lengths, tok_scores = self.patcher.patch( + getattr(self, "patch_in_forward", None) is not None and self.patch_in_forward + ), "Patch in forward not enabled and no patch_lengths passed." + patch_lengths, tok_scores = self.patch( local_encoder_tokens, include_next_token=True, - threshold=self.patcher.threshold, + threshold=self.patching_threshold, ) else: if nb_boe > 0: From bdb6ceef95b2fac9d38611199df57e1d412e260d Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 6 Jun 2025 18:19:33 +0200 Subject: [PATCH 005/139] enable MPS --- src/demo_hf.py | 8 +++++--- src/transformers/models/blt_wip/modeling_blt_wip.py | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index c93599a0e09a..013b5217a69e 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -56,6 +56,7 @@ def generate( top_k: int = 0, top_p: float = 0.0, remove_prompts: bool = True, + device: torch.device = torch.device("cpu"), ) -> list[list[int]]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -113,8 +114,8 @@ def generate( -def main(prompt: str = "my name is", model_name: str = "blt-1b"): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +def main(prompt: str = "hi", model_name: str = "blt-1b"): + device = "mps" #HF blt_repo = "facebook/blt-1b" @@ -163,7 +164,8 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): prompts, model=model, tokenizer=tokenizer, - max_gen_len=100 + max_gen_len=4, + device=device ) text_outputs = [tokenizer.decode(t) for t in outputs] diff --git a/src/transformers/models/blt_wip/modeling_blt_wip.py b/src/transformers/models/blt_wip/modeling_blt_wip.py index 1b4ec71e2c9f..a14fa1ddb7ff 100644 --- a/src/transformers/models/blt_wip/modeling_blt_wip.py +++ b/src/transformers/models/blt_wip/modeling_blt_wip.py @@ -339,7 +339,7 @@ def forward( xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) assert mask is None or isinstance(mask, (str, torch.Tensor)) is_causal = (mask == "causal") if isinstance(mask, str) else False - mask = mask if isinstance(mask, torch.Tensor) else None + mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None output = F.scaled_dot_product_attention( xq, xk, @@ -1688,7 +1688,7 @@ def forward( #output = flex_attention_comp(xq, xk, xv, block_mask=mask) is_causal = (mask == "causal") if isinstance(mask, str) else False mask = mask if isinstance(mask, torch.Tensor) else None - mask = mask.to(dtype=xq.dtype) + mask = mask.to(dtype=xq.dtype).to(xq.device) output = F.scaled_dot_product_attention( xq, xk, @@ -2211,7 +2211,7 @@ def forward( # Generate patch IDs from patch_lengths patch_ids = patch_ids_from_lengths( patch_lengths, local_encoder_tokens.shape[-1] - ) + ).to(tokens.device) assert torch.max(patch_ids) + 1 <= torch.max( (patch_lengths != 0).sum(dim=-1) ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" From 131f9608f5ade466f1720c53ebbb12ea6e52bdab Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 10 Jun 2025 12:54:06 +0000 Subject: [PATCH 006/139] refactoring unused code --- .../models/blt_wip/modeling_blt_wip.py | 360 ++++++++---------- 1 file changed, 152 insertions(+), 208 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt_wip.py b/src/transformers/models/blt_wip/modeling_blt_wip.py index a14fa1ddb7ff..044ee0e7aed0 100644 --- a/src/transformers/models/blt_wip/modeling_blt_wip.py +++ b/src/transformers/models/blt_wip/modeling_blt_wip.py @@ -246,18 +246,6 @@ def forward( return self.freqs_cis[0:seqlen] -def _reshape_for_attn_bias( - attn_bias: None, - *tensors: torch.Tensor, -) -> list[torch.Tensor]: - to_transform = list(tensors) - if isinstance(attn_bias): - # could be `view` instead of reshape during training, but for inference - # have to reshape due to strides mismatch - to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform] - return to_transform - - class Attention(nn.Module): def __init__( self, @@ -450,7 +438,7 @@ def reset_parameters(self, init_std=None, factor=1.0): ) -class TransformerBlock(nn.Module): +class TransformerLayer(nn.Module): def __init__(self, args): super().__init__() @@ -515,9 +503,27 @@ def get_output_seq_len(self) -> int: pass -class BaseTransformer(nn.Module, SequenceModelWithOutput): - def __init__(self, config): + +class LMTransformer( + nn.Module, + SequenceModelWithOutput, + PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/blt", + # paper_url="https://arxiv.org/abs/2412.09871", + pipeline_tag="text-generation", + license="other", + license_name="fair-noncommercial-research-license", + license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", + coders={ + LMTransformerConfig: ( + lambda x: {"config": x.to_dict()}, + lambda data: LMTransformerConfig(**data), + ) + }, +): + def __init__(self, config: LMTransformerConfig): super().__init__() + self.dim = config.dim self.init_base_std = config.init_base_std self.attn_impl = config.attn_impl @@ -535,56 +541,9 @@ def __init__(self, config): self.layers = nn.ModuleList() for _ in range(config.n_layers): - self.layers.append(TransformerBlock(config)) - - def get_output_seq_len(self): - return self.max_seqlen - - def forward( - self, - h, - tok_idx: Optional[torch.Tensor] = None, - mask: Optional[Union[BlockMask, str]] = None, - attn_impl: str = "sdpa", - ): - - freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) - - for i, layer in enumerate(self.layers): - h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) - return h - - def init_weights(self): - self.rope_embeddings.reset_parameters() - for depth, layer in enumerate(self.layers): - factor = { - InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, - InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, - InitStdFactor.DIM_RATIO: self.dim / 4096, - InitStdFactor.DISABLED: 1.0, - }[self.init_std_factor] - - layer.init_weights(self.init_base_std, factor) - - -class LMTransformer( - BaseTransformer, - PyTorchModelHubMixin, - repo_url="https://github.com/facebookresearch/blt", - # paper_url="https://arxiv.org/abs/2412.09871", - pipeline_tag="text-generation", - license="other", - license_name="fair-noncommercial-research-license", - license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", - coders={ - LMTransformerConfig: ( - lambda x: {"config": x.to_dict()}, - lambda data: LMTransformerConfig(**data), - ) - }, -): - def __init__(self, config: LMTransformerConfig): - super().__init__(config) + self.layers.append(TransformerLayer(config)) + + # LMTransformer specific attributes self.weight_tying = config.weight_tying self.sliding_window = config.sliding_window @@ -601,7 +560,10 @@ def __init__(self, config: LMTransformerConfig): ) if config.weight_tying: - self.output.weight = self.embeddings.tok_embeddings.weight + self.output.weight = self.tok_embeddings.weight + + def get_output_seq_len(self): + return self.max_seqlen def push_to_hub(self, *args, **kwargs): raise ValueError( @@ -621,7 +583,6 @@ def forward( bsz, seqlen = token_values.shape h = self.tok_embeddings(token_values) - # attn_impl = "sdpa" mask = ( mask if mask is not None @@ -634,7 +595,11 @@ def forward( eos_id=self.eos_id, ) ) - h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + + freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) logits = self.output(self.norm(h)) if target is not None: @@ -655,7 +620,17 @@ def init_weights(self): a=-3 * init_std, b=3 * init_std, ) - super().init_weights() + + self.rope_embeddings.reset_parameters() + for depth, layer in enumerate(self.layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(self.init_base_std, factor) if not self.weight_tying: nn.init.trunc_normal_( @@ -852,8 +827,6 @@ def split_large_numbers(lst, m): return new_lst - - def get_encoder_dim_token_emb(args): if args.dim_token is not None: dim_token_emb = args.dim_token @@ -904,18 +877,6 @@ def get_decoder_dim_token_emb(args): return dim_token_emb -def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]: - if ngram_to_size_str is None: - return None - ngram_to_size = {} - for entry in ngram_to_size_str.split(","): - ngram, size = entry.split(":") - ngram = int(ngram) - size = int(size) - ngram_to_size[ngram] = size - return ngram_to_size - - def fill_tokens(tokens, patch_size, fill_id): batch_size, seq_len = tokens.shape if seq_len % patch_size == 0: @@ -1233,14 +1194,14 @@ def __init__(self, config: BLTConfig, component_type: str = "encoder"): self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None) self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None) self.cross_attn_k = getattr(config, "cross_attn_k", None) - self.eos_id = config.eos_id + self.eos_id = config.eos_token_id self.boe_id = BOE_ID # Initialize cross attention layers as None (will be set by subclasses if needed) self.cross_attn_layers = None - # Create component-specific config for TransformerBlocks by copying config and overriding dimensions + # Create component-specific config for TransformerLayers by copying config and overriding dimensions component_config = type(config)(**config.to_dict()) component_config.dim = self.dim component_config.n_layers = self.n_layers @@ -1251,7 +1212,7 @@ def __init__(self, config: BLTConfig, component_type: str = "encoder"): component_config.max_seqlen = self.max_seqlen self.layers = nn.ModuleList( - [TransformerBlock(component_config) for _ in range(self.n_layers)] + [TransformerLayer(component_config) for _ in range(self.n_layers)] ) if not self.use_rope: @@ -1455,39 +1416,27 @@ def forward( if self.cross_attn_encoder and ( i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder ): - patch_embeds = self.apply_cross_attention( - h, patch_embeds, i, bs, num_patches, patch_ids, cross_mask + # apply pooling and project + if self.cross_attn_init_by_pooling and patch_embeds is None: + patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) + if self.patch_embedding_projection is not None: + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape( + bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim + ) + + layer_idx = i if self.cross_attn_all_layers_encoder else 0 + patch_embeds_cross = self.cross_attn_layers[layer_idx]( + x=patch_embeds, + kv=h, + mask=cross_mask, ) + patch_embeds = patch_embeds + patch_embeds_cross h_residual = patch_embeds if self.cross_attn_encoder else None return (h, h_residual), cache - def apply_cross_attention( - self, h, patch_embeds, layer_idx, bs, num_patches, patch_ids, cross_mask - ): - # apply pooling and project - if self.cross_attn_init_by_pooling and patch_embeds is None: - # patch_embeds = downsample( - # h, - # num_patches, - # patch_ids=patch_ids, - # downsampling_by_pooling=self.downsampling_by_pooling, - # patch_size=self.patch_size, - # ) - patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) - if self.patch_embedding_projection is not None: - patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape( - bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim - ) - layer_idx = layer_idx if self.cross_attn_all_layers_encoder else 0 - patch_embeds_cross = self.cross_attn_layers[layer_idx]( - x=patch_embeds, - kv=h, - mask=cross_mask, - ) - return patch_embeds + patch_embeds_cross def patch_reduce(self, h, max_num_patches, reduction, patch_ids): """ @@ -1740,11 +1689,31 @@ def init_weights(self, base_std: float, factor: float = 1.0): self.cross_attn_norm_kv.reset_parameters() -class GlobalTransformer(BaseTransformer): +class GlobalTransformer(nn.Module, SequenceModelWithOutput): def __init__(self, config): - super().__init__(config) + super().__init__() + + self.dim = config.dim + self.init_base_std = config.init_base_std + self.attn_impl = config.attn_impl + self.attn_bias_type = config.attn_bias_type + self.init_std_factor = config.init_std_factor + self.max_seqlen = config.max_seqlen + self.rope_embeddings = RotaryEmbedding( + theta=config.rope_theta, + head_dim=config.head_dim or config.dim // config.n_heads, + max_seqlen=config.max_seqlen, + rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, + ) + # Handle both eos_id and eos_token_id for compatibility + self.eos_id = getattr(config, 'eos_id', getattr(config, 'eos_token_id', 2)) + + self.layers = nn.ModuleList() + for _ in range(config.n_layers): + self.layers.append(TransformerLayer(config)) + + # GlobalTransformer specific attributes self.dropout = config.dropout - # eos_id is already set in BaseTransformer self.dim_token_emb = config.dim_token_emb self.token_embedding_projection = None @@ -1754,6 +1723,9 @@ def __init__(self, config): config.dim, bias=False, ) + + def get_output_seq_len(self): + return self.max_seqlen def forward( self, @@ -1763,10 +1735,6 @@ def forward( mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - """ - Similar to BaseTransformer.forward, but with an additional embeds argument - and projection to the token space. - """ bs, seqlen = tokens.shape h = embeds @@ -1788,11 +1756,26 @@ def forward( h = F.dropout(h, p=self.dropout, training=self.training) - h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) + freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) + return h, cache def init_weights(self): - super().init_weights() + self.rope_embeddings.reset_parameters() + for depth, layer in enumerate(self.layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(self.init_base_std, factor) + + # GlobalTransformer specific initialization std = self.dim_token_emb ** (-0.5) if self.token_embedding_projection is not None: nn.init.trunc_normal_( @@ -1803,47 +1786,7 @@ def init_weights(self): b=3 * std, ) -class EmbeddingType(Enum): - HASH_TOK = auto() - NGRAM = auto() - - -def init_embeddings( - config, - embedding_type: EmbeddingType, - local_encoder_dim: int, - encoder_hash_byte_group_size: list = None, -): - if ( - embedding_type == EmbeddingType.HASH_TOK - and config.encoder_hash_byte_group_size is None - ): - return None - if embedding_type == EmbeddingType.NGRAM and config.encoder_ngram_to_size_str is None: - return None - - embeddings = [] - - if embedding_type == EmbeddingType.HASH_TOK: - emb_dim = local_encoder_dim - encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab - for _ in range(config.encoder_hash_byte_group_nb_functions): - for _ in encoder_hash_byte_group_size: - embeddings.append( - nn.Embedding( - encoder_hash_byte_group_vocab, - emb_dim, - ) - ) - elif embedding_type == EmbeddingType.NGRAM: - encoder_ngram_to_size = parse_ngram_to_size(config.encoder_ngram_to_size_str) - emb_dim = local_encoder_dim - OFFSET = 4 # This should be passed as parameter if it's variable - for ngram_vocab_size in encoder_ngram_to_size.values(): - embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim)) - - return nn.ModuleList(embeddings) def compute_hash_embeddings( @@ -1919,7 +1862,7 @@ def __init__(self, config: BLTConfig): # Store config reference self.config = config - + # General configuration self.weight_tying = config.weight_tying self.patch_size = config.patch_size @@ -1966,31 +1909,28 @@ def __init__(self, config: BLTConfig): self.global_transformer = GlobalTransformer(global_config) self.local_decoder = LocalDecoder(config) - self.encoder_hash_tok_embedding = init_embeddings( + self.encoder_hash_tok_embedding = init_hash_embeddings( config, - EmbeddingType.HASH_TOK, local_encoder_dim=self.local_encoder.dim, encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, ) - self.encoder_ngram_embedding = init_embeddings( - config, - EmbeddingType.NGRAM, - local_encoder_dim=self.local_encoder.dim, - encoder_hash_byte_group_size=None, - ) - - # Encoder ngram embedding tables - self.encoder_ngram_embedding = None - if config.encoder_enable_byte_ngrams: - self.encoder_ngram_embedding = nn.ModuleList() - self.encoder_ngram_to_size = parse_ngram_to_size( - config.encoder_ngram_to_size_str + + # NOTE: Frequency-based n-gram embeddings were experimental and removed in final model + # See paper section 3.2.1: "we subsequently moved to hash-based n-gram embeddings" + # The code below is kept for backward compatibility but should not be used + if hasattr(config, 'encoder_enable_byte_ngrams') and config.encoder_enable_byte_ngrams: + import warnings + warnings.warn( + "Frequency-based n-gram embeddings (encoder_enable_byte_ngrams) are deprecated. " + "The final BLT model uses only hash-based n-gram embeddings. " + "Consider setting encoder_enable_byte_ngrams=False.", + DeprecationWarning ) - ngram_emb_dim = self.local_encoder.dim - for ngram_vocab_size in self.encoder_ngram_to_size.values(): - self.encoder_ngram_embedding.append( - nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim) - ) + + # Remove the duplicate/unused ngram embedding initialization + # self.encoder_ngram_embedding = init_embeddings(...) # Removed + # self.encoder_ngram_embedding = None # Removed + # if config.encoder_enable_byte_ngrams: ... # Removed # Output layer assert config.vocab_size > 0, "vocab_size must be greater than 0" @@ -2067,8 +2007,6 @@ def __init__(self, config: BLTConfig): else: self.patcher = None - - def patch( self, tokens: torch.Tensor, @@ -2173,12 +2111,9 @@ def forward( self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None, - ngram_ids: Optional[torch.Tensor] = None, ): - # Ensure ngram_ids is either a tensor or None - assert ( - isinstance(ngram_ids, torch.Tensor) or ngram_ids is None - ), f"ngram_ids must be a tensor or None, but was: {type(ngram_ids)}" + # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings + # are no longer used in the final BLT model bs, N = tokens.shape # Batch size and sequence length @@ -2239,23 +2174,8 @@ def forward( encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab, ) - # N-gram table embeddings - if self.encoder_ngram_embedding is not None: - assert ngram_ids is not None, "ngram_ids must be provided" - if local_encoder_embeds is None: - local_encoder_embeds = self.local_encoder.tok_embeddings( - local_encoder_tokens - ) - assert len(ngram_ids) == len( - self.encoder_ngram_embedding - ), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(self.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}" - for i in range(ngram_ids.shape[0]): - ngram_embedding = self.encoder_ngram_embedding[i] - ngram_embeds = ngram_embedding(ngram_ids[i]) - assert ( - local_encoder_embeds.shape == ngram_embeds.shape - ), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}" - local_encoder_embeds = local_encoder_embeds + ngram_embeds + # NOTE: Frequency-based n-gram embeddings removed as per paper + # The final BLT model uses only hash-based n-gram embeddings # Local encoder (h_encoder, h_cross), cache_encoder = self.local_encoder( @@ -2335,4 +2255,28 @@ def init_weights(self): std=emb_std, a=-3 * emb_std, b=3 * emb_std, - ) \ No newline at end of file + ) + +def init_hash_embeddings( + config, + local_encoder_dim: int, + encoder_hash_byte_group_size: list, +): + """Initialize hash-based token embeddings for the BLT encoder.""" + if config.encoder_hash_byte_group_size is None: + return None + + embeddings = [] + emb_dim = local_encoder_dim + encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab + + for _ in range(config.encoder_hash_byte_group_nb_functions): + for _ in encoder_hash_byte_group_size: + embeddings.append( + nn.Embedding( + encoder_hash_byte_group_vocab, + emb_dim, + ) + ) + + return nn.ModuleList(embeddings) From fb1d11ba15ce5e7dd69d6968de0f2a8f2b566966 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 11 Jun 2025 10:53:23 +0000 Subject: [PATCH 007/139] single config class in config file --- .../models/blt_wip/configuration_blt.py | 270 +++++++++------- .../models/blt_wip/modeling_blt_wip.py | 299 ++++++++---------- 2 files changed, 291 insertions(+), 278 deletions(-) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index d6fc85789329..645906a1f2fd 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -26,9 +26,7 @@ class InitStdFactor(str, Enum): DISABLED = "disabled" # Init std is divided by 1.0 - GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) - DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 class PatchingModeEnum(str, Enum): @@ -217,6 +215,52 @@ class BLTConfig(PretrainedConfig): The id of the "end-of-sequence" token. pad_token_id (`int`, *optional*, defaults to -1): The id of the padding token. + + # Patcher/Entropy model configuration + patcher_vocab_size (`int`, *optional*, defaults to 256): + Vocabulary size for the entropy model used in patching. + patcher_dim (`int`, *optional*, defaults to 512): + Hidden dimension for the entropy model. + patcher_n_layers (`int`, *optional*, defaults to 8): + Number of layers in the entropy model. + patcher_n_heads (`int`, *optional*, defaults to 8): + Number of attention heads in the entropy model. + patcher_head_dim (`int`, *optional*): + Dimension of each attention head in the entropy model. + patcher_n_kv_heads (`int`, *optional*): + Number of key-value heads in the entropy model. + patcher_max_seqlen (`int`, *optional*, defaults to 1024): + Maximum sequence length for the entropy model. + patcher_norm_eps (`float`, *optional*, defaults to 1e-5): + Layer normalization epsilon for the entropy model. + patcher_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the entropy model. + patcher_sliding_window (`int`, *optional*): + Sliding window size for the entropy model attention. + patcher_ffn_dim_multiplier (`float`, *optional*): + Feedforward dimension multiplier for the entropy model. + patcher_multiple_of (`int`, *optional*, defaults to 256): + Make feedforward dimension multiple of this for the entropy model. + patcher_rope_theta (`float`, *optional*, defaults to 10000.0): + RoPE theta parameter for the entropy model. + patcher_rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False): + Whether to use fp32 in RoPE outer product for the entropy model. + patcher_attn_impl (`str`, *optional*, defaults to "sdpa"): + Attention implementation for the entropy model. + patcher_attn_bias_type (`str`, *optional*, defaults to "causal"): + Attention bias type for the entropy model. + patcher_init_base_std (`float`, *optional*): + Base initialization standard deviation for the entropy model. + patcher_init_std_factor (`str`, *optional*, defaults to "disabled"): + Initialization std factor for the entropy model. + patcher_dim_token_emb (`int`, *optional*): + Token embedding dimension for the entropy model. + patcher_weight_tying (`bool`, *optional*, defaults to False): + Whether to tie embeddings in the entropy model. + patcher_bos_token_id (`int`, *optional*, defaults to 1): + Beginning of sequence token id for the entropy model. + patcher_eos_token_id (`int`, *optional*, defaults to 2): + End of sequence token id for the entropy model. ```python >>> from transformers import ByteLatentTransformer, BLTConfig @@ -335,6 +379,30 @@ def __init__( eos_token_id=2, pad_token_id=-1, + # Patcher/Entropy model configuration + patcher_vocab_size=256, + patcher_dim=512, + patcher_n_layers=8, + patcher_n_heads=8, + patcher_head_dim=None, + patcher_n_kv_heads=None, + patcher_max_seqlen=1024, + patcher_norm_eps=1e-5, + patcher_dropout=0.0, + patcher_sliding_window=None, + patcher_ffn_dim_multiplier=None, + patcher_multiple_of=256, + patcher_rope_theta=10000.0, + patcher_rope_use_fp32_in_outer_product=False, + patcher_attn_impl="sdpa", + patcher_attn_bias_type="causal", + patcher_init_base_std=None, + patcher_init_std_factor="disabled", + patcher_dim_token_emb=None, + patcher_weight_tying=False, + patcher_bos_token_id=1, + patcher_eos_token_id=2, + # Inherited **kwargs, ): @@ -432,6 +500,30 @@ def __init__( # Parameter mixing self.pm_size = pm_size + + # Patcher/Entropy model configuration + self.patcher_vocab_size = patcher_vocab_size + self.patcher_dim = patcher_dim + self.patcher_n_layers = patcher_n_layers + self.patcher_n_heads = patcher_n_heads + self.patcher_head_dim = patcher_head_dim + self.patcher_n_kv_heads = patcher_n_kv_heads + self.patcher_max_seqlen = patcher_max_seqlen + self.patcher_norm_eps = patcher_norm_eps + self.patcher_dropout = patcher_dropout + self.patcher_sliding_window = patcher_sliding_window + self.patcher_ffn_dim_multiplier = patcher_ffn_dim_multiplier + self.patcher_multiple_of = patcher_multiple_of + self.patcher_rope_theta = patcher_rope_theta + self.patcher_rope_use_fp32_in_outer_product = patcher_rope_use_fp32_in_outer_product + self.patcher_attn_impl = patcher_attn_impl + self.patcher_attn_bias_type = patcher_attn_bias_type + self.patcher_init_base_std = patcher_init_base_std + self.patcher_init_std_factor = InitStdFactor(patcher_init_std_factor) + self.patcher_dim_token_emb = patcher_dim_token_emb + self.patcher_weight_tying = patcher_weight_tying + self.patcher_bos_token_id = patcher_bos_token_id + self.patcher_eos_token_id = patcher_eos_token_id # Handle hash byte group size validation if ( @@ -451,113 +543,75 @@ def __init__( **kwargs, ) + @property + def encoder_dim_token_emb(self): + """Compute encoder token embedding dimension.""" + if self.dim_token is not None: + return self.dim_token + elif self.use_local_encoder_transformer: + return self.dim_local_encoder + else: + # Use default patch_size of 8 if not set + patch_size = self.patch_size if self.patch_size is not None else 8 + return self.dim_global // patch_size -# Separate config for the LM Transformer (entropy model) -class LMTransformerConfig(PretrainedConfig): - r""" - Configuration class for the LM Transformer used as entropy model in BLT patching. - - Args: - vocab_size (`int`, *optional*, defaults to 256): - Vocabulary size of the LM model. - dim (`int`, *optional*, defaults to 512): - Dimension of the hidden representations. - n_layers (`int`, *optional*, defaults to 8): - Number of hidden layers. - n_heads (`int`, *optional*, defaults to 8): - Number of attention heads. - head_dim (`int`, *optional*): - Dimension of each attention head. - n_kv_heads (`int`, *optional*): - Number of key-value heads for grouped query attention. - max_seqlen (`int`, *optional*, defaults to 1024): - Maximum sequence length. - norm_eps (`float`, *optional*, defaults to 1e-5): - Epsilon for layer normalization. - dropout (`float`, *optional*, defaults to 0.0): - Dropout probability. - sliding_window (`int`, *optional*): - Sliding window size for attention. - ffn_dim_multiplier (`float`, *optional*): - Multiplier for feedforward dimension. - multiple_of (`int`, *optional*, defaults to 256): - Make feedforward dimension multiple of this. - rope_theta (`float`, *optional*, defaults to 10000.0): - RoPE theta parameter. - rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False): - Whether to use fp32 in RoPE outer product. - attn_impl (`str`, *optional*, defaults to "sdpa"): - Attention implementation. - attn_bias_type (`str`, *optional*, defaults to "causal"): - Attention bias type. - init_base_std (`float`, *optional*): - Base initialization standard deviation. - init_std_factor (`str`, *optional*, defaults to "disabled"): - Initialization std factor. - dim_token_emb (`int`, *optional*): - Token embedding dimension. - weight_tying (`bool`, *optional*, defaults to False): - Whether to tie embeddings. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of sequence token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of sequence token id. - """ - - model_type = "lm_transformer" - - def __init__( - self, - vocab_size=256, - dim=512, - n_layers=8, - n_heads=8, - head_dim=None, - n_kv_heads=None, - max_seqlen=1024, - norm_eps=1e-5, - dropout=0.0, - sliding_window=None, - ffn_dim_multiplier=None, - multiple_of=256, - rope_theta=10000.0, - rope_use_fp32_in_outer_product=False, - attn_impl="sdpa", - attn_bias_type="causal", - init_base_std=None, - init_std_factor="disabled", - dim_token_emb=None, - weight_tying=False, - bos_token_id=1, - eos_token_id=2, - **kwargs, - ): - self.vocab_size = vocab_size - self.dim = dim - self.n_layers = n_layers - self.n_heads = n_heads - self.head_dim = head_dim - self.n_kv_heads = n_kv_heads - self.max_seqlen = max_seqlen - self.norm_eps = norm_eps - self.dropout = dropout - self.sliding_window = sliding_window - self.ffn_dim_multiplier = ffn_dim_multiplier - self.multiple_of = multiple_of - self.rope_theta = rope_theta - self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product - self.attn_impl = attn_impl - self.attn_bias_type = attn_bias_type - self.init_base_std = init_base_std - self.init_std_factor = InitStdFactor(init_std_factor) - self.dim_token_emb = dim_token_emb - self.weight_tying = weight_tying + @property + def encoder_dim_patch_emb(self): + """Compute encoder patch embedding dimension.""" + if self.cross_attn_encoder: + if self.cross_attn_init_by_pooling: + return self.dim_local_encoder + else: + return self.dim_global + return None - super().__init__( - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - **kwargs, - ) + @property + def global_dim_patch_emb(self): + """Compute global patch embedding dimension.""" + dim_token_emb = self.encoder_dim_token_emb + if self.cross_attn_encoder: + cross_attn_k = self.cross_attn_k if self.cross_attn_k is not None else 1 + return dim_token_emb * cross_attn_k + elif ( + self.downsampling_by_pooling is None + or not self.downsampling_by_pooling + or len(self.downsampling_by_pooling) == 0 + ): + # Use default patch_size of 8 if not set + patch_size = self.patch_size if self.patch_size is not None else 8 + return dim_token_emb * patch_size + else: + return dim_token_emb * sum( + [ + pooling in self.downsampling_by_pooling + for pooling in ["avg", "min", "max"] + ] + ) + + @property + def decoder_dim_token_emb(self): + """Compute decoder token embedding dimension.""" + if self.share_encoder_decoder_emb: + return self.encoder_dim_token_emb + elif self.dim_token is not None: + return self.dim_token + else: + return self.dim_local_decoder + + def get_init_std_factor(self, depth: int) -> float: + """ + Calculate the initialization standard deviation scaling factor for a given layer depth. + + Args: + depth: Current layer depth (0-indexed) + + Returns: + Scaling factor to divide the base initialization std by + """ + if self.init_std_factor == InitStdFactor.CURRENT_DEPTH: + return (2 * (depth + 1)) ** 0.5 + else: # DISABLED + return 1.0 -__all__ = ["BLTConfig", "LMTransformerConfig", "InitStdFactor", "PatchingModeEnum"] \ No newline at end of file +__all__ = ["BLTConfig", "InitStdFactor", "PatchingModeEnum"] \ No newline at end of file diff --git a/src/transformers/models/blt_wip/modeling_blt_wip.py b/src/transformers/models/blt_wip/modeling_blt_wip.py index 044ee0e7aed0..223be89ed871 100644 --- a/src/transformers/models/blt_wip/modeling_blt_wip.py +++ b/src/transformers/models/blt_wip/modeling_blt_wip.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -from enum import Enum, auto +from enum import Enum from typing import Any, List, Optional, Tuple, Union import torch @@ -8,14 +8,12 @@ from pydantic import model_validator from torch import nn from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention -from typing_extensions import Self import json import logging import torch import torch.nn import torch.nn as nn -from pydantic import ConfigDict from torch.nn import functional as F import abc @@ -23,8 +21,6 @@ import os from contextlib import nullcontext -from pydantic import BaseModel - SEP = " " BOS_ID: int = 1 EOS_ID: int = 2 @@ -41,7 +37,6 @@ from .configuration_blt import ( BLTConfig, - LMTransformerConfig, PatchingModeEnum, InitStdFactor, ) @@ -81,13 +76,6 @@ def create_causal_mask( f"Attention {attn_impl} with {sliding_window} sliding window not implemented" ) - -class InitStdFactor(str, Enum): - DISABLED = "disabled" # Init std is divided by 1.0 - GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) - CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) - DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 - def cross_entropy(pred, target, **kwargs): return F.nll_loss( F.log_softmax(pred.flatten(end_dim=-2).float(), -1), @@ -442,31 +430,41 @@ class TransformerLayer(nn.Module): def __init__(self, args): super().__init__() - assert (args.head_dim is not None) or ( - args.n_heads is not None + # Extract parameters from dictionary + dim = args['dim'] + n_heads = args['n_heads'] + head_dim = args['head_dim'] + n_kv_heads = args['n_kv_heads'] + rope_theta = args['rope_theta'] + multiple_of = args['multiple_of'] + ffn_dim_multiplier = args['ffn_dim_multiplier'] + norm_eps = args['norm_eps'] + + assert (head_dim is not None) or ( + n_heads is not None ), "Should specify at least head_dim or n_heads" - self.head_dim = args.head_dim or args.dim // args.n_heads - self.n_heads = args.n_heads or args.dim // args.head_dim - self.n_kv_heads = args.n_kv_heads or self.n_heads + self.head_dim = head_dim or dim // n_heads + self.n_heads = n_heads or dim // head_dim + self.n_kv_heads = n_kv_heads or self.n_heads - assert args.n_heads % self.n_kv_heads == 0 - assert args.dim % args.n_heads == 0 + assert n_heads % self.n_kv_heads == 0 + assert dim % n_heads == 0 self.attention = Attention( - dim=args.dim, + dim=dim, head_dim=self.head_dim, n_heads=self.n_heads, n_kv_heads=self.n_kv_heads, - rope_theta=args.rope_theta, + rope_theta=rope_theta, ) self.feed_forward = FeedForward( - dim=args.dim, - hidden_dim=4 * args.dim, - multiple_of=args.multiple_of, - ffn_dim_multiplier=args.ffn_dim_multiplier, + dim=dim, + hidden_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, ) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = RMSNorm(dim, eps=norm_eps) + self.ffn_norm = RMSNorm(dim, eps=norm_eps) def forward( self, @@ -514,52 +512,78 @@ class LMTransformer( license="other", license_name="fair-noncommercial-research-license", license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", - coders={ - LMTransformerConfig: ( - lambda x: {"config": x.to_dict()}, - lambda data: LMTransformerConfig(**data), - ) - }, ): - def __init__(self, config: LMTransformerConfig): + def __init__(self, config): super().__init__() - self.dim = config.dim - self.init_base_std = config.init_base_std - self.attn_impl = config.attn_impl - self.attn_bias_type = config.attn_bias_type - self.init_std_factor = config.init_std_factor - self.max_seqlen = config.max_seqlen + # Store config reference for later use + self.config = config + + # Extract patcher parameters from BLTConfig + self.dim = config.patcher_dim + self.init_base_std = config.patcher_init_base_std + self.attn_impl = config.patcher_attn_impl + self.attn_bias_type = config.patcher_attn_bias_type + self.init_std_factor = config.patcher_init_std_factor + self.max_seqlen = config.patcher_max_seqlen + n_layers = config.patcher_n_layers + n_heads = config.patcher_n_heads + head_dim = config.patcher_head_dim + rope_theta = config.patcher_rope_theta + rope_use_fp32_in_outer_product = config.patcher_rope_use_fp32_in_outer_product + norm_eps = config.patcher_norm_eps + vocab_size = config.patcher_vocab_size + weight_tying = config.patcher_weight_tying + sliding_window = config.patcher_sliding_window + eos_token_id = config.patcher_eos_token_id + self.rope_embeddings = RotaryEmbedding( - theta=config.rope_theta, - head_dim=config.head_dim or config.dim // config.n_heads, - max_seqlen=config.max_seqlen, - rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, + theta=rope_theta, + head_dim=head_dim or self.dim // n_heads, + max_seqlen=self.max_seqlen, + rope_use_fp32_in_outer_product=rope_use_fp32_in_outer_product, ) # Handle both eos_id and eos_token_id for compatibility - self.eos_id = getattr(config, 'eos_id', getattr(config, 'eos_token_id', 2)) + self.eos_id = eos_token_id + + # Extract additional parameters for TransformerLayer + n_kv_heads = getattr(config, 'patcher_n_kv_heads', None) if hasattr(config, 'patcher_dim') else getattr(config, 'n_kv_heads', None) + multiple_of = getattr(config, 'patcher_multiple_of', 256) if hasattr(config, 'patcher_dim') else getattr(config, 'multiple_of', 256) + ffn_dim_multiplier = getattr(config, 'patcher_ffn_dim_multiplier', None) if hasattr(config, 'patcher_dim') else getattr(config, 'ffn_dim_multiplier', None) + + # Create a simple parameter dict for TransformerLayer + layer_params = { + 'dim': self.dim, + 'n_heads': n_heads, + 'head_dim': head_dim, + 'n_kv_heads': n_kv_heads, + 'rope_theta': rope_theta, + 'multiple_of': multiple_of, + 'ffn_dim_multiplier': ffn_dim_multiplier, + 'norm_eps': norm_eps, + } self.layers = nn.ModuleList() - for _ in range(config.n_layers): - self.layers.append(TransformerLayer(config)) + for _ in range(n_layers): + self.layers.append(TransformerLayer(layer_params)) # LMTransformer specific attributes - self.weight_tying = config.weight_tying - self.sliding_window = config.sliding_window + self.weight_tying = weight_tying + self.sliding_window = sliding_window - assert config.vocab_size > 0 + assert vocab_size > 0 - self.tok_embeddings = torch.nn.Embedding(config.vocab_size, config.dim) + self.tok_embeddings = torch.nn.Embedding(vocab_size, self.dim) - self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.norm = RMSNorm(self.dim, eps=norm_eps) self.output = nn.Linear( - config.dim, - config.vocab_size, + self.dim, + vocab_size, bias=False, ) - if config.weight_tying: + if self.weight_tying: self.output.weight = self.tok_embeddings.weight def get_output_seq_len(self): @@ -623,13 +647,7 @@ def init_weights(self): self.rope_embeddings.reset_parameters() for depth, layer in enumerate(self.layers): - factor = { - InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, - InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, - InitStdFactor.DIM_RATIO: self.dim / 4096, - InitStdFactor.DISABLED: 1.0, - }[self.init_std_factor] - + factor = self.config.get_init_std_factor(depth) layer.init_weights(self.init_base_std, factor) if not self.weight_tying: @@ -827,56 +845,6 @@ def split_large_numbers(lst, m): return new_lst -def get_encoder_dim_token_emb(args): - if args.dim_token is not None: - dim_token_emb = args.dim_token - elif args.use_local_encoder_transformer: - dim_token_emb = args.dim_local_encoder - else: - dim_token_emb = args.dim_global // args.patch_size - return dim_token_emb - - -def get_encoder_dim_patch_emb(args): - dim_patch_emb = None - if args.cross_attn_encoder: - if args.cross_attn_init_by_pooling: - dim_patch_emb = args.dim_local_encoder - else: - dim_patch_emb = args.dim_global - return dim_patch_emb - - -def get_global_dim_patch_emb(args): - dim_token_emb = get_encoder_dim_token_emb(args) - if args.cross_attn_encoder: - dim_patch_emb = dim_token_emb * args.cross_attn_k - elif ( - args.downsampling_by_pooling is None - or not args.downsampling_by_pooling - or len(args.downsampling_by_pooling) == 0 - ): - dim_patch_emb = dim_token_emb * args.patch_size - else: - dim_patch_emb = dim_token_emb * sum( - [ - pooling in args.downsampling_by_pooling - for pooling in ["avg", "min", "max"] - ] - ) - return dim_patch_emb - - -def get_decoder_dim_token_emb(args): - if args.share_encoder_decoder_emb: - dim_token_emb = get_encoder_dim_token_emb(args) - elif args.dim_token is not None: - dim_token_emb = args.dim_token - else: - dim_token_emb = args.dim_local_decoder - return dim_token_emb - - def fill_tokens(tokens, patch_size, fill_id): batch_size, seq_len = tokens.shape if seq_len % patch_size == 0: @@ -941,14 +909,6 @@ def byte_group_hash_function( """ with torch.no_grad(): bs, seq_len = x.shape - # x_numpy = x.numpy() - # hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False) - # for i in range(bs): - # for j in range(seq_len): - # start = max(j, j-group_size+1) - # end = j+1 - # hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash) - prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) x = torch.cat([prefix, x], dim=1) windows = x.unfold(1, group_size, 1) @@ -1201,18 +1161,20 @@ def __init__(self, config: BLTConfig, component_type: str = "encoder"): # Initialize cross attention layers as None (will be set by subclasses if needed) self.cross_attn_layers = None - # Create component-specific config for TransformerLayers by copying config and overriding dimensions - component_config = type(config)(**config.to_dict()) - component_config.dim = self.dim - component_config.n_layers = self.n_layers - component_config.n_heads = self.n_heads - if hasattr(config, 'attn_bias_type'): - component_config.attn_bias_type = self.attn_bias_type - if hasattr(config, 'max_seqlen'): - component_config.max_seqlen = self.max_seqlen + # Create parameter dict for TransformerLayers + layer_params = { + 'dim': self.dim, + 'n_heads': self.n_heads, + 'head_dim': config.head_dim, + 'n_kv_heads': getattr(config, 'n_kv_heads', None), + 'rope_theta': config.rope_theta, + 'multiple_of': getattr(config, 'multiple_of', 256), + 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None), + 'norm_eps': config.norm_eps, + } self.layers = nn.ModuleList( - [TransformerLayer(component_config) for _ in range(self.n_layers)] + [TransformerLayer(layer_params) for _ in range(self.n_layers)] ) if not self.use_rope: @@ -1228,10 +1190,10 @@ def __init__(self, config: BLTConfig, component_type: str = "encoder"): # Set dimension-specific embedding dimensions if component_type == "encoder": - self.dim_token_emb = get_encoder_dim_token_emb(config) - self.dim_patch_emb = get_encoder_dim_patch_emb(config) + self.dim_token_emb = config.encoder_dim_token_emb + self.dim_patch_emb = config.encoder_dim_patch_emb elif component_type == "decoder": - self.dim_token_emb = get_decoder_dim_token_emb(config) + self.dim_token_emb = config.decoder_dim_token_emb self.dim_patch_emb = config.dim_global self.token_embedding_projection = ( @@ -1296,14 +1258,8 @@ def init_weights(self, init_std=None): ) for depth, layer in enumerate(self.layers): - factor = { - InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, - InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, - InitStdFactor.DIM_RATIO: self.dim / 4096, - InitStdFactor.DISABLED: 1.0, - }[self.init_std_factor] - - layer.init_weights(None, factor) + factor = self.config.get_init_std_factor(depth) + layer.init_weights(self.init_base_std, factor) if hasattr(self, "output"): nn.init.trunc_normal_( @@ -1335,13 +1291,7 @@ def init_weights(self, init_std=None): if self.cross_attn_layers is not None: for depth, layer in enumerate(self.cross_attn_layers): - factor = { - InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, - InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, - InitStdFactor.DIM_RATIO: self.dim / 4096, - InitStdFactor.DISABLED: 1.0, - }[self.init_std_factor] - + factor = self.config.get_init_std_factor(depth) layer.init_weights(None, factor) @@ -1708,9 +1658,21 @@ def __init__(self, config): # Handle both eos_id and eos_token_id for compatibility self.eos_id = getattr(config, 'eos_id', getattr(config, 'eos_token_id', 2)) + # Create parameter dict for TransformerLayers + layer_params = { + 'dim': self.dim, + 'n_heads': config.n_heads, + 'head_dim': config.head_dim, + 'n_kv_heads': getattr(config, 'n_kv_heads', None), + 'rope_theta': config.rope_theta, + 'multiple_of': getattr(config, 'multiple_of', 256), + 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None), + 'norm_eps': config.norm_eps, + } + self.layers = nn.ModuleList() for _ in range(config.n_layers): - self.layers.append(TransformerLayer(config)) + self.layers.append(TransformerLayer(layer_params)) # GlobalTransformer specific attributes self.dropout = config.dropout @@ -1766,13 +1728,7 @@ def forward( def init_weights(self): self.rope_embeddings.reset_parameters() for depth, layer in enumerate(self.layers): - factor = { - InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, - InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, - InitStdFactor.DIM_RATIO: self.dim / 4096, - InitStdFactor.DISABLED: 1.0, - }[self.init_std_factor] - + factor = self.config.get_init_std_factor(depth) layer.init_weights(self.init_base_std, factor) # GlobalTransformer specific initialization @@ -1904,7 +1860,7 @@ def __init__(self, config: BLTConfig): global_config.n_layers = config.n_layers_global global_config.n_heads = config.n_heads_global global_config.n_kv_heads = config.n_kv_heads_global - global_config.dim_token_emb = get_global_dim_patch_emb(config) + global_config.dim_token_emb = config.global_dim_patch_emb self.global_transformer = GlobalTransformer(global_config) @@ -1968,18 +1924,21 @@ def __init__(self, config: BLTConfig): logger.warning( "Update checkpoint to load attn and sliding window args from checkpoint" ) - entropy_model_config = LMTransformerConfig( - dim=model_params["dim"], - n_layers=model_params["n_layers"], - n_heads=model_params["n_heads"], - max_seqlen=model_params["max_seqlen"], - ffn_dim_multiplier=model_params["ffn_dim_multiplier"], - vocab_size=model_params["vocab_size"], - attn_bias_type="local_block_causal", - attn_impl="sdpa", # originally xformers - sliding_window=512, - ) - self.patcher = LMTransformer(entropy_model_config) + + # Override patcher configuration with actual entropy model parameters from checkpoint + config.patcher_dim = model_params["dim"] + config.patcher_n_layers = model_params["n_layers"] + config.patcher_n_heads = model_params["n_heads"] + config.patcher_max_seqlen = model_params["max_seqlen"] + config.patcher_ffn_dim_multiplier = model_params["ffn_dim_multiplier"] + config.patcher_vocab_size = model_params["vocab_size"] + # Use sensible defaults for parameters not in checkpoint + config.patcher_attn_bias_type = "local_block_causal" + config.patcher_attn_impl = "sdpa" # originally xformers + config.patcher_sliding_window = 512 + + # LMTransformer will extract patcher_ parameters from config directly + self.patcher = LMTransformer(config) # Load state dict maybe_consolidated = os.path.join( From 1eab6a4ae13690d9523b6240eab2b3586e479da2 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 11 Jun 2025 15:43:02 +0000 Subject: [PATCH 008/139] inherit from PreTrainedModel --- src/demo_hf.py | 37 ++- .../models/blt_wip/modeling_blt_wip.py | 269 +++++++----------- 2 files changed, 127 insertions(+), 179 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index 013b5217a69e..e6b566c43668 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -6,7 +6,7 @@ import torch -from transformers.models.blt_wip.modeling_blt_wip import ByteLatentTransformer +from transformers.models.blt_wip.modeling_blt_wip import BLTModel from transformers.models.blt_wip.configuration_blt import BLTConfig from transformers.models.blt_wip.tokenizers.blt_tokenizer import BltTokenizer @@ -47,7 +47,7 @@ def sample_top_p(probs, p): def generate( prompts: list[str] | None, *, - model: ByteLatentTransformer, + model: BLTModel, tokenizer: BltTokenizer, max_prompt_len: int = 256, max_gen_len: int = 256, @@ -114,8 +114,8 @@ def generate( -def main(prompt: str = "hi", model_name: str = "blt-1b"): - device = "mps" +def main(prompt: str = "my name is", model_name: str = "blt-1b"): + device = "cuda" #HF blt_repo = "facebook/blt-1b" @@ -134,24 +134,35 @@ def main(prompt: str = "hi", model_name: str = "blt-1b"): config['args']['attn_bias_type'] = 'causal' config['args']['attn_impl'] = 'sdpa' - - model_config = BLTConfig(**config["args"]) + # Create model using normal constructor instead of from_pretrained + from transformers.models.blt_wip.configuration_blt import BLTConfig + model_config = BLTConfig(**config['args']) + + # Set entropy model parameters manually patcher_args = entropy_params["data"]["patcher_args"] model_config.patch_in_forward = True - model_config.realtime_patching = True # Enable realtime patching + model_config.realtime_patching = True model_config.patch_size = patcher_args["patch_size"] - model_config.patching_mode = "entropy" #patcher_args["patching_mode"] #TODO: we need to pass "entropy" to run through the Patcher / "entropy model", which is the LMTransformer + model_config.patching_mode = "entropy" model_config.patching_threshold = patcher_args["threshold"] model_config.patching_threshold_add = patcher_args["threshold_add"] model_config.max_patch_length = patcher_args["max_patch_length"] model_config.patching_batch_size = patcher_args["patching_batch_size"] model_config.patching_device = patcher_args["patching_device"] model_config.monotonicity = patcher_args["monotonicity"] - model_config.entropy_model_checkpoint_dir = entropy_dir #original config on the hub don't set this - - - model = ByteLatentTransformer.from_pretrained(blt_repo, config=model_config).to(device) + model_config.entropy_model_checkpoint_dir = entropy_dir + + # Use direct construction instead of from_pretrained to avoid meta tensor issues + print("Creating model...") + model = BLTModel(model_config).to(device) + + # Load model weights manually + print("Loading model weights...") + from safetensors.torch import load_file + checkpoint_path = hf_hub_download(repo_id=blt_repo, filename="model.safetensors") + state_dict = load_file(checkpoint_path) + model.load_state_dict(state_dict, strict=False) tokenizer = BltTokenizer( vocab_size_unit_1=model_config.vocab_size, @@ -164,7 +175,7 @@ def main(prompt: str = "hi", model_name: str = "blt-1b"): prompts, model=model, tokenizer=tokenizer, - max_gen_len=4, + max_gen_len=200, device=device ) diff --git a/src/transformers/models/blt_wip/modeling_blt_wip.py b/src/transformers/models/blt_wip/modeling_blt_wip.py index 223be89ed871..6020bdfe898b 100644 --- a/src/transformers/models/blt_wip/modeling_blt_wip.py +++ b/src/transformers/models/blt_wip/modeling_blt_wip.py @@ -16,8 +16,6 @@ import torch.nn as nn from torch.nn import functional as F -import abc - import os from contextlib import nullcontext @@ -41,6 +39,9 @@ InitStdFactor, ) +from ...modeling_utils import PreTrainedModel +from ...utils import logging as transformers_logging + flex_attention_comp = flex_attention @@ -495,16 +496,12 @@ def init_weights(self, init_std=None, factor=1.0): self.ffn_norm.reset_parameters() -class SequenceModelWithOutput(abc.ABC): - @abc.abstractmethod - def get_output_seq_len(self) -> int: - pass + class LMTransformer( nn.Module, - SequenceModelWithOutput, PyTorchModelHubMixin, repo_url="https://github.com/facebookresearch/blt", # paper_url="https://arxiv.org/abs/2412.09871", @@ -585,9 +582,6 @@ def __init__(self, config): if self.weight_tying: self.output.weight = self.tok_embeddings.weight - - def get_output_seq_len(self): - return self.max_seqlen def push_to_hub(self, *args, **kwargs): raise ValueError( @@ -1125,8 +1119,11 @@ def patch_ids_from_lengths(patch_lengths, seq_len): class LocalModelBase(nn.Module): def __init__(self, config: BLTConfig, component_type: str = "encoder"): super().__init__() + + # Store config for later use + self.config = config - # Use component-specific dimensions + # Use component-specific dimensions if component_type == "encoder": self.dim = config.dim_local_encoder self.n_layers = config.n_layers_local_encoder @@ -1151,6 +1148,7 @@ def __init__(self, config: BLTConfig, component_type: str = "encoder"): self.attn_impl = config.attn_impl self.use_rope = config.use_rope self.init_std_factor = config.init_std_factor + self.init_base_std = config.init_base_std self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None) self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None) self.cross_attn_k = getattr(config, "cross_attn_k", None) @@ -1639,10 +1637,13 @@ def init_weights(self, base_std: float, factor: float = 1.0): self.cross_attn_norm_kv.reset_parameters() -class GlobalTransformer(nn.Module, SequenceModelWithOutput): +class GlobalTransformer(nn.Module): def __init__(self, config): super().__init__() + # Store config for later use + self.config = config + self.dim = config.dim self.init_base_std = config.init_base_std self.attn_impl = config.attn_impl @@ -1685,9 +1686,6 @@ def __init__(self, config): config.dim, bias=False, ) - - def get_output_seq_len(self): - return self.max_seqlen def forward( self, @@ -1789,69 +1787,37 @@ def compute_hash_embeddings( return local_encoder_embeds -class ByteLatentTransformer( - nn.Module, - SequenceModelWithOutput, - PyTorchModelHubMixin, - repo_url="https://github.com/facebookresearch/blt", - # paper_url="https://arxiv.org/abs/2412.09871", - pipeline_tag="text-generation", - license="other", - license_name="fair-noncommercial-research-license", - license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", - coders={ - BLTConfig: ( - lambda x: {"config": x.to_dict()}, - lambda data: BLTConfig(**data), - ) - }, -): +class BLTPreTrainedModel(PreTrainedModel): + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["TransformerLayer", "LocalEncoder", "LocalDecoder", "GlobalTransformer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False # BLT uses its own attention implementation + _supports_sdpa = True + _supports_cache_class = False + + def _init_weights(self, module): + """Initialize the weights - this is called by PreTrainedModel but we delegate to our custom init""" + # Don't do anything here - we use the custom init_weights method instead + pass + + +class BLTModel(BLTPreTrainedModel): """ - The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences - by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, - and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for + The BLTModel (BLT) is a byte-level language model architecture that processes byte sequences + by dynamically segmenting them into patches. It uses a combination of local encoder/decoder and aglobal transformer + to efficiently encode and decode byte sequences, leveraging patch-based processing for improved performance and inference efficiency. """ def __init__(self, config: BLTConfig): - super().__init__() + super().__init__(config) # Store config reference self.config = config - # General configuration - self.weight_tying = config.weight_tying - self.patch_size = config.patch_size - self.patching_mode = config.patching_mode - self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( - BOE_ID, - BOS_ID, - PAD_ID, - config.eos_token_id, - ) - self.downsampling_by_pooling = config.downsampling_by_pooling - self.patching_threshold = config.patching_threshold - self.dim = config.dim - self.init_base_std = config.init_base_std - self.init_std_factor = config.init_std_factor - self.max_seqlen = config.max_seqlen - - # Cross attention configuration - self.cross_attn_encoder = config.cross_attn_encoder - self.cross_attn_decoder = config.cross_attn_decoder - self.cross_attn_k = config.cross_attn_k - self.cross_attn_window_encoder = config.cross_attn_window_encoder - self.cross_attn_window_decoder = config.cross_attn_window_decoder - self.cross_attn_use_flex_attention = config.cross_attn_use_flex_attention - - # Encoder hash configuration - self.encoder_hash_byte_group_size = config.encoder_hash_byte_group_size - self.encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab - self.encoder_hash_byte_group_nb_functions = ( - config.encoder_hash_byte_group_nb_functions - ) - - # ByteLatent modules - pass the full config directly to local models + # Create main components - they will read their parameters from config self.local_encoder = LocalEncoder(config) # Create global-specific config by copying config and overriding dimensions @@ -1863,47 +1829,17 @@ def __init__(self, config: BLTConfig): global_config.dim_token_emb = config.global_dim_patch_emb self.global_transformer = GlobalTransformer(global_config) - self.local_decoder = LocalDecoder(config) + + # Initialize hash embeddings self.encoder_hash_tok_embedding = init_hash_embeddings( config, local_encoder_dim=self.local_encoder.dim, - encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, + encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, ) - - # NOTE: Frequency-based n-gram embeddings were experimental and removed in final model - # See paper section 3.2.1: "we subsequently moved to hash-based n-gram embeddings" - # The code below is kept for backward compatibility but should not be used - if hasattr(config, 'encoder_enable_byte_ngrams') and config.encoder_enable_byte_ngrams: - import warnings - warnings.warn( - "Frequency-based n-gram embeddings (encoder_enable_byte_ngrams) are deprecated. " - "The final BLT model uses only hash-based n-gram embeddings. " - "Consider setting encoder_enable_byte_ngrams=False.", - DeprecationWarning - ) - - # Remove the duplicate/unused ngram embedding initialization - # self.encoder_ngram_embedding = init_embeddings(...) # Removed - # self.encoder_ngram_embedding = None # Removed - # if config.encoder_enable_byte_ngrams: ... # Removed - - # Output layer - assert config.vocab_size > 0, "vocab_size must be greater than 0" - # Patcher configuration - self.patch_in_forward = config.patch_in_forward + # Initialize patcher if needed if config.patch_in_forward: - # Store patching parameters - self.patching_mode = config.patching_mode - self.patching_threshold = config.patching_threshold - self.patching_threshold_add = config.patching_threshold_add - self.monotonicity = config.monotonicity - self.max_patch_length = config.max_patch_length - self.patching_batch_size = config.patching_batch_size or 1 - self.patching_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #config.patching_device or "cuda" - - # Initialize entropy model (patcher) if realtime_patching is True if config.realtime_patching and config.entropy_model_checkpoint_dir is not None: # Load entropy model directly entropy_model_checkpoint_dir = config.entropy_model_checkpoint_dir @@ -1939,32 +1875,30 @@ def __init__(self, config: BLTConfig): # LMTransformer will extract patcher_ parameters from config directly self.patcher = LMTransformer(config) - - # Load state dict - maybe_consolidated = os.path.join( - entropy_model_checkpoint_dir, - "consolidated/consolidated.pth", - ) - if os.path.exists(maybe_consolidated): - state_path = maybe_consolidated - else: - state_path = os.path.join( - entropy_model_checkpoint_dir, "consolidated.pth" - ) - if not os.path.exists(state_path): - raise FileNotFoundError(f"Model checkpoint not found at: {state_path}") + state_path = os.path.join( + entropy_model_checkpoint_dir, "consolidated.pth" + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.patcher.load_state_dict( - torch.load(state_path, map_location=self.patching_device)["model"], strict=False + torch.load(state_path, map_location=device)["model"], strict=False ) - self.patcher.to(self.patching_device) + self.patcher.to(device) self.patcher = self.patcher.eval() # no grads for the model: for param in self.patcher.parameters(): param.requires_grad = False else: self.patcher = None + + # Initialize weights and apply final processing + self.post_init() + + @property + def patch_in_forward(self): + """Backward compatibility property for accessing patch_in_forward from config.""" + return self.config.patch_in_forward def patch( self, @@ -2001,11 +1935,11 @@ def patch( seq_len_next_tok = seq_len + 1 if include_next_token else seq_len scores = None # STATIC - if self.patching_mode == PatchingModeEnum.byte: + if self.config.patching_mode == PatchingModeEnum.byte: patch_lengths = torch.ones( (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device ) - elif self.patching_mode == PatchingModeEnum.entropy: + elif self.config.patching_mode == PatchingModeEnum.entropy: if entropies is not None: scores = entropies.to(dtype=torch.float32) elif preds is not None: @@ -2014,28 +1948,28 @@ def patch( scores, _ = calculate_entropies( tokens, self.patcher, - self.patching_batch_size, - self.patching_device, + self.config.patching_batch_size, + self.config.patching_device, ) patch_start_ids = find_entropy_patch_start_ids( scores, - self.patch_size, + self.config.patch_size, include_next_token=include_next_token, - threshold=threshold if threshold is not None else self.patching_threshold, - threshold_add=self.patching_threshold_add, - monotonicity=self.monotonicity, + threshold=threshold if threshold is not None else self.config.patching_threshold, + threshold_add=self.config.patching_threshold_add, + monotonicity=self.config.monotonicity, ) patch_lengths = patch_lengths_from_start_ids( patch_start_ids, seq_len_next_tok ) else: - raise NotImplementedError(f"self.patching_mode {self.patching_mode}") + raise NotImplementedError(f"self.config.patching_mode {self.config.patching_mode}") # Apply any processing to patch lengths - if self.max_patch_length is not None: + if self.config.max_patch_length is not None: # TODO: avoid going back to a list here. patch_lengths = [ - split_large_numbers(pl, self.max_patch_length) + split_large_numbers(pl, self.config.max_patch_length) for pl in patch_lengths.tolist() ] max_len = max([len(pl) for pl in patch_lengths]) @@ -2058,14 +1992,6 @@ def patch( ), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}" return patch_lengths, scores - def push_to_hub(self, *args, **kwargs): - raise ValueError( - "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." - ) - - def get_output_seq_len(self): - return self.max_seqlen - def forward( self, tokens: torch.Tensor, @@ -2077,24 +2003,24 @@ def forward( bs, N = tokens.shape # Batch size and sequence length # Get megabyte inputs - nb_boe = int(0 if self.patching_mode != "" else self.patch_size - 1) + nb_boe = int(0 if self.config.patching_mode != "" else self.config.patch_size - 1) local_encoder_tokens, _, local_decoder_tokens = get_blt_input( tokens=tokens, enforce_patch_size_multiple=False, nb_boe=nb_boe, - patch_size=self.patch_size, - boe_id=self.boe_id, + patch_size=self.config.patch_size, + boe_id=BOE_ID, ) # Patching if patch_lengths is None: assert ( - getattr(self, "patch_in_forward", None) is not None and self.patch_in_forward + getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward ), "Patch in forward not enabled and no patch_lengths passed." patch_lengths, tok_scores = self.patch( local_encoder_tokens, include_next_token=True, - threshold=self.patching_threshold, + threshold=self.config.patching_threshold, ) else: if nb_boe > 0: @@ -2112,15 +2038,15 @@ def forward( cross_attn_mask_enc = None # Cross-attention encoder - if self.cross_attn_encoder: + if self.config.cross_attn_encoder: cross_attn_mask_enc = cross_attn_mask( patch_ids, patch_lengths, N, patches_as_queries=True, - cross_attn_k=self.cross_attn_k, - window=self.cross_attn_window_encoder, - block_mask=self.cross_attn_use_flex_attention, + cross_attn_k=self.config.cross_attn_k, + window=self.config.cross_attn_window_encoder, + block_mask=self.config.cross_attn_use_flex_attention, ) # Hashing and embedding @@ -2128,9 +2054,9 @@ def forward( local_encoder_tokens=local_encoder_tokens, local_encoder=self.local_encoder, encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, - encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions, - encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, - encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab, + encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab, ) # NOTE: Frequency-based n-gram embeddings removed as per paper @@ -2150,10 +2076,10 @@ def forward( h = h_cross.view(bs, patch_lengths.shape[1], -1) # Global transformer - global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id) - rows, cols = torch.where(local_encoder_tokens == self.eos_id) + global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(BOE_ID) + rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id) eos_patch_ids = patch_ids[rows, cols] - global_tokens[rows, eos_patch_ids] = self.eos_id + global_tokens[rows, eos_patch_ids] = self.config.eos_token_id h, _ = self.global_transformer( embeds=h, @@ -2175,7 +2101,7 @@ def forward( ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" # Cross-attention decoder - if not self.cross_attn_decoder: + if not self.config.cross_attn_decoder: h = torch.gather( h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) ) @@ -2187,9 +2113,9 @@ def forward( patch_lengths, N, patches_as_queries=False, - cross_attn_k=self.cross_attn_k, - window=self.cross_attn_window_decoder, - block_mask=self.cross_attn_use_flex_attention, + cross_attn_k=self.config.cross_attn_k, + window=self.config.cross_attn_window_decoder, + block_mask=self.config.cross_attn_use_flex_attention, ) # Local decoder @@ -2206,15 +2132,16 @@ def init_weights(self): self.global_transformer.init_weights() self.local_decoder.init_weights() - emb_std = self.local_encoder.dim ** (-0.5) - for emb in self.encoder_hash_tok_embedding: - nn.init.trunc_normal_( - emb.weight, - mean=0.0, - std=emb_std, - a=-3 * emb_std, - b=3 * emb_std, - ) + if self.encoder_hash_tok_embedding is not None: + emb_std = self.local_encoder.dim ** (-0.5) + for emb in self.encoder_hash_tok_embedding: + nn.init.trunc_normal_( + emb.weight, + mean=0.0, + std=emb_std, + a=-3 * emb_std, + b=3 * emb_std, + ) def init_hash_embeddings( config, @@ -2239,3 +2166,13 @@ def init_hash_embeddings( ) return nn.ModuleList(embeddings) + + +__all__ = [ + "BLTPreTrainedModel", + "BLTModel", + "LMTransformer", + "LocalEncoder", + "LocalDecoder", + "GlobalTransformer", +] From bc2aeb746082082bcc34cac786adedc15ca3f1c0 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 12 Jun 2025 13:01:08 +0000 Subject: [PATCH 009/139] refactor LMTransformer --> BLTPatcher --- src/demo_hf.py | 6 +- .../models/blt_wip/modeling_blt_wip.py | 976 +++++++++--------- 2 files changed, 483 insertions(+), 499 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index e6b566c43668..f8db4eaf117d 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -60,9 +60,6 @@ def generate( ) -> list[list[int]]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - assert ( - model.patch_in_forward - ), "generate requires model.patch_in_forward=True" model.eval() prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts] # Truncation @@ -81,8 +78,7 @@ def generate( for i, curr_pos in enumerate(range(start_pos, end_pos)): current_tokens = tokens[:, :curr_pos] - patch_lengths, _ = model.patch(current_tokens, include_next_token=True) - logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1] + logits = model(current_tokens)[:, -1] if use_sampling: probs = torch.softmax(logits / temp, dim=-1) diff --git a/src/transformers/models/blt_wip/modeling_blt_wip.py b/src/transformers/models/blt_wip/modeling_blt_wip.py index 6020bdfe898b..b8eedb21498d 100644 --- a/src/transformers/models/blt_wip/modeling_blt_wip.py +++ b/src/transformers/models/blt_wip/modeling_blt_wip.py @@ -4,7 +4,6 @@ from typing import Any, List, Optional, Tuple, Union import torch -from huggingface_hub import PyTorchModelHubMixin from pydantic import model_validator from torch import nn from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention @@ -235,7 +234,7 @@ def forward( return self.freqs_cis[0:seqlen] -class Attention(nn.Module): +class BLTAttention(nn.Module): def __init__( self, dim: int, @@ -357,7 +356,7 @@ def reset_parameters(self, init_std=None, factor=1.0): ) -class FeedForward(nn.Module): +class BLTMLP(nn.Module): def __init__( self, dim: int, @@ -427,7 +426,7 @@ def reset_parameters(self, init_std=None, factor=1.0): ) -class TransformerLayer(nn.Module): +class BLTTransformerLayer(nn.Module): def __init__(self, args): super().__init__() @@ -451,14 +450,14 @@ def __init__(self, args): assert n_heads % self.n_kv_heads == 0 assert dim % n_heads == 0 - self.attention = Attention( + self.attention = BLTAttention( dim=dim, head_dim=self.head_dim, n_heads=self.n_heads, n_kv_heads=self.n_kv_heads, rope_theta=rope_theta, ) - self.feed_forward = FeedForward( + self.feed_forward = BLTMLP( dim=dim, hidden_dim=4 * dim, multiple_of=multiple_of, @@ -496,164 +495,6 @@ def init_weights(self, init_std=None, factor=1.0): self.ffn_norm.reset_parameters() - - - - -class LMTransformer( - nn.Module, - PyTorchModelHubMixin, - repo_url="https://github.com/facebookresearch/blt", - # paper_url="https://arxiv.org/abs/2412.09871", - pipeline_tag="text-generation", - license="other", - license_name="fair-noncommercial-research-license", - license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", -): - def __init__(self, config): - super().__init__() - - # Store config reference for later use - self.config = config - - # Extract patcher parameters from BLTConfig - self.dim = config.patcher_dim - self.init_base_std = config.patcher_init_base_std - self.attn_impl = config.patcher_attn_impl - self.attn_bias_type = config.patcher_attn_bias_type - self.init_std_factor = config.patcher_init_std_factor - self.max_seqlen = config.patcher_max_seqlen - n_layers = config.patcher_n_layers - n_heads = config.patcher_n_heads - head_dim = config.patcher_head_dim - rope_theta = config.patcher_rope_theta - rope_use_fp32_in_outer_product = config.patcher_rope_use_fp32_in_outer_product - norm_eps = config.patcher_norm_eps - vocab_size = config.patcher_vocab_size - weight_tying = config.patcher_weight_tying - sliding_window = config.patcher_sliding_window - eos_token_id = config.patcher_eos_token_id - - self.rope_embeddings = RotaryEmbedding( - theta=rope_theta, - head_dim=head_dim or self.dim // n_heads, - max_seqlen=self.max_seqlen, - rope_use_fp32_in_outer_product=rope_use_fp32_in_outer_product, - ) - # Handle both eos_id and eos_token_id for compatibility - self.eos_id = eos_token_id - - # Extract additional parameters for TransformerLayer - n_kv_heads = getattr(config, 'patcher_n_kv_heads', None) if hasattr(config, 'patcher_dim') else getattr(config, 'n_kv_heads', None) - multiple_of = getattr(config, 'patcher_multiple_of', 256) if hasattr(config, 'patcher_dim') else getattr(config, 'multiple_of', 256) - ffn_dim_multiplier = getattr(config, 'patcher_ffn_dim_multiplier', None) if hasattr(config, 'patcher_dim') else getattr(config, 'ffn_dim_multiplier', None) - - # Create a simple parameter dict for TransformerLayer - layer_params = { - 'dim': self.dim, - 'n_heads': n_heads, - 'head_dim': head_dim, - 'n_kv_heads': n_kv_heads, - 'rope_theta': rope_theta, - 'multiple_of': multiple_of, - 'ffn_dim_multiplier': ffn_dim_multiplier, - 'norm_eps': norm_eps, - } - - self.layers = nn.ModuleList() - for _ in range(n_layers): - self.layers.append(TransformerLayer(layer_params)) - - # LMTransformer specific attributes - self.weight_tying = weight_tying - self.sliding_window = sliding_window - - assert vocab_size > 0 - - self.tok_embeddings = torch.nn.Embedding(vocab_size, self.dim) - - self.norm = RMSNorm(self.dim, eps=norm_eps) - - self.output = nn.Linear( - self.dim, - vocab_size, - bias=False, - ) - - if self.weight_tying: - self.output.weight = self.tok_embeddings.weight - - def push_to_hub(self, *args, **kwargs): - raise ValueError( - "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." - ) - - def forward( - self, - token_values: torch.Tensor, - target: Optional[torch.Tensor] = None, - tok_idx: Optional[torch.Tensor] = None, - mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, - attn_impl: str | None = None, - ): - if attn_impl is None: - attn_impl = self.attn_impl - bsz, seqlen = token_values.shape - - h = self.tok_embeddings(token_values) - mask = ( - mask - if mask is not None - else create_causal_mask( - seqlen, - attn_impl, - self.attn_bias_type, - sliding_window=self.sliding_window, - tokens=token_values, - eos_id=self.eos_id, - ) - ) - - freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) - - for i, layer in enumerate(self.layers): - h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) - - logits = self.output(self.norm(h)) - if target is not None: - return cross_entropy(logits, target) - else: - return logits - - def reset_parameters(self, init_std=None): - self.norm.reset_parameters() - - def init_weights(self): - self.reset_parameters() - init_std = self.dim ** (-0.5) - nn.init.trunc_normal_( - self.tok_embeddings.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - self.rope_embeddings.reset_parameters() - for depth, layer in enumerate(self.layers): - factor = self.config.get_init_std_factor(depth) - layer.init_weights(self.init_base_std, factor) - - if not self.weight_tying: - nn.init.trunc_normal_( - self.output.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - def rightpad(seq, pad_id, max_len): return seq + [pad_id] * (max_len - len(seq)) @@ -670,174 +511,6 @@ def check_non_zero_after_zero(tensor): non_zero_after_zero = (tensor != 0) & shifted_mask return non_zero_after_zero.any() -def entropy(scores): - """ - scores: [bs, seq_len, vocab] - returns [bs, seq_len] - - Computes the entropy for each token in the batch. - Note: uses natural log. - """ - log_probs = F.log_softmax(scores, dim=-1) - probs = torch.exp(log_probs) - p_log_p = log_probs * probs - entropy = -p_log_p.sum(dim=-1) - return entropy - - -def calculate_entropies( - tokens: torch.tensor, - entropy_model, - patching_batch_size, - device: str | None = None, - enable_grad: bool = False, -): - """ - tokens: 2D tensor of shape [batch_size, seq_len] - Return 2D tensor of shape [batch_size, seq_len] with entropies for each token. - - Splits the tokens into chunks of size max_length and calculates entropies for each chunk. - Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument. - """ - - grad_context = nullcontext() if enable_grad else torch.no_grad() - - with grad_context: - entropies = [] - preds = [] - max_length = getattr(entropy_model, "max_length", 8192) - batch_numel = max_length * patching_batch_size - splits = torch.split(tokens.flatten(), batch_numel) - for split in splits: - pad_size = (max_length - (split.numel() % max_length)) % max_length - pad = torch.zeros( - pad_size, dtype=split.dtype, device=split.device, requires_grad=False - ) - split = torch.cat((split, pad), dim=0) - split = split.reshape(-1, max_length) - if device is not None: - split = split.to(device) - # assert torch.all(split >= 0) and torch.all(split < 260) - pred = entropy_model(split) - pred = pred.reshape(-1, pred.shape[-1])[ - : split.numel() - pad_size, : - ] # [batch_size * seq_len, vocab] - preds.append(pred) - pred_entropies = entropy(pred) - entropies.append(pred_entropies) - - concat_entropies = torch.cat(entropies, dim=0) - concat_entropies = concat_entropies.reshape(tokens.shape) - concat_preds = torch.cat(preds, dim=0) - concat_preds = concat_preds.reshape(tokens.shape[0], -1) - return concat_entropies, concat_preds - - -def patch_start_ids_from_patch_start_mask(patch_start_mask): - bs, trunc_seq_len = patch_start_mask.shape - max_patches = patch_start_mask.sum(dim=1).max() - if max_patches == 0: - patch_start_ids = torch.full( - (bs, trunc_seq_len), - trunc_seq_len, - dtype=torch.long, - device=patch_start_mask.device, - ) - else: - patch_ids = ( - torch.arange(trunc_seq_len, device=patch_start_mask.device) - .unsqueeze(0) - .repeat(bs, 1) - ) - extra_patch_ids = torch.full( - (bs, trunc_seq_len), - trunc_seq_len, - dtype=torch.long, - device=patch_start_mask.device, - ) - all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) - patch_start_mask_padded = torch.cat( - (patch_start_mask, ~patch_start_mask), dim=1 - ) - patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape( - bs, trunc_seq_len - )[:, :max_patches] - return patch_start_ids - - -def patch_lengths_from_start_ids(patch_start_ids, seq_len): - """ - Calculate patch lengths from start ids. - start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then - the rest are filled to the seq len. - seq_len: ex: 7 length of the sequence - - returns the patch lengths: - [1, 6] for the above example. - """ - last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) - patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) - patch_lengths = patch_end_ids - patch_start_ids + 1 - assert torch.all(patch_lengths >= 0), f"{patch_lengths}" - assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" - return patch_lengths - -def find_entropy_patch_start_ids( - entropies, - patch_size=None, - threshold=None, - threshold_add=None, - monotonicity=False, - include_next_token=True, -): - """ - Use entropies to find the start ids of each patch. - Use patch_size or threshold to figure out the total number of patches to allocate. - - When threshold is not None the number of patches is not constant between - different sequences, but patches can be identified incrementally rather than - decided globally using the entire sequence. - """ - bs, seq_len = entropies.shape[:2] - - first_ids = ( - torch.tensor([0, 1], dtype=torch.long, device=entropies.device) - .unsqueeze(0) - .repeat(bs, 1) - ) - preds_truncation_len = first_ids.shape[ - 1 - ] # remove the first preds because they will be start of patches. - entropies = entropies[:, 1:] - if threshold is None: - num_patches = seq_len // patch_size - patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices - patch_start_ids = patch_start_ids.sort(dim=1).values - else: - patch_start_mask = entropies > threshold - if not include_next_token: - patch_start_mask = patch_start_mask[:, :-1] - # patch_start_mask[1:] |= tokens[:-1] < OFFSET - patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask) - - patch_start_ids = torch.cat( - (first_ids, patch_start_ids + preds_truncation_len), dim=1 - ) - return patch_start_ids - -def split_large_numbers(lst, m): - new_lst = [] - for i in lst: - if i > m: - while i > m: - new_lst.append(m) - i -= m - new_lst.append(i) - else: - new_lst.append(i) - assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" - return new_lst - def fill_tokens(tokens, patch_size, fill_id): batch_size, seq_len = tokens.shape @@ -849,29 +522,6 @@ def fill_tokens(tokens, patch_size, fill_id): return torch.cat((tokens, final_padding), dim=1) -def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len): - first_patch_length = patch_lengths[0, 0] - assert torch.all( - first_patch_length == patch_lengths[:, 0] - ), "first patch should always be the same size (1 for dynamic, patch_size for static)." - assert ( - first_patch_length - nb_boe == 1 - ), f"First patch (patch length: {first_patch_length}) should have one non-boe token (boe toks: {nb_boe})" - # Remove first patch from patch_ids for local decoder inputs and shift the last patch. - # decoder_patch_lengths = patch_lengths[:, 1:].clone() - # decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1) - decoder_patch_lengths = patch_lengths[:, 1:] - assert ( - decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0] - == patch_lengths.sum() - ), f"{decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]} != {patch_lengths.sum()}" - assert torch.all(decoder_patch_lengths >= 0), f"{decoder_patch_lengths}" - decoder_patch_ids = patch_ids_from_lengths( - patch_lengths=decoder_patch_lengths, seq_len=seq_len - ) - return decoder_patch_ids - - def rolling_polynomial_hash(t, hash_func_nb: int = 0): primes = [ 1000000007, @@ -1097,25 +747,6 @@ def get_blt_input( return local_encoder_tokens, None, local_decoder_tokens -def patch_ids_from_lengths(patch_lengths, seq_len): - bs, num_patches = patch_lengths.shape - # Create a tensor of cumulative sums of the patch lengths - cum_d = torch.cat( - [ - torch.zeros(bs, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), - patch_lengths.cumsum(dim=-1), - ], - dim=-1, - ) - patch_ids = (cum_d.unsqueeze(-1) <= torch.arange(seq_len, device=cum_d.device)).sum( - dim=-2 - ) - 1 - assert not ( - torch.max(patch_ids) > patch_lengths.shape[-1] or torch.min(patch_ids) < 0 - ), f"{torch.max(patch_ids)} > {patch_lengths.shape[-1]} or {torch.min(patch_ids)} < 0" - return patch_ids - - class LocalModelBase(nn.Module): def __init__(self, config: BLTConfig, component_type: str = "encoder"): super().__init__() @@ -1159,7 +790,7 @@ def __init__(self, config: BLTConfig, component_type: str = "encoder"): # Initialize cross attention layers as None (will be set by subclasses if needed) self.cross_attn_layers = None - # Create parameter dict for TransformerLayers + # Create parameter dict for BLTTransformerLayers layer_params = { 'dim': self.dim, 'n_heads': self.n_heads, @@ -1172,7 +803,7 @@ def __init__(self, config: BLTConfig, component_type: str = "encoder"): } self.layers = nn.ModuleList( - [TransformerLayer(layer_params) for _ in range(self.n_layers)] + [BLTTransformerLayer(layer_params) for _ in range(self.n_layers)] ) if not self.use_rope: @@ -1312,7 +943,7 @@ def __init__(self, config: BLTConfig): layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1 for _ in range(layers_to_add): self.cross_attn_layers.append( - CrossAttention( + BLTCrossAttention( dim=self.dim, head_dim=self.dim // self.cross_attn_nheads, n_heads=self.cross_attn_nheads, @@ -1433,7 +1064,7 @@ def __init__(self, config: BLTConfig): layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1 for _ in range(layers_to_add): self.cross_attn_layers.append( - CrossAttention( + BLTCrossAttention( dim=self.dim, head_dim=self.dim // self.cross_attn_nheads, n_heads=self.cross_attn_nheads, @@ -1507,9 +1138,9 @@ def forward( return h_preds, cache -class CrossAttention(nn.Module): +class BLTCrossAttention(nn.Module): """ - CrossAttention block to attend to the encoder states from the decoder. + BLTCrossAttention block to attend to the encoder states from the decoder. Rope is not supported. """ @@ -1659,7 +1290,7 @@ def __init__(self, config): # Handle both eos_id and eos_token_id for compatibility self.eos_id = getattr(config, 'eos_id', getattr(config, 'eos_token_id', 2)) - # Create parameter dict for TransformerLayers + # Create parameter dict for BLTTransformerLayers layer_params = { 'dim': self.dim, 'n_heads': config.n_heads, @@ -1673,7 +1304,7 @@ def __init__(self, config): self.layers = nn.ModuleList() for _ in range(config.n_layers): - self.layers.append(TransformerLayer(layer_params)) + self.layers.append(BLTTransformerLayer(layer_params)) # GlobalTransformer specific attributes self.dropout = config.dropout @@ -1740,9 +1371,6 @@ def init_weights(self): b=3 * std, ) - - - def compute_hash_embeddings( local_encoder_tokens: torch.Tensor, local_encoder, @@ -1788,10 +1416,23 @@ def compute_hash_embeddings( class BLTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + BLT models. + + This class provides the interface for model loading, saving, and weight initialization for all BLT model variants. + It inherits from [`PreTrainedModel`] which provides the core functionality for working with HuggingFace models. + + Args: + config ([`BLTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + """ + config_class = BLTConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["TransformerLayer", "LocalEncoder", "LocalDecoder", "GlobalTransformer"] + _no_split_modules = ["BLTTransformerLayer", "LocalEncoder", "LocalDecoder", "GlobalTransformer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = False # BLT uses its own attention implementation _supports_sdpa = True @@ -1806,8 +1447,8 @@ def _init_weights(self, module): class BLTModel(BLTPreTrainedModel): """ The BLTModel (BLT) is a byte-level language model architecture that processes byte sequences - by dynamically segmenting them into patches. It uses a combination of local encoder/decoder and aglobal transformer - to efficiently encode and decode byte sequences, leveraging patch-based processing for + by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, + and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for improved performance and inference efficiency. """ @@ -1873,8 +1514,8 @@ def __init__(self, config: BLTConfig): config.patcher_attn_impl = "sdpa" # originally xformers config.patcher_sliding_window = 512 - # LMTransformer will extract patcher_ parameters from config directly - self.patcher = LMTransformer(config) + # BLTPatcher will extract patcher_ parameters from config directly + self.patcher = BLTPatcher(config) state_path = os.path.join( entropy_model_checkpoint_dir, "consolidated.pth" @@ -1895,102 +1536,68 @@ def __init__(self, config: BLTConfig): # Initialize weights and apply final processing self.post_init() - @property - def patch_in_forward(self): - """Backward compatibility property for accessing patch_in_forward from config.""" - return self.config.patch_in_forward + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: + """ + Convert patch lengths to patch IDs for each token position. + + For each token position in the sequence, determines which patch it belongs to. + + Args: + patch_lengths: [batch_size, num_patches] - length of each patch + seq_len: total sequence length + + Returns: + patch_ids: [batch_size, seq_len] - patch index for each token position + + Example: + patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1 + seq_len = 10 + Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]] + # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3 + """ + batch_size, num_patches = patch_lengths.shape + + # Create patch start positions: [0, 3, 5, 9] for the example above + patch_starts = torch.cat([ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1] # cumsum without the final total + ], dim=-1) + + # For each token position, find which patch it belongs to + # by finding the rightmost patch start that's <= the position + token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1] + + # Broadcasting: patch_starts[batch, patch] <= token_positions[position] + # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t + position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1) + + # Count how many patch starts are <= each position, then subtract 1 to get patch index + patch_ids = position_ge_patch_start.sum(dim=-1) - 1 + + return patch_ids - def patch( - self, - tokens: torch.Tensor, - include_next_token: bool = False, - preds: torch.Tensor | None = None, - entropies: torch.Tensor | None = None, - threshold: float = None, - ) -> torch.Tensor: + def _decoder_patch_ids_from_lengths(self, patch_lengths: torch.Tensor, nb_boe: int, seq_len: int) -> torch.Tensor: """ - tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched - Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.) - -> output tensor: [batch_size, max_num_patches] - each tensor is processed independently and gets right padded with zeros. - - Patching with the following modes: - 1. patching_mode = None: static patch size - 2. patching_mode = "entropy": - calculate entropy of each token, allocate patches so that the total - number of patches is the same as static patching but choose to begin - patches on tokens where the model is most uncertain (highest entropy). - - When threshold is provided, it uses the threshold to decide when to - start a new patch. - 3. patching_mode = "space": - use space like tokens to define the patches. - 4. patching_mode = "bpe": - use bpe delim tokens to define the patches. - - To correctly patch the last token, it may be necessary to include the next token in the patch - lengths calculations. This is controlled by the include_next_token argument. + Create decoder patch IDs by skipping the first encoder patch. + + The decoder starts after the first patch (which contains BOE tokens), + so we need to map decoder positions to the remaining patches. + + Args: + patch_lengths: [batch_size, num_patches] from encoder + nb_boe: number of beginning-of-example tokens in first patch + seq_len: decoder sequence length + + Returns: + decoder_patch_ids: [batch_size, seq_len] mapping decoder positions to patch indices """ - bs, seq_len = tokens.shape - seq_len_next_tok = seq_len + 1 if include_next_token else seq_len - scores = None - # STATIC - if self.config.patching_mode == PatchingModeEnum.byte: - patch_lengths = torch.ones( - (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device - ) - elif self.config.patching_mode == PatchingModeEnum.entropy: - if entropies is not None: - scores = entropies.to(dtype=torch.float32) - elif preds is not None: - scores = entropy(preds) - else: - scores, _ = calculate_entropies( - tokens, - self.patcher, - self.config.patching_batch_size, - self.config.patching_device, - ) - patch_start_ids = find_entropy_patch_start_ids( - scores, - self.config.patch_size, - include_next_token=include_next_token, - threshold=threshold if threshold is not None else self.config.patching_threshold, - threshold_add=self.config.patching_threshold_add, - monotonicity=self.config.monotonicity, - ) - patch_lengths = patch_lengths_from_start_ids( - patch_start_ids, seq_len_next_tok - ) - else: - raise NotImplementedError(f"self.config.patching_mode {self.config.patching_mode}") + # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens) + decoder_patch_lengths = patch_lengths[:, 1:] + + # Create patch IDs for the decoder sequence using the remaining patches + return self._patch_ids_from_lengths(decoder_patch_lengths, seq_len) + - # Apply any processing to patch lengths - if self.config.max_patch_length is not None: - # TODO: avoid going back to a list here. - patch_lengths = [ - split_large_numbers(pl, self.config.max_patch_length) - for pl in patch_lengths.tolist() - ] - max_len = max([len(pl) for pl in patch_lengths]) - patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] - patch_lengths = torch.tensor( - patch_lengths, dtype=tokens.dtype, device=tokens.device - ) - assert not check_non_zero_after_zero(patch_lengths) - # Find the last non-zero column index using argmax on a reversed version of the tensor - last_non_zero_col_reversed = ( - (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() - ) - # Slice the tensor up to the last non-zero column - patch_lengths = patch_lengths[ - :, : patch_lengths.shape[1] - last_non_zero_col_reversed - ] - assert ( - torch.sum(patch_lengths) - == tokens.numel() + include_next_token * tokens.shape[0] - ), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}" - return patch_lengths, scores def forward( self, @@ -2014,14 +1621,52 @@ def forward( # Patching if patch_lengths is None: - assert ( - getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward - ), "Patch in forward not enabled and no patch_lengths passed." - patch_lengths, tok_scores = self.patch( - local_encoder_tokens, - include_next_token=True, - threshold=self.config.patching_threshold, - ) + # assert ( + # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward + # ), "Patch in forward not enabled and no patch_lengths passed." + + # PATCHER MODEL DEFINED + if self.config.patching_mode == PatchingModeEnum.entropy: + _, patch_lengths, _ = self.patcher( + local_encoder_tokens, + patch_size=self.config.patch_size, + include_next_token=True, + threshold=self.config.patching_threshold, + threshold_add=self.config.patching_threshold_add, + monotonicity=self.config.monotonicity, + max_patch_length=self.config.max_patch_length, + patching_batch_size=self.config.patching_batch_size, + device=self.config.patching_device, + ) + else: + # self.config.patching_mode == PatchingModeEnum.byte + bs, seq_len = local_encoder_tokens.shape + seq_len_next_tok = seq_len + 1 # include_next_token=True + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + ) + + # Apply any processing to patch lengths + if self.config.max_patch_length is not None: + # TODO: avoid going back to a list here. + patch_lengths = [ + BLTPatcher.split_large_numbers(pl, self.config.max_patch_length) + for pl in patch_lengths.tolist() + ] + max_len = max([len(pl) for pl in patch_lengths]) + patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] + patch_lengths = torch.tensor( + patch_lengths, dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + ) + assert not check_non_zero_after_zero(patch_lengths) + # Find the last non-zero column index using argmax on a reversed version of the tensor + last_non_zero_col_reversed = ( + (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() + ) + # Slice the tensor up to the last non-zero column + patch_lengths = patch_lengths[ + :, : patch_lengths.shape[1] - last_non_zero_col_reversed + ] else: if nb_boe > 0: patch_lengths[:, 0] += nb_boe @@ -2029,9 +1674,9 @@ def forward( assert torch.min(patch_lengths) >= 0 # Generate patch IDs from patch_lengths - patch_ids = patch_ids_from_lengths( + patch_ids = self._patch_ids_from_lengths( patch_lengths, local_encoder_tokens.shape[-1] - ).to(tokens.device) + ) assert torch.max(patch_ids) + 1 <= torch.max( (patch_lengths != 0).sum(dim=-1) ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" @@ -2090,7 +1735,7 @@ def forward( dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :] # Generate decoder patch IDs - decoder_patch_ids = decoder_patch_ids_from_lengths( + decoder_patch_ids = self._decoder_patch_ids_from_lengths( patch_lengths, nb_boe, local_decoder_tokens.shape[-1] ) assert ( @@ -2143,6 +1788,349 @@ def init_weights(self): b=3 * emb_std, ) + +class BLTPatcher(BLTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + # Store config reference for later use + self.config = config + + # Extract patcher parameters from BLTConfig + self.dim = config.patcher_dim + self.init_base_std = config.patcher_init_base_std + self.attn_impl = config.patcher_attn_impl + self.attn_bias_type = config.patcher_attn_bias_type + self.init_std_factor = config.patcher_init_std_factor + self.max_seqlen = config.patcher_max_seqlen + n_layers = config.patcher_n_layers + n_heads = config.patcher_n_heads + head_dim = config.patcher_head_dim + rope_theta = config.patcher_rope_theta + rope_use_fp32_in_outer_product = config.patcher_rope_use_fp32_in_outer_product + norm_eps = config.patcher_norm_eps + vocab_size = config.patcher_vocab_size + weight_tying = config.patcher_weight_tying + sliding_window = config.patcher_sliding_window + eos_token_id = config.patcher_eos_token_id + + self.rope_embeddings = RotaryEmbedding( + theta=rope_theta, + head_dim=head_dim or self.dim // n_heads, + max_seqlen=self.max_seqlen, + rope_use_fp32_in_outer_product=rope_use_fp32_in_outer_product, + ) + # Handle both eos_id and eos_token_id for compatibility + self.eos_id = eos_token_id + + # Extract additional parameters for BLTTransformerLayer + n_kv_heads = getattr(config, 'patcher_n_kv_heads', None) if hasattr(config, 'patcher_dim') else getattr(config, 'n_kv_heads', None) + multiple_of = getattr(config, 'patcher_multiple_of', 256) if hasattr(config, 'patcher_dim') else getattr(config, 'multiple_of', 256) + ffn_dim_multiplier = getattr(config, 'patcher_ffn_dim_multiplier', None) if hasattr(config, 'patcher_dim') else getattr(config, 'ffn_dim_multiplier', None) + + # Create a simple parameter dict for BLTTransformerLayer + layer_params = { + 'dim': self.dim, + 'n_heads': n_heads, + 'head_dim': head_dim, + 'n_kv_heads': n_kv_heads, + 'rope_theta': rope_theta, + 'multiple_of': multiple_of, + 'ffn_dim_multiplier': ffn_dim_multiplier, + 'norm_eps': norm_eps, + } + + self.layers = nn.ModuleList() + for _ in range(n_layers): + self.layers.append(BLTTransformerLayer(layer_params)) + + # LMTransformer specific attributes + self.weight_tying = weight_tying + self.sliding_window = sliding_window + + assert vocab_size > 0 + + self.tok_embeddings = torch.nn.Embedding(vocab_size, self.dim) + + self.norm = RMSNorm(self.dim, eps=norm_eps) + + self.output = nn.Linear( + self.dim, + vocab_size, + bias=False, + ) + + if self.weight_tying: + self.output.weight = self.tok_embeddings.weight + + def forward( + self, + token_values: torch.Tensor, + target: Optional[torch.Tensor] = None, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + attn_impl: str | None = None, + patch_size: Optional[int] = None, + include_next_token: bool = True, + threshold: Optional[float] = None, + threshold_add: Optional[float] = None, + monotonicity: bool = False, + max_patch_length: Optional[int] = None, + patching_batch_size: int = 1, # Changed from Optional[int] = None to int = 1 + device: Optional[str] = None, + enable_grad: bool = False, + ): + attn_impl = self.attn_impl if attn_impl is None else attn_impl + + # Handle chunked processing for entropy calculation + # grad_context = nullcontext() if enable_grad else torch.no_grad() + # with grad_context: + entropies = [] + preds = [] + max_length = min(getattr(self, "max_length", 8192), self.max_seqlen) + batch_numel = max_length * patching_batch_size + splits = torch.split(token_values.flatten(), batch_numel) + + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros( + pad_size, dtype=split.dtype, device=split.device, requires_grad=False + ) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length) + if device is not None: + split = split.to(device) + + # Process chunk: embeddings -> layers -> output + bsz, seqlen = split.shape + h = self.tok_embeddings(split) + chunk_mask = create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=split, + eos_id=self.eos_id, + ) + freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=attn_impl) + + pred = self.output(self.norm(h)) + pred = pred.reshape(-1, pred.shape[-1])[ + : split.numel() - pad_size, : + ] # [batch_size * seq_len, vocab] + preds.append(pred) + pred_entropies = self.entropy(pred) + entropies.append(pred_entropies) + + concat_entropies = torch.cat(entropies, dim=0) + concat_entropies = concat_entropies.reshape(token_values.shape) + concat_preds = torch.cat(preds, dim=0) + concat_preds = concat_preds.reshape(token_values.shape[0], -1) + + # Always compute patch lengths from concatenated entropies + bs, seq_len = token_values.shape + seq_len_next_tok = seq_len + 1 if include_next_token else seq_len + + # Find patch start IDs based on entropy + if patch_size is not None: + patch_start_ids = self.find_entropy_patch_start_ids( + concat_entropies, + patch_size, + include_next_token=include_next_token, + threshold=threshold, + threshold_add=threshold_add, + monotonicity=monotonicity, + ) + patch_lengths = self.patch_lengths_from_start_ids( + patch_start_ids, seq_len_next_tok + ) + else: + # Default to byte-level patching + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device + ) + + # Apply any processing to patch lengths + if max_patch_length is not None: + # TODO: avoid going back to a list here. + patch_lengths = [ + self.split_large_numbers(pl, max_patch_length) + for pl in patch_lengths.tolist() + ] + max_len = max([len(pl) for pl in patch_lengths]) + patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] + patch_lengths = torch.tensor( + patch_lengths, dtype=token_values.dtype, device=token_values.device + ) + assert not check_non_zero_after_zero(patch_lengths) + # Find the last non-zero column index using argmax on a reversed version of the tensor + last_non_zero_col_reversed = ( + (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() + ) + # Slice the tensor up to the last non-zero column + patch_lengths = patch_lengths[ + :, : patch_lengths.shape[1] - last_non_zero_col_reversed + ] + + return concat_entropies, patch_lengths, concat_preds + + def reset_parameters(self, init_std=None): + self.norm.reset_parameters() + + def init_weights(self): + self.reset_parameters() + init_std = self.dim ** (-0.5) + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + self.rope_embeddings.reset_parameters() + for depth, layer in enumerate(self.layers): + factor = self.config.get_init_std_factor(depth) + layer.init_weights(self.init_base_std, factor) + + if not self.weight_tying: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + @staticmethod + def entropy(scores): + """ + scores: [bs, seq_len, vocab] + returns [bs, seq_len] + + Computes the entropy for each token in the batch. + Note: uses natural log. + """ + log_probs = F.log_softmax(scores, dim=-1) + probs = torch.exp(log_probs) + p_log_p = log_probs * probs + entropy = -p_log_p.sum(dim=-1) + return entropy + + + + @staticmethod + def patch_start_ids_from_patch_start_mask(patch_start_mask): + bs, trunc_seq_len = patch_start_mask.shape + max_patches = patch_start_mask.sum(dim=1).max() + if max_patches == 0: + patch_start_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + else: + patch_ids = ( + torch.arange(trunc_seq_len, device=patch_start_mask.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + extra_patch_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) + patch_start_mask_padded = torch.cat( + (patch_start_mask, ~patch_start_mask), dim=1 + ) + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape( + bs, trunc_seq_len + )[:, :max_patches] + return patch_start_ids + + @staticmethod + def patch_lengths_from_start_ids(patch_start_ids, seq_len): + """ + Calculate patch lengths from start ids. + start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then + the rest are filled to the seq len. + seq_len: ex: 7 length of the sequence + + returns the patch lengths: + [1, 6] for the above example. + """ + last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) + patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) + patch_lengths = patch_end_ids - patch_start_ids + 1 + assert torch.all(patch_lengths >= 0), f"{patch_lengths}" + assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" + return patch_lengths + + @staticmethod + def find_entropy_patch_start_ids( + entropies, + patch_size=None, + threshold=None, + threshold_add=None, + monotonicity=False, + include_next_token=True, + ): + """ + Use entropies to find the start ids of each patch. + Use patch_size or threshold to figure out the total number of patches to allocate. + + When threshold is not None the number of patches is not constant between + different sequences, but patches can be identified incrementally rather than + decided globally using the entire sequence. + """ + bs, seq_len = entropies.shape[:2] + + first_ids = ( + torch.tensor([0, 1], dtype=torch.long, device=entropies.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + preds_truncation_len = first_ids.shape[ + 1 + ] # remove the first preds because they will be start of patches. + entropies = entropies[:, 1:] + if threshold is None: + num_patches = seq_len // patch_size + patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices + patch_start_ids = patch_start_ids.sort(dim=1).values + else: + patch_start_mask = entropies > threshold + if not include_next_token: + patch_start_mask = patch_start_mask[:, :-1] + # patch_start_mask[1:] |= tokens[:-1] < OFFSET + patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) + + patch_start_ids = torch.cat( + (first_ids, patch_start_ids + preds_truncation_len), dim=1 + ) + return patch_start_ids + + @staticmethod + def split_large_numbers(lst, m): + new_lst = [] + for i in lst: + if i > m: + while i > m: + new_lst.append(m) + i -= m + new_lst.append(i) + else: + new_lst.append(i) + assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" + return new_lst + + def init_hash_embeddings( config, local_encoder_dim: int, @@ -2171,7 +2159,7 @@ def init_hash_embeddings( __all__ = [ "BLTPreTrainedModel", "BLTModel", - "LMTransformer", + "BLTPatcher", "LocalEncoder", "LocalDecoder", "GlobalTransformer", From 907eca15b8d9858b4e789df2511ea80404863bff Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 12 Jun 2025 21:44:11 +0000 Subject: [PATCH 010/139] add conversion script --- src/convert_blt_to_hf.py | 422 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 422 insertions(+) create mode 100644 src/convert_blt_to_hf.py diff --git a/src/convert_blt_to_hf.py b/src/convert_blt_to_hf.py new file mode 100644 index 000000000000..fff4e8a971a8 --- /dev/null +++ b/src/convert_blt_to_hf.py @@ -0,0 +1,422 @@ +import argparse +import json +import logging +import os +from typing import Dict, Any, Optional + +import torch +from huggingface_hub import hf_hub_download, snapshot_download, upload_folder +from safetensors.torch import load_file, save_file + +from transformers.utils import logging as transformers_logging + +logger = transformers_logging.get_logger(__name__) +transformers_logging.set_verbosity_info() + +from transformers.models.blt_wip.modeling_blt_wip import BLTModel +from transformers.models.blt_wip.configuration_blt import BLTConfig + + +def download_model_files(model_id: str, cache_dir: Optional[str] = None) -> Dict[str, str]: + config_path = hf_hub_download( + repo_id=model_id, + filename="config.json", + cache_dir=cache_dir + ) + + weights_path = hf_hub_download( + repo_id=model_id, + filename="model.safetensors", + cache_dir=cache_dir + ) + + entropy_params_path = hf_hub_download( + repo_id=model_id, + filename="entropy_model/params.json", + cache_dir=cache_dir + ) + + entropy_weights_path = hf_hub_download( + repo_id=model_id, + filename="entropy_model/consolidated.pth", + cache_dir=cache_dir + ) + + return { + "config": config_path, + "weights": weights_path, + "entropy_params": entropy_params_path, + "entropy_weights": entropy_weights_path + } + + +def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]: + + logger.info("Merging confi") + + # Load BLT configuration + with open(config_path, 'r') as f: + main_config = json.load(f) + + # Load Patcher entropy model parameters + with open(entropy_params_path, 'r') as f: + entropy_data = json.load(f) + + entropy_model_params = entropy_data.get("entropy_model", {}) + patcher_args = entropy_data.get("data", {}).get("patcher_args", {}) + + # Create unified configuration + unified_config = main_config.copy()['args'] + + # Ensure other integer parameters are properly typed + for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]: + if key in unified_config and not isinstance(unified_config[key], int): + unified_config[key] = int(unified_config[key]) + + patch_size = patcher_args.get("patch_size", 8) + if isinstance(patch_size, float): + patch_size = int(patch_size) + + unified_config.update({ + "patch_in_forward": True, + "realtime_patching": True, + "patching_mode": "entropy", + + "patch_size": patch_size, + "patching_threshold": patcher_args.get("threshold", 0.5), + "patching_threshold_add": patcher_args.get("threshold_add", 0.0), + "max_patch_length": patcher_args.get("max_patch_length"), + "patching_batch_size": patcher_args.get("patching_batch_size", 1), + "patching_device": patcher_args.get("patching_device", "cuda"), + "monotonicity": patcher_args.get("monotonicity", False), + + "patcher_vocab_size": int(entropy_model_params.get("vocab_size", 256)), + "patcher_dim": int(entropy_model_params.get("dim", 512)), + "patcher_n_layers": int(entropy_model_params.get("n_layers", 8)), + "patcher_n_heads": int(entropy_model_params.get("n_heads", 8)), + "patcher_head_dim": int(entropy_model_params.get("head_dim")) if entropy_model_params.get("head_dim") is not None else None, + "patcher_n_kv_heads": int(entropy_model_params.get("n_kv_heads")) if entropy_model_params.get("n_kv_heads") is not None else None, + "patcher_max_seqlen": int(entropy_model_params.get("max_seqlen", 1024)), + "patcher_norm_eps": entropy_model_params.get("norm_eps", 1e-5), + "patcher_dropout": entropy_model_params.get("dropout", 0.0), + "patcher_sliding_window": int(entropy_model_params.get("sliding_window", 512)) if entropy_model_params.get("sliding_window") is not None else None, + "patcher_ffn_dim_multiplier": entropy_model_params.get("ffn_dim_multiplier"), + "patcher_multiple_of": int(entropy_model_params.get("multiple_of", 256)), + "patcher_rope_theta": entropy_model_params.get("rope_theta", 10000.0), + "patcher_rope_use_fp32_in_outer_product": entropy_model_params.get("rope_use_fp32_in_outer_product", False), + "patcher_attn_impl": entropy_model_params.get("attn_impl", "sdpa"), + "patcher_attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"), + "patcher_init_base_std": entropy_model_params.get("init_base_std"), + "patcher_init_std_factor": entropy_model_params.get("init_std_factor", "disabled"), + "patcher_dim_token_emb": entropy_model_params.get("dim_token_emb"), + "patcher_weight_tying": entropy_model_params.get("weight_tying", False), + "patcher_bos_token_id": entropy_model_params.get("bos_token_id", 1), + "patcher_eos_token_id": entropy_model_params.get("eos_token_id", 2), + }) + + logger.info(f"Merged configuration with {len(unified_config)} parameters") + return unified_config + + +def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]: + logger.info("Merging model weights") + + main_weights = load_file(weights_path) + logger.info(f"Loaded main model weights: {len(main_weights)} tensors") + + entropy_weights = torch.load(entropy_weights_path, map_location='cpu', weights_only=True) + + if 'model' in entropy_weights: + entropy_weights = entropy_weights['model'] + elif 'state_dict' in entropy_weights: + entropy_weights = entropy_weights['state_dict'] + + logger.info(f"Loaded entropy model weights: {len(entropy_weights)} tensors") + + # unified state dict + unified_weights = main_weights.copy() + + # Add entropy model weights with "patcher." prefix + for key, tensor in entropy_weights.items(): + patcher_key = f"patcher.{key}" + unified_weights[patcher_key] = tensor + + logger.info(f"Merged weights: {len(unified_weights)} tensors total") + return unified_weights + + +def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): + logger.info("Creating tokenizer config") + + tokenizer_config = { + "tokenizer_class": "BltTokenizer", + "vocab_size": config.get("vocab_size", 256), + "model_max_length": config.get("max_seqlen", 1024), + "add_bos_token": True, + "add_eos_token": True, + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "", + } + + tokenizer_path = os.path.join(output_dir, "tokenizer_config.json") + with open(tokenizer_path, 'w') as f: + json.dump(tokenizer_config, f, indent=2) + + logger.info(f"Tokenizer config saved to {tokenizer_path}") + + +def validate_unified_model(config: Dict[str, Any], weights: Dict[str, torch.Tensor]): + logger.info("Validating unified model") + + required_keys = [ + "vocab_size", "dim", "n_layers", "n_heads", + "patch_in_forward", "patcher_vocab_size", "patcher_dim" + ] + + missing_keys = [key for key in required_keys if key not in config] + if missing_keys: + logger.warning(f"Missing configuration keys: {missing_keys}") + + # Check for patcher weights + patcher_weights = [key for key in weights.keys() if key.startswith("patcher.")] + if not patcher_weights: + logger.warning("No patcher weights found in unified weights") + else: + logger.info(f"Found {len(patcher_weights)} patcher weight tensors") + + main_weights = [key for key in weights.keys() if not key.startswith("patcher.")] + logger.info(f"Found {len(main_weights)} main model weight tensors") + + try: + logger.info("Testing model instantiation...") + blt_config = BLTConfig(**config) + model = BLTModel(blt_config) + + logger.info("Testing weight loading...") + try: + missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False) + if missing_keys: + logger.warning(f"Missing keys during weight loading: {missing_keys}") + if unexpected_keys: + logger.warning(f"Unexpected keys during weight loading: {unexpected_keys}") + logger.info("Weight loading successful") + except Exception as weight_error: + logger.warning(f"Weight loading failed: {weight_error}") + logger.info("Model instantiation successful, but weight loading had issues") + + except Exception as e: + logger.error(f"Model validation failed: {e}") + + logger.info("Model validation completed") + + +def push_to_hub( + local_dir: str, + repo_id: str, + commit_message: str = "Upload converted BLT model", + private: bool = False, + token: Optional[str] = None, +) -> None: + """ + Push the converted model to Hugging Face Hub. + + Args: + local_dir: Local directory containing the converted model files + repo_id: Repository ID on Hugging Face Hub (e.g., "username/model-name") + commit_message: Commit message for the upload + private: Whether to create a private repository + token: Hugging Face authentication token (if not provided, will use cached token) + """ + logger.info(f"Pushing converted model to Hub: {repo_id}") + + try: + # Upload the entire directory to the Hub + upload_folder( + folder_path=local_dir, + repo_id=repo_id, + commit_message=commit_message, + repo_type="model", + token=token, + ) + logger.info(f"Successfully pushed model to {repo_id}") + + except Exception as e: + logger.error(f"Failed to push model to Hub: {e}") + raise + + +def convert_hf_blt_to_unified( + model_id: str, + output_dir: str, + config_name: str = "config.json", + weights_name: str = "model.bin", + safe_serialization: bool = True, + cache_dir: Optional[str] = None, + push_to_hub_repo: Optional[str] = None, + hub_private: bool = False, + hub_token: Optional[str] = None, +) -> None: + """ + Convert BLT model from HuggingFace Hub format to unified format. + + Args: + model_id: HuggingFace model ID (e.g., "facebook/blt-1b") + output_dir: Output directory for unified model + config_name: Name for unified config file + weights_name: Name for unified weights file + safe_serialization: Whether to use safetensors format + cache_dir: Cache directory for downloads + push_to_hub_repo: Repository ID to push the converted model to (optional) + hub_private: Whether to create a private repository on the Hub + hub_token: Hugging Face authentication token + """ + logger.info(f"Converting {model_id} to unified transformers format") + + file_paths = download_model_files(model_id, cache_dir) + + # Merge configurations + unified_config = merge_configurations( + file_paths["config"], + file_paths["entropy_params"] + ) + + # Merge weights + unified_weights = merge_weights( + file_paths["weights"], + file_paths["entropy_weights"] + ) + + validate_unified_model(unified_config, unified_weights) + + os.makedirs(output_dir, exist_ok=True) + + config_path = os.path.join(output_dir, config_name) + with open(config_path, 'w') as f: + json.dump(unified_config, f, indent=2) + + if safe_serialization and weights_name.endswith('.bin'): + weights_name = weights_name.replace('.bin', '.safetensors') + elif not safe_serialization and weights_name.endswith('.safetensors'): + weights_name = weights_name.replace('.safetensors', '.bin') + + weights_path = os.path.join(output_dir, weights_name) + if safe_serialization: + save_file(unified_weights, weights_path) + else: + torch.save(unified_weights, weights_path) + + logger.info(f"Unified config and weights saved to {weights_path}") + + # Create tokenizer config + create_tokenizer_config(output_dir, unified_config) + + logger.info(f"Conversion completed, model saved to: {output_dir}") + + # Push to Hub if requested + if push_to_hub_repo: + push_to_hub( + local_dir=output_dir, + repo_id=push_to_hub_repo, + commit_message=f"Upload unified BLT model converted from {model_id}", + private=hub_private, + token=hub_token, + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Convert BLT models from HuggingFace Hub format to unified format", + formatter_class=argparse.RawDescriptionHelpFormatter + ) + + parser.add_argument( + "--model_id", + type=str, + default="facebook/blt-1b", + help="HuggingFace model ID (e.g., facebook/blt-1b)" + ) + parser.add_argument( + "--output_dir", + type=str, + default="./new_unified_blt_debug", + help="Output directory for unified model" + ) + + # Optional + parser.add_argument( + "--config_name", + type=str, + default="config.json", + help="Name for unified config file (default: config.json)" + ) + parser.add_argument( + "--weights_name", + type=str, + default="model.bin", + ) + parser.add_argument( + "--safe_serialization", + action="store_true", + default=True, + ) + parser.add_argument( + "--no_safe_serialization", + dest="safe_serialization", + action="store_false", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + ) + parser.add_argument( + "--debug", + action="store_true", + default=True, # Enable debug by default for easier debugging + ) + + # Hub upload arguments + parser.add_argument( + "--push_to_hub", + type=str, + default="itazap/blt-1b", + ) + parser.add_argument( + "--hub_private", + action="store_true", + default=False, + help="Whether to create a private repository on the Hub" + ) + parser.add_argument( + "--hub_token", + type=str, + default="hf_your_token_here", + help="Hugging Face authentication token (if not provided, will use cached token)" + ) + + args = parser.parse_args() + + transformers_logging.set_verbosity_debug() + logging.basicConfig(level=logging.DEBUG) + + try: + convert_hf_blt_to_unified( + model_id=args.model_id, + output_dir=args.output_dir, + config_name=args.config_name, + weights_name=args.weights_name, + safe_serialization=args.safe_serialization, + cache_dir=args.cache_dir, + push_to_hub_repo=args.push_to_hub, + hub_private=args.hub_private, + hub_token=args.hub_token, + ) + except Exception as e: + logger.error(f"Conversion failed: {e}") + raise + + +if __name__ == "__main__": + main() \ No newline at end of file From c4b1775309f5ac6a5155d65de7ffd58837f11b59 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 13 Jun 2025 10:42:15 +0000 Subject: [PATCH 011/139] load from new checkpoing with form_pretrained --- src/convert_blt_to_hf.py | 39 +---- src/demo_hf.py | 117 +++++--------- .../models/blt_wip/modeling_blt_wip.py | 151 ++++++++---------- 3 files changed, 105 insertions(+), 202 deletions(-) diff --git a/src/convert_blt_to_hf.py b/src/convert_blt_to_hf.py index fff4e8a971a8..4a84822bf44b 100644 --- a/src/convert_blt_to_hf.py +++ b/src/convert_blt_to_hf.py @@ -68,7 +68,6 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str # Create unified configuration unified_config = main_config.copy()['args'] - # Ensure other integer parameters are properly typed for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]: if key in unified_config and not isinstance(unified_config[key], int): unified_config[key] = int(unified_config[key]) @@ -204,7 +203,6 @@ def validate_unified_model(config: Dict[str, Any], weights: Dict[str, torch.Tens logger.info("Weight loading successful") except Exception as weight_error: logger.warning(f"Weight loading failed: {weight_error}") - logger.info("Model instantiation successful, but weight loading had issues") except Exception as e: logger.error(f"Model validation failed: {e}") @@ -219,17 +217,6 @@ def push_to_hub( private: bool = False, token: Optional[str] = None, ) -> None: - """ - Push the converted model to Hugging Face Hub. - - Args: - local_dir: Local directory containing the converted model files - repo_id: Repository ID on Hugging Face Hub (e.g., "username/model-name") - commit_message: Commit message for the upload - private: Whether to create a private repository - token: Hugging Face authentication token (if not provided, will use cached token) - """ - logger.info(f"Pushing converted model to Hub: {repo_id}") try: # Upload the entire directory to the Hub @@ -258,20 +245,6 @@ def convert_hf_blt_to_unified( hub_private: bool = False, hub_token: Optional[str] = None, ) -> None: - """ - Convert BLT model from HuggingFace Hub format to unified format. - - Args: - model_id: HuggingFace model ID (e.g., "facebook/blt-1b") - output_dir: Output directory for unified model - config_name: Name for unified config file - weights_name: Name for unified weights file - safe_serialization: Whether to use safetensors format - cache_dir: Cache directory for downloads - push_to_hub_repo: Repository ID to push the converted model to (optional) - hub_private: Whether to create a private repository on the Hub - hub_token: Hugging Face authentication token - """ logger.info(f"Converting {model_id} to unified transformers format") file_paths = download_model_files(model_id, cache_dir) @@ -309,17 +282,15 @@ def convert_hf_blt_to_unified( logger.info(f"Unified config and weights saved to {weights_path}") - # Create tokenizer config create_tokenizer_config(output_dir, unified_config) logger.info(f"Conversion completed, model saved to: {output_dir}") - # Push to Hub if requested if push_to_hub_repo: push_to_hub( local_dir=output_dir, repo_id=push_to_hub_repo, - commit_message=f"Upload unified BLT model converted from {model_id}", + commit_message=f"Upload BLT model converted", private=hub_private, token=hub_token, ) @@ -335,13 +306,11 @@ def main(): "--model_id", type=str, default="facebook/blt-1b", - help="HuggingFace model ID (e.g., facebook/blt-1b)" ) parser.add_argument( "--output_dir", type=str, default="./new_unified_blt_debug", - help="Output directory for unified model" ) # Optional @@ -349,7 +318,6 @@ def main(): "--config_name", type=str, default="config.json", - help="Name for unified config file (default: config.json)" ) parser.add_argument( "--weights_name", @@ -374,10 +342,9 @@ def main(): parser.add_argument( "--debug", action="store_true", - default=True, # Enable debug by default for easier debugging + default=True, ) - # Hub upload arguments parser.add_argument( "--push_to_hub", type=str, @@ -387,13 +354,11 @@ def main(): "--hub_private", action="store_true", default=False, - help="Whether to create a private repository on the Hub" ) parser.add_argument( "--hub_token", type=str, default="hf_your_token_here", - help="Hugging Face authentication token (if not provided, will use cached token)" ) args = parser.parse_args() diff --git a/src/demo_hf.py b/src/demo_hf.py index f8db4eaf117d..92bc39135753 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -1,23 +1,21 @@ -from huggingface_hub import hf_hub_download import json - import logging import os import torch +from huggingface_hub import hf_hub_download -from transformers.models.blt_wip.modeling_blt_wip import BLTModel from transformers.models.blt_wip.configuration_blt import BLTConfig +from transformers.models.blt_wip.modeling_blt_wip import BLTModel from transformers.models.blt_wip.tokenizers.blt_tokenizer import BltTokenizer + logger = logging.getLogger() -import os os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" -def get_generation_range( - prompt_tokens: list[list[int]] | None, max_gen_len: int -) -> tuple[int, int]: + +def get_generation_range(prompt_tokens: list[list[int]] | None, max_gen_len: int) -> tuple[int, int]: batch_min_prompt_length = min([len(t) for t in prompt_tokens]) batch_max_prompt_length = max([len(t) for t in prompt_tokens]) return batch_min_prompt_length, batch_max_prompt_length + max_gen_len @@ -63,10 +61,7 @@ def generate( model.eval() prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts] # Truncation - prompt_tokens = [ - t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :] - for t in prompt_tokens - ] + prompt_tokens = [t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :] for t in prompt_tokens] start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len) batch_size = len(prompt_tokens) tokens = torch.full((batch_size, end_pos), tokenizer.pad_id, dtype=torch.long, device=device) @@ -91,96 +86,56 @@ def generate( else: next_token = torch.argmax(logits, dim=-1) - next_token = torch.where( - input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token - ) + next_token = torch.where(input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token) tokens[:, curr_pos] = next_token if remove_prompts: generated_tokens = [ - t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist() - for i, t in enumerate(tokens) + t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist() for i, t in enumerate(tokens) ] else: - generated_tokens = [ - t[: len(prompt_tokens[i]) + max_gen_len].tolist() - for i, t in enumerate(tokens) - ] + generated_tokens = [t[: len(prompt_tokens[i]) + max_gen_len].tolist() for i, t in enumerate(tokens)] return generated_tokens - def main(prompt: str = "my name is", model_name: str = "blt-1b"): device = "cuda" - #HF - blt_repo = "facebook/blt-1b" - - # Get the model's default configuration and entropy model params - print("Loading model configuration...") - config_path = hf_hub_download(repo_id=blt_repo, filename="config.json") - entropy_params_path = hf_hub_download(repo_id=blt_repo, filename="entropy_model/params.json") - entropy_checkpoint_path = hf_hub_download(repo_id=blt_repo, filename="entropy_model/consolidated.pth") - entropy_dir = os.path.dirname(entropy_checkpoint_path) - - with open(config_path, 'r') as f: - config = json.load(f) - with open(entropy_params_path, 'r') as f: - entropy_params = json.load(f) - - config['args']['attn_bias_type'] = 'causal' - config['args']['attn_impl'] = 'sdpa' - - # Create model using normal constructor instead of from_pretrained - from transformers.models.blt_wip.configuration_blt import BLTConfig - model_config = BLTConfig(**config['args']) - - # Set entropy model parameters manually - patcher_args = entropy_params["data"]["patcher_args"] - model_config.patch_in_forward = True - model_config.realtime_patching = True - model_config.patch_size = patcher_args["patch_size"] - model_config.patching_mode = "entropy" - model_config.patching_threshold = patcher_args["threshold"] - model_config.patching_threshold_add = patcher_args["threshold_add"] - model_config.max_patch_length = patcher_args["max_patch_length"] - model_config.patching_batch_size = patcher_args["patching_batch_size"] - model_config.patching_device = patcher_args["patching_device"] - model_config.monotonicity = patcher_args["monotonicity"] - model_config.entropy_model_checkpoint_dir = entropy_dir - - # Use direct construction instead of from_pretrained to avoid meta tensor issues - print("Creating model...") - model = BLTModel(model_config).to(device) - - # Load model weights manually - print("Loading model weights...") - from safetensors.torch import load_file - checkpoint_path = hf_hub_download(repo_id=blt_repo, filename="model.safetensors") - state_dict = load_file(checkpoint_path) - model.load_state_dict(state_dict, strict=False) - - tokenizer = BltTokenizer( - vocab_size_unit_1=model_config.vocab_size, - add_bos=True, - add_eos=True - ) + # HF + blt_repo = "itazap/blt-1b" + + # Load model using from_pretrained + print("Loading model...") + model = BLTModel.from_pretrained(blt_repo).to(device) + + tokenizer = BltTokenizer(vocab_size_unit_1=model.config.vocab_size, add_bos=True, add_eos=True) prompts = [prompt] - outputs = generate( - prompts, - model=model, - tokenizer=tokenizer, - max_gen_len=200, - device=device - ) + + import torch + + def snapshot_state_dict(state_dict, n=3, filename='snapshot.txt'): + with open(filename, 'w') as f: + keys = list(state_dict.keys()) + # selected_keys = keys[:n] + keys[-n:] if len(keys) > 2 * n else keys + for key in keys: + f.write(f"{key}:\n") + f.write(f"{state_dict[key].flatten()[:5]}\n") # Print first 5 values of each tensor + f.write("\n") + print(f"Snapshot saved to {filename}") + + snapshot_state_dict(model.state_dict(), n=5, filename='demo_hf.txt') + + + outputs = generate(prompts, model=model, tokenizer=tokenizer, max_gen_len=200, device=device) + text_outputs = [tokenizer.decode(t) for t in outputs] for p, t in zip(prompts, text_outputs): print(f'Prompt: "{p}"') print(f'Completion: "{t}"') print() + if __name__ == "__main__": main() - diff --git a/src/transformers/models/blt_wip/modeling_blt_wip.py b/src/transformers/models/blt_wip/modeling_blt_wip.py index b8eedb21498d..71e390077819 100644 --- a/src/transformers/models/blt_wip/modeling_blt_wip.py +++ b/src/transformers/models/blt_wip/modeling_blt_wip.py @@ -1481,60 +1481,33 @@ def __init__(self, config: BLTConfig): # Initialize patcher if needed if config.patch_in_forward: - if config.realtime_patching and config.entropy_model_checkpoint_dir is not None: - # Load entropy model directly - entropy_model_checkpoint_dir = config.entropy_model_checkpoint_dir - - if not os.path.exists(entropy_model_checkpoint_dir): - raise FileNotFoundError(f"Entropy model checkpoint directory not found: {entropy_model_checkpoint_dir}") - - # Load entropy model parameters - params_path = os.path.join(entropy_model_checkpoint_dir, "params.json") - if not os.path.exists(params_path): - raise FileNotFoundError(f"params.json not found in: {entropy_model_checkpoint_dir}") - - with open(params_path) as fr: - reloaded = json.loads(fr.read()) + # Create patcher with config + self.patcher = BLTPatcher(config) + # Set patcher to eval mode and disable gradients + self.patcher.eval() + for param in self.patcher.parameters(): + param.requires_grad = False + else: + self.patcher = None - torch.set_default_dtype(torch.bfloat16) - model_params = reloaded["entropy_model"] - logger.warning( - "Update checkpoint to load attn and sliding window args from checkpoint" - ) - - # Override patcher configuration with actual entropy model parameters from checkpoint - config.patcher_dim = model_params["dim"] - config.patcher_n_layers = model_params["n_layers"] - config.patcher_n_heads = model_params["n_heads"] - config.patcher_max_seqlen = model_params["max_seqlen"] - config.patcher_ffn_dim_multiplier = model_params["ffn_dim_multiplier"] - config.patcher_vocab_size = model_params["vocab_size"] - # Use sensible defaults for parameters not in checkpoint - config.patcher_attn_bias_type = "local_block_causal" - config.patcher_attn_impl = "sdpa" # originally xformers - config.patcher_sliding_window = 512 - - # BLTPatcher will extract patcher_ parameters from config directly - self.patcher = BLTPatcher(config) - - state_path = os.path.join( - entropy_model_checkpoint_dir, "consolidated.pth" - ) + def init_weights(self): + self.local_encoder.init_weights() + self.global_transformer.init_weights() + self.local_decoder.init_weights() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.patcher.load_state_dict( - torch.load(state_path, map_location=device)["model"], strict=False + if self.encoder_hash_tok_embedding is not None: + emb_std = self.local_encoder.dim ** (-0.5) + for emb in self.encoder_hash_tok_embedding: + nn.init.trunc_normal_( + emb.weight, + mean=0.0, + std=emb_std, + a=-3 * emb_std, + b=3 * emb_std, ) - self.patcher.to(device) - self.patcher = self.patcher.eval() - # no grads for the model: - for param in self.patcher.parameters(): - param.requires_grad = False - else: - self.patcher = None - - # Initialize weights and apply final processing - self.post_init() + + if self.patcher is not None: + self.patcher.init_weights() def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: """ @@ -1772,22 +1745,6 @@ def forward( ) return output - def init_weights(self): - self.local_encoder.init_weights() - self.global_transformer.init_weights() - self.local_decoder.init_weights() - - if self.encoder_hash_tok_embedding is not None: - emb_std = self.local_encoder.dim ** (-0.5) - for emb in self.encoder_hash_tok_embedding: - nn.init.trunc_normal_( - emb.weight, - mean=0.0, - std=emb_std, - a=-3 * emb_std, - b=3 * emb_std, - ) - class BLTPatcher(BLTPreTrainedModel): def __init__(self, config): @@ -1876,15 +1833,13 @@ def forward( threshold_add: Optional[float] = None, monotonicity: bool = False, max_patch_length: Optional[int] = None, - patching_batch_size: int = 1, # Changed from Optional[int] = None to int = 1 + patching_batch_size: int = 1, device: Optional[str] = None, enable_grad: bool = False, ): attn_impl = self.attn_impl if attn_impl is None else attn_impl # Handle chunked processing for entropy calculation - # grad_context = nullcontext() if enable_grad else torch.no_grad() - # with grad_context: entropies = [] preds = [] max_length = min(getattr(self, "max_length", 8192), self.max_seqlen) @@ -1929,7 +1884,7 @@ def forward( concat_entropies = concat_entropies.reshape(token_values.shape) concat_preds = torch.cat(preds, dim=0) concat_preds = concat_preds.reshape(token_values.shape[0], -1) - + # Always compute patch lengths from concatenated entropies bs, seq_len = token_values.shape seq_len_next_tok = seq_len + 1 if include_next_token else seq_len @@ -1977,34 +1932,62 @@ def forward( return concat_entropies, patch_lengths, concat_preds - def reset_parameters(self, init_std=None): - self.norm.reset_parameters() - def init_weights(self): - self.reset_parameters() - init_std = self.dim ** (-0.5) + """Initialize weights for the patcher model""" + # Initialize RoPE embeddings + self.rope_embeddings.reset_parameters() + + # Initialize norm layer + self.norm.reset_parameters() + + # Initialize token embeddings + emb_std = self.dim ** (-0.5) nn.init.trunc_normal_( self.tok_embeddings.weight, mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, + std=emb_std, + a=-3 * emb_std, + b=3 * emb_std, ) - self.rope_embeddings.reset_parameters() + # Initialize transformer layers for depth, layer in enumerate(self.layers): factor = self.config.get_init_std_factor(depth) layer.init_weights(self.init_base_std, factor) - + + # Initialize output layer if not weight tied if not self.weight_tying: nn.init.trunc_normal_( self.output.weight, mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, + std=emb_std, + a=-3 * emb_std, + b=3 * emb_std, ) + def _init_weights(self, module): + """Initialize weights for a specific module""" + if isinstance(module, nn.Linear): + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=self.init_base_std or (self.dim ** (-0.5)), + a=-3 * (self.init_base_std or (self.dim ** (-0.5))), + b=3 * (self.init_base_std or (self.dim ** (-0.5))), + ) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=self.init_base_std or (self.dim ** (-0.5)), + a=-3 * (self.init_base_std or (self.dim ** (-0.5))), + b=3 * (self.init_base_std or (self.dim ** (-0.5))), + ) + + + @staticmethod def entropy(scores): """ From fececd1c733fad2b0dce1f425333d5685069e123 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 13 Jun 2025 11:15:00 +0000 Subject: [PATCH 012/139] fixed demo from_pretrained --- src/demo_hf.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index 92bc39135753..68fc1565f719 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -101,33 +101,13 @@ def generate( def main(prompt: str = "my name is", model_name: str = "blt-1b"): device = "cuda" - # HF blt_repo = "itazap/blt-1b" - # Load model using from_pretrained - print("Loading model...") model = BLTModel.from_pretrained(blt_repo).to(device) - tokenizer = BltTokenizer(vocab_size_unit_1=model.config.vocab_size, add_bos=True, add_eos=True) prompts = [prompt] - import torch - - - def snapshot_state_dict(state_dict, n=3, filename='snapshot.txt'): - with open(filename, 'w') as f: - keys = list(state_dict.keys()) - # selected_keys = keys[:n] + keys[-n:] if len(keys) > 2 * n else keys - for key in keys: - f.write(f"{key}:\n") - f.write(f"{state_dict[key].flatten()[:5]}\n") # Print first 5 values of each tensor - f.write("\n") - print(f"Snapshot saved to {filename}") - - snapshot_state_dict(model.state_dict(), n=5, filename='demo_hf.txt') - - outputs = generate(prompts, model=model, tokenizer=tokenizer, max_gen_len=200, device=device) text_outputs = [tokenizer.decode(t) for t in outputs] From f2604f3016cfd97503825f917b3f7acd02ff934f Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 13 Jun 2025 14:44:12 +0000 Subject: [PATCH 013/139] clean up --- src/convert_blt_to_hf.py | 263 +- src/demo_hf.py | 3 - src/transformers/models/blt_wip/blt_args.py | 68 +- .../models/blt_wip/configuration_blt.py | 97 +- .../convert_hf_blt_original_to_unified.py | 540 ++++ .../models/blt_wip/modeling_blt_wip.py | 519 ++-- .../models/blt_wip/modeling_blt_wip_backup.py | 2166 +++++++++++++++++ .../blt_wip/tokenizers/abstract_tokenizer.py | 4 +- .../blt_wip/tokenizers/blt_tokenizer.py | 20 +- .../tokenizers/sentence_piece_tokenizer.py | 18 +- .../blt_wip/unified_blt_debug/config.json | 144 ++ 11 files changed, 3256 insertions(+), 586 deletions(-) create mode 100644 src/transformers/models/blt_wip/convert_hf_blt_original_to_unified.py create mode 100644 src/transformers/models/blt_wip/modeling_blt_wip_backup.py create mode 100644 src/transformers/models/blt_wip/unified_blt_debug/config.json diff --git a/src/convert_blt_to_hf.py b/src/convert_blt_to_hf.py index 4a84822bf44b..cb933961b4e2 100644 --- a/src/convert_blt_to_hf.py +++ b/src/convert_blt_to_hf.py @@ -2,151 +2,144 @@ import json import logging import os -from typing import Dict, Any, Optional +from typing import Any, Dict, Optional import torch -from huggingface_hub import hf_hub_download, snapshot_download, upload_folder +from huggingface_hub import hf_hub_download, upload_folder from safetensors.torch import load_file, save_file +from transformers.models.blt_wip.configuration_blt import BLTConfig +from transformers.models.blt_wip.modeling_blt_wip import BLTModel from transformers.utils import logging as transformers_logging + logger = transformers_logging.get_logger(__name__) transformers_logging.set_verbosity_info() -from transformers.models.blt_wip.modeling_blt_wip import BLTModel -from transformers.models.blt_wip.configuration_blt import BLTConfig - def download_model_files(model_id: str, cache_dir: Optional[str] = None) -> Dict[str, str]: - config_path = hf_hub_download( - repo_id=model_id, - filename="config.json", - cache_dir=cache_dir - ) - - weights_path = hf_hub_download( - repo_id=model_id, - filename="model.safetensors", - cache_dir=cache_dir - ) - - entropy_params_path = hf_hub_download( - repo_id=model_id, - filename="entropy_model/params.json", - cache_dir=cache_dir - ) - + config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir) + + weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir) + + entropy_params_path = hf_hub_download(repo_id=model_id, filename="entropy_model/params.json", cache_dir=cache_dir) + entropy_weights_path = hf_hub_download( - repo_id=model_id, - filename="entropy_model/consolidated.pth", - cache_dir=cache_dir + repo_id=model_id, filename="entropy_model/consolidated.pth", cache_dir=cache_dir ) - + return { "config": config_path, "weights": weights_path, "entropy_params": entropy_params_path, - "entropy_weights": entropy_weights_path + "entropy_weights": entropy_weights_path, } - -def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]: +def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]: logger.info("Merging confi") - + # Load BLT configuration - with open(config_path, 'r') as f: + with open(config_path, "r") as f: main_config = json.load(f) - + # Load Patcher entropy model parameters - with open(entropy_params_path, 'r') as f: + with open(entropy_params_path, "r") as f: entropy_data = json.load(f) - + entropy_model_params = entropy_data.get("entropy_model", {}) patcher_args = entropy_data.get("data", {}).get("patcher_args", {}) - + # Create unified configuration - unified_config = main_config.copy()['args'] - + unified_config = main_config.copy()["args"] + for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]: if key in unified_config and not isinstance(unified_config[key], int): unified_config[key] = int(unified_config[key]) - + patch_size = patcher_args.get("patch_size", 8) if isinstance(patch_size, float): patch_size = int(patch_size) - - unified_config.update({ - "patch_in_forward": True, - "realtime_patching": True, - "patching_mode": "entropy", - - "patch_size": patch_size, - "patching_threshold": patcher_args.get("threshold", 0.5), - "patching_threshold_add": patcher_args.get("threshold_add", 0.0), - "max_patch_length": patcher_args.get("max_patch_length"), - "patching_batch_size": patcher_args.get("patching_batch_size", 1), - "patching_device": patcher_args.get("patching_device", "cuda"), - "monotonicity": patcher_args.get("monotonicity", False), - - "patcher_vocab_size": int(entropy_model_params.get("vocab_size", 256)), - "patcher_dim": int(entropy_model_params.get("dim", 512)), - "patcher_n_layers": int(entropy_model_params.get("n_layers", 8)), - "patcher_n_heads": int(entropy_model_params.get("n_heads", 8)), - "patcher_head_dim": int(entropy_model_params.get("head_dim")) if entropy_model_params.get("head_dim") is not None else None, - "patcher_n_kv_heads": int(entropy_model_params.get("n_kv_heads")) if entropy_model_params.get("n_kv_heads") is not None else None, - "patcher_max_seqlen": int(entropy_model_params.get("max_seqlen", 1024)), - "patcher_norm_eps": entropy_model_params.get("norm_eps", 1e-5), - "patcher_dropout": entropy_model_params.get("dropout", 0.0), - "patcher_sliding_window": int(entropy_model_params.get("sliding_window", 512)) if entropy_model_params.get("sliding_window") is not None else None, - "patcher_ffn_dim_multiplier": entropy_model_params.get("ffn_dim_multiplier"), - "patcher_multiple_of": int(entropy_model_params.get("multiple_of", 256)), - "patcher_rope_theta": entropy_model_params.get("rope_theta", 10000.0), - "patcher_rope_use_fp32_in_outer_product": entropy_model_params.get("rope_use_fp32_in_outer_product", False), - "patcher_attn_impl": entropy_model_params.get("attn_impl", "sdpa"), - "patcher_attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"), - "patcher_init_base_std": entropy_model_params.get("init_base_std"), - "patcher_init_std_factor": entropy_model_params.get("init_std_factor", "disabled"), - "patcher_dim_token_emb": entropy_model_params.get("dim_token_emb"), - "patcher_weight_tying": entropy_model_params.get("weight_tying", False), - "patcher_bos_token_id": entropy_model_params.get("bos_token_id", 1), - "patcher_eos_token_id": entropy_model_params.get("eos_token_id", 2), - }) - + + unified_config.update( + { + "patch_in_forward": True, + "realtime_patching": True, + "patching_mode": "entropy", + "patch_size": patch_size, + "patching_threshold": patcher_args.get("threshold", 0.5), + "patching_threshold_add": patcher_args.get("threshold_add", 0.0), + "max_patch_length": patcher_args.get("max_patch_length"), + "patching_batch_size": patcher_args.get("patching_batch_size", 1), + "patching_device": patcher_args.get("patching_device", "cuda"), + "monotonicity": patcher_args.get("monotonicity", False), + "patcher_vocab_size": int(entropy_model_params.get("vocab_size", 256)), + "patcher_dim": int(entropy_model_params.get("dim", 512)), + "patcher_n_layers": int(entropy_model_params.get("n_layers", 8)), + "patcher_n_heads": int(entropy_model_params.get("n_heads", 8)), + "patcher_head_dim": int(entropy_model_params.get("head_dim")) + if entropy_model_params.get("head_dim") is not None + else None, + "patcher_n_kv_heads": int(entropy_model_params.get("n_kv_heads")) + if entropy_model_params.get("n_kv_heads") is not None + else None, + "patcher_max_seqlen": int(entropy_model_params.get("max_seqlen", 1024)), + "patcher_norm_eps": entropy_model_params.get("norm_eps", 1e-5), + "patcher_dropout": entropy_model_params.get("dropout", 0.0), + "patcher_sliding_window": int(entropy_model_params.get("sliding_window", 512)) + if entropy_model_params.get("sliding_window") is not None + else None, + "patcher_ffn_dim_multiplier": entropy_model_params.get("ffn_dim_multiplier"), + "patcher_multiple_of": int(entropy_model_params.get("multiple_of", 256)), + "patcher_rope_theta": entropy_model_params.get("rope_theta", 10000.0), + "patcher_rope_use_fp32_in_outer_product": entropy_model_params.get( + "rope_use_fp32_in_outer_product", False + ), + "patcher_attn_impl": entropy_model_params.get("attn_impl", "sdpa"), + "patcher_attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"), + "patcher_init_base_std": entropy_model_params.get("init_base_std"), + "patcher_init_std_factor": entropy_model_params.get("init_std_factor", "disabled"), + "patcher_dim_token_emb": entropy_model_params.get("dim_token_emb"), + "patcher_weight_tying": entropy_model_params.get("weight_tying", False), + "patcher_bos_token_id": entropy_model_params.get("bos_token_id", 1), + "patcher_eos_token_id": entropy_model_params.get("eos_token_id", 2), + } + ) + logger.info(f"Merged configuration with {len(unified_config)} parameters") return unified_config def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]: logger.info("Merging model weights") - + main_weights = load_file(weights_path) logger.info(f"Loaded main model weights: {len(main_weights)} tensors") - - entropy_weights = torch.load(entropy_weights_path, map_location='cpu', weights_only=True) - - if 'model' in entropy_weights: - entropy_weights = entropy_weights['model'] - elif 'state_dict' in entropy_weights: - entropy_weights = entropy_weights['state_dict'] - + + entropy_weights = torch.load(entropy_weights_path, map_location="cpu", weights_only=True) + + if "model" in entropy_weights: + entropy_weights = entropy_weights["model"] + elif "state_dict" in entropy_weights: + entropy_weights = entropy_weights["state_dict"] + logger.info(f"Loaded entropy model weights: {len(entropy_weights)} tensors") - + # unified state dict unified_weights = main_weights.copy() - + # Add entropy model weights with "patcher." prefix for key, tensor in entropy_weights.items(): patcher_key = f"patcher.{key}" unified_weights[patcher_key] = tensor - + logger.info(f"Merged weights: {len(unified_weights)} tensors total") return unified_weights def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): logger.info("Creating tokenizer config") - + tokenizer_config = { "tokenizer_class": "BltTokenizer", "vocab_size": config.get("vocab_size", 256), @@ -158,36 +151,41 @@ def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): "pad_token": "", "unk_token": "", } - + tokenizer_path = os.path.join(output_dir, "tokenizer_config.json") - with open(tokenizer_path, 'w') as f: + with open(tokenizer_path, "w") as f: json.dump(tokenizer_config, f, indent=2) - + logger.info(f"Tokenizer config saved to {tokenizer_path}") def validate_unified_model(config: Dict[str, Any], weights: Dict[str, torch.Tensor]): logger.info("Validating unified model") - + required_keys = [ - "vocab_size", "dim", "n_layers", "n_heads", - "patch_in_forward", "patcher_vocab_size", "patcher_dim" + "vocab_size", + "dim", + "n_layers", + "n_heads", + "patch_in_forward", + "patcher_vocab_size", + "patcher_dim", ] - + missing_keys = [key for key in required_keys if key not in config] if missing_keys: logger.warning(f"Missing configuration keys: {missing_keys}") - + # Check for patcher weights patcher_weights = [key for key in weights.keys() if key.startswith("patcher.")] if not patcher_weights: logger.warning("No patcher weights found in unified weights") else: logger.info(f"Found {len(patcher_weights)} patcher weight tensors") - + main_weights = [key for key in weights.keys() if not key.startswith("patcher.")] logger.info(f"Found {len(main_weights)} main model weight tensors") - + try: logger.info("Testing model instantiation...") blt_config = BLTConfig(**config) @@ -203,10 +201,10 @@ def validate_unified_model(config: Dict[str, Any], weights: Dict[str, torch.Tens logger.info("Weight loading successful") except Exception as weight_error: logger.warning(f"Weight loading failed: {weight_error}") - + except Exception as e: logger.error(f"Model validation failed: {e}") - + logger.info("Model validation completed") @@ -217,7 +215,6 @@ def push_to_hub( private: bool = False, token: Optional[str] = None, ) -> None: - try: # Upload the entire directory to the Hub upload_folder( @@ -228,7 +225,7 @@ def push_to_hub( token=token, ) logger.info(f"Successfully pushed model to {repo_id}") - + except Exception as e: logger.error(f"Failed to push model to Hub: {e}") raise @@ -246,34 +243,28 @@ def convert_hf_blt_to_unified( hub_token: Optional[str] = None, ) -> None: logger.info(f"Converting {model_id} to unified transformers format") - + file_paths = download_model_files(model_id, cache_dir) - + # Merge configurations - unified_config = merge_configurations( - file_paths["config"], - file_paths["entropy_params"] - ) - + unified_config = merge_configurations(file_paths["config"], file_paths["entropy_params"]) + # Merge weights - unified_weights = merge_weights( - file_paths["weights"], - file_paths["entropy_weights"] - ) - + unified_weights = merge_weights(file_paths["weights"], file_paths["entropy_weights"]) + validate_unified_model(unified_config, unified_weights) - + os.makedirs(output_dir, exist_ok=True) - + config_path = os.path.join(output_dir, config_name) - with open(config_path, 'w') as f: + with open(config_path, "w") as f: json.dump(unified_config, f, indent=2) - if safe_serialization and weights_name.endswith('.bin'): - weights_name = weights_name.replace('.bin', '.safetensors') - elif not safe_serialization and weights_name.endswith('.safetensors'): - weights_name = weights_name.replace('.safetensors', '.bin') - + if safe_serialization and weights_name.endswith(".bin"): + weights_name = weights_name.replace(".bin", ".safetensors") + elif not safe_serialization and weights_name.endswith(".safetensors"): + weights_name = weights_name.replace(".safetensors", ".bin") + weights_path = os.path.join(output_dir, weights_name) if safe_serialization: save_file(unified_weights, weights_path) @@ -281,16 +272,16 @@ def convert_hf_blt_to_unified( torch.save(unified_weights, weights_path) logger.info(f"Unified config and weights saved to {weights_path}") - + create_tokenizer_config(output_dir, unified_config) logger.info(f"Conversion completed, model saved to: {output_dir}") - + if push_to_hub_repo: push_to_hub( local_dir=output_dir, repo_id=push_to_hub_repo, - commit_message=f"Upload BLT model converted", + commit_message="Upload BLT model converted", private=hub_private, token=hub_token, ) @@ -299,9 +290,9 @@ def convert_hf_blt_to_unified( def main(): parser = argparse.ArgumentParser( description="Convert BLT models from HuggingFace Hub format to unified format", - formatter_class=argparse.RawDescriptionHelpFormatter + formatter_class=argparse.RawDescriptionHelpFormatter, ) - + parser.add_argument( "--model_id", type=str, @@ -312,7 +303,7 @@ def main(): type=str, default="./new_unified_blt_debug", ) - + # Optional parser.add_argument( "--config_name", @@ -342,9 +333,9 @@ def main(): parser.add_argument( "--debug", action="store_true", - default=True, + default=True, ) - + parser.add_argument( "--push_to_hub", type=str, @@ -360,12 +351,12 @@ def main(): type=str, default="hf_your_token_here", ) - + args = parser.parse_args() - + transformers_logging.set_verbosity_debug() logging.basicConfig(level=logging.DEBUG) - + try: convert_hf_blt_to_unified( model_id=args.model_id, @@ -384,4 +375,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/demo_hf.py b/src/demo_hf.py index 68fc1565f719..d88be6783480 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -1,11 +1,8 @@ -import json import logging import os import torch -from huggingface_hub import hf_hub_download -from transformers.models.blt_wip.configuration_blt import BLTConfig from transformers.models.blt_wip.modeling_blt_wip import BLTModel from transformers.models.blt_wip.tokenizers.blt_tokenizer import BltTokenizer diff --git a/src/transformers/models/blt_wip/blt_args.py b/src/transformers/models/blt_wip/blt_args.py index 78fc3482c362..e043d1dd20a8 100644 --- a/src/transformers/models/blt_wip/blt_args.py +++ b/src/transformers/models/blt_wip/blt_args.py @@ -1,8 +1,10 @@ from enum import Enum -from typing import Any, Optional +from typing import Any + from pydantic import BaseModel, ConfigDict, model_validator from typing_extensions import Self + EOS_ID: int = 2 @@ -24,64 +26,66 @@ class PatchingModeEnum(str, Enum): class LMTransformerArgs(BaseModel): """Arguments for the Language Model Transformer (used as entropy model for patching)""" + model_config = ConfigDict() - + # Basic architecture dim: int = 512 n_layers: int = 8 head_dim: int | None = None n_heads: int | None = None n_kv_heads: int | None = None - + # Transformer configuration max_seqlen: int = 1024 norm_eps: float = 1e-5 dropout: float = 0 vocab_size: int = -1 sliding_window: int | None = None - + # Feedforward ffn_dim_multiplier: float | None = None multiple_of: int = 256 - + # Positional encoding rope_theta: float = 10000.0 rope_use_fp32_in_outer_product: bool = False - + # Attention attn_impl: str = "sdpa" attn_bias_type: str = "causal" - + # Initialization init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED - + # Embedding dimensions dim_token_emb: int | None = None - + # Model behavior weight_tying: bool = False seed: int = 42 - + # Special token config eos_id: int = EOS_ID class ByteLatentTransformerArgs(BaseModel): """Arguments for the Byte Latent Transformer (main BLT model)""" + model_config = ConfigDict() - + # Basic model configuration seed: int = 42 vocab_size: int = -1 - + # Main architecture dimensions (these will be used for creating transformer args) dim: int = 512 n_layers: int = 8 head_dim: int | None = None n_heads: int | None = None n_kv_heads: int | None = None - + # Component-specific dimensions dim_global: int = 512 dim_local_decoder: int = 512 @@ -93,31 +97,31 @@ class ByteLatentTransformerArgs(BaseModel): n_heads_local_decoder: int = 8 n_heads_local_encoder: int = 8 n_kv_heads_global: int | None = None - + # Transformer configuration (needed by transformer components) max_seqlen: int = 1024 norm_eps: float = 1e-5 dropout: float = 0 - + # Feedforward (needed by transformer components) ffn_dim_multiplier: float = 1.0 multiple_of: int = 256 - + # Positional encoding (needed by transformer components) rope_theta: float = 10000.0 rope_use_fp32_in_outer_product: bool = False - + # Attention (needed by transformer components) attn_impl: str = "sdpa" attn_bias_type: str = "causal" - + # Initialization (needed by transformer components) init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED - + # Embedding dimensions (needed by transformer components) dim_token_emb: int | None = None - + # Patching configuration patch_in_forward: bool = False realtime_patching: bool = True @@ -130,7 +134,7 @@ class ByteLatentTransformerArgs(BaseModel): patching_device: str = "cuda" max_patch_length: int | None = None entropy_model_checkpoint_dir: str | None = None - + # Cross attention configurations cross_attn_encoder: bool = False cross_attn_decoder: bool = False @@ -142,7 +146,7 @@ class ByteLatentTransformerArgs(BaseModel): cross_attn_all_layers_encoder: bool = False cross_attn_use_flex_attention: bool = True cross_attn_init_by_pooling: bool = False - + # Encoder configurations use_local_encoder_transformer: bool = False max_encoder_seq_length: int | None = None @@ -152,38 +156,32 @@ class ByteLatentTransformerArgs(BaseModel): encoder_enable_byte_ngrams: bool = False encoder_ngram_to_size_str: str | None = None downsampling_by_pooling: str | None = None - + # Architecture and dimensions dim_token: int | None = None share_encoder_decoder_emb: bool = True weight_tying: bool = False - + # Attention configuration local_attention_window_len: int | None = None use_rope: bool = True - + # Performance optimization sequence_parallel: bool = False loss_parallel: bool = False fuse_sequence_parallel: bool = False use_fsdp: bool = True - + # Parameter mixing pm_size: int = 0 - + # Special token config eos_id: int = EOS_ID @model_validator(mode="after") def check_hash_byte_sizes(self) -> Self: - if ( - self.encoder_hash_byte_group_size is not None - and type(self.encoder_hash_byte_group_size) == str - ): + if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str: self.encoder_hash_byte_group_size = [ - int(x) - for x in self.encoder_hash_byte_group_size.split(",") - if len(x) > 0 + int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0 ] return self - diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index 645906a1f2fd..7c645e8b18f6 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -15,7 +15,6 @@ """BLT (Byte Latent Transformer) model configuration""" from enum import Enum -from typing import Any, Optional from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -51,7 +50,7 @@ class BLTConfig(PretrainedConfig): Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented. max_seqlen (`int`, *optional*, defaults to 1024): The maximum sequence length that this model can handle. - + # Main architecture dimensions dim (`int`, *optional*, defaults to 512): Main dimension of the model. @@ -63,7 +62,7 @@ class BLTConfig(PretrainedConfig): Dimension of each attention head. If not specified, computed as dim // n_heads. n_kv_heads (`int`, *optional*): Number of key-value heads for grouped query attention. If not specified, defaults to n_heads. - + # Component-specific dimensions dim_global (`int`, *optional*, defaults to 512): Dimension of the global transformer component. @@ -85,7 +84,7 @@ class BLTConfig(PretrainedConfig): Number of attention heads in the local encoder. n_kv_heads_global (`int`, *optional*): Number of key-value heads in the global transformer. - + # Transformer configuration norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon used by the layer normalization layers. @@ -95,13 +94,13 @@ class BLTConfig(PretrainedConfig): Multiplier for the feedforward network dimension. multiple_of (`int`, *optional*, defaults to 256): Make feedforward network dimension multiple of this value. - + # Positional encoding rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False): Whether to use fp32 in RoPE outer product computation. - + # Attention configuration attn_impl (`str`, *optional*, defaults to "sdpa"): Attention implementation to use ("sdpa" or "flex_attention"). @@ -111,19 +110,19 @@ class BLTConfig(PretrainedConfig): Window length for local attention. use_rope (`bool`, *optional*, defaults to True): Whether to use rotary position embeddings. - + # Initialization init_base_std (`float`, *optional*): Base standard deviation for weight initialization. init_std_factor (`str`, *optional*, defaults to "disabled"): Factor for adjusting initialization standard deviation. - + # Embedding dimensions dim_token_emb (`int`, *optional*): Token embedding dimension. dim_token (`int`, *optional*): Token dimension. - + # Patching configuration patch_in_forward (`bool`, *optional*, defaults to False): Whether to perform patching during forward pass. @@ -147,7 +146,7 @@ class BLTConfig(PretrainedConfig): Maximum length of patches. entropy_model_checkpoint_dir (`str`, *optional*): Directory containing entropy model checkpoint. - + # Cross attention configurations cross_attn_encoder (`bool`, *optional*, defaults to False): Whether to use cross attention in encoder. @@ -169,7 +168,7 @@ class BLTConfig(PretrainedConfig): Whether to use flexible attention for cross attention. cross_attn_init_by_pooling (`bool`, *optional*, defaults to False): Whether to initialize cross attention by pooling. - + # Encoder configurations use_local_encoder_transformer (`bool`, *optional*, defaults to False): Whether to use transformer in local encoder. @@ -187,13 +186,13 @@ class BLTConfig(PretrainedConfig): String defining n-gram sizes. downsampling_by_pooling (`str`, *optional*): Type of pooling for downsampling. - + # Model behavior share_encoder_decoder_emb (`bool`, *optional*, defaults to True): Whether to share encoder and decoder embeddings. weight_tying (`bool`, *optional*, defaults to False): Whether to tie input and output embeddings. - + # Performance optimization sequence_parallel (`bool`, *optional*, defaults to False): Whether to use sequence parallelism. @@ -203,11 +202,11 @@ class BLTConfig(PretrainedConfig): Whether to fuse sequence parallel operations. use_fsdp (`bool`, *optional*, defaults to True): Whether to use fully sharded data parallel. - + # Parameter mixing pm_size (`int`, *optional*, defaults to 0): Parameter mixing size. - + # Special tokens bos_token_id (`int`, *optional*, defaults to 1): The id of the "beginning-of-sequence" token. @@ -215,7 +214,7 @@ class BLTConfig(PretrainedConfig): The id of the "end-of-sequence" token. pad_token_id (`int`, *optional*, defaults to -1): The id of the padding token. - + # Patcher/Entropy model configuration patcher_vocab_size (`int`, *optional*, defaults to 256): Vocabulary size for the entropy model used in patching. @@ -282,14 +281,12 @@ def __init__( self, vocab_size=256, max_seqlen=1024, - # Main architecture dimensions dim=512, n_layers=8, n_heads=8, head_dim=None, n_kv_heads=None, - # Component-specific dimensions dim_global=512, dim_local_decoder=512, @@ -301,31 +298,25 @@ def __init__( n_heads_local_decoder=8, n_heads_local_encoder=8, n_kv_heads_global=None, - # Transformer configuration norm_eps=1e-5, dropout=0.0, ffn_dim_multiplier=1.0, multiple_of=256, - # Positional encoding rope_theta=10000.0, rope_use_fp32_in_outer_product=False, - # Attention configuration attn_impl="sdpa", attn_bias_type="causal", local_attention_window_len=None, use_rope=True, - # Initialization init_base_std=None, init_std_factor="disabled", - # Embedding dimensions dim_token_emb=None, dim_token=None, - # Patching configuration patch_in_forward=False, realtime_patching=True, @@ -338,7 +329,6 @@ def __init__( patching_device="cuda", max_patch_length=None, entropy_model_checkpoint_dir=None, - # Cross attention configurations cross_attn_encoder=False, cross_attn_decoder=False, @@ -350,7 +340,6 @@ def __init__( cross_attn_all_layers_encoder=False, cross_attn_use_flex_attention=True, cross_attn_init_by_pooling=False, - # Encoder configurations use_local_encoder_transformer=False, max_encoder_seq_length=None, @@ -360,25 +349,20 @@ def __init__( encoder_enable_byte_ngrams=False, encoder_ngram_to_size_str=None, downsampling_by_pooling=None, - # Model behavior share_encoder_decoder_emb=True, weight_tying=False, - # Performance optimization sequence_parallel=False, loss_parallel=False, fuse_sequence_parallel=False, use_fsdp=True, - # Parameter mixing pm_size=0, - # Special tokens bos_token_id=1, eos_token_id=2, pad_token_id=-1, - # Patcher/Entropy model configuration patcher_vocab_size=256, patcher_dim=512, @@ -402,21 +386,20 @@ def __init__( patcher_weight_tying=False, patcher_bos_token_id=1, patcher_eos_token_id=2, - # Inherited **kwargs, ): # Basic model configuration self.vocab_size = vocab_size self.max_seqlen = max_seqlen - + # Main architecture dimensions self.dim = dim self.n_layers = n_layers self.n_heads = n_heads self.head_dim = head_dim self.n_kv_heads = n_kv_heads - + # Component-specific dimensions self.dim_global = dim_global self.dim_local_decoder = dim_local_decoder @@ -428,31 +411,31 @@ def __init__( self.n_heads_local_decoder = n_heads_local_decoder self.n_heads_local_encoder = n_heads_local_encoder self.n_kv_heads_global = n_kv_heads_global - + # Transformer configuration self.norm_eps = norm_eps self.dropout = dropout self.ffn_dim_multiplier = ffn_dim_multiplier self.multiple_of = multiple_of - + # Positional encoding self.rope_theta = rope_theta self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product - + # Attention configuration self.attn_impl = attn_impl self.attn_bias_type = attn_bias_type self.local_attention_window_len = local_attention_window_len self.use_rope = use_rope - + # Initialization self.init_base_std = init_base_std self.init_std_factor = InitStdFactor(init_std_factor) - + # Embedding dimensions self.dim_token_emb = dim_token_emb self.dim_token = dim_token - + # Patching configuration self.patch_in_forward = patch_in_forward self.realtime_patching = realtime_patching @@ -465,7 +448,7 @@ def __init__( self.patching_device = patching_device self.max_patch_length = max_patch_length self.entropy_model_checkpoint_dir = entropy_model_checkpoint_dir - + # Cross attention configurations self.cross_attn_encoder = cross_attn_encoder self.cross_attn_decoder = cross_attn_decoder @@ -477,7 +460,7 @@ def __init__( self.cross_attn_all_layers_encoder = cross_attn_all_layers_encoder self.cross_attn_use_flex_attention = cross_attn_use_flex_attention self.cross_attn_init_by_pooling = cross_attn_init_by_pooling - + # Encoder configurations self.use_local_encoder_transformer = use_local_encoder_transformer self.max_encoder_seq_length = max_encoder_seq_length @@ -487,20 +470,20 @@ def __init__( self.encoder_enable_byte_ngrams = encoder_enable_byte_ngrams self.encoder_ngram_to_size_str = encoder_ngram_to_size_str self.downsampling_by_pooling = downsampling_by_pooling - + # Model behavior self.share_encoder_decoder_emb = share_encoder_decoder_emb self.weight_tying = weight_tying - + # Performance optimization self.sequence_parallel = sequence_parallel self.loss_parallel = loss_parallel self.fuse_sequence_parallel = fuse_sequence_parallel self.use_fsdp = use_fsdp - + # Parameter mixing self.pm_size = pm_size - + # Patcher/Entropy model configuration self.patcher_vocab_size = patcher_vocab_size self.patcher_dim = patcher_dim @@ -526,14 +509,9 @@ def __init__( self.patcher_eos_token_id = patcher_eos_token_id # Handle hash byte group size validation - if ( - self.encoder_hash_byte_group_size is not None - and type(self.encoder_hash_byte_group_size) == str - ): + if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str: self.encoder_hash_byte_group_size = [ - int(x) - for x in self.encoder_hash_byte_group_size.split(",") - if len(x) > 0 + int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0 ] super().__init__( @@ -581,12 +559,7 @@ def global_dim_patch_emb(self): patch_size = self.patch_size if self.patch_size is not None else 8 return dim_token_emb * patch_size else: - return dim_token_emb * sum( - [ - pooling in self.downsampling_by_pooling - for pooling in ["avg", "min", "max"] - ] - ) + return dim_token_emb * sum([pooling in self.downsampling_by_pooling for pooling in ["avg", "min", "max"]]) @property def decoder_dim_token_emb(self): @@ -601,10 +574,10 @@ def decoder_dim_token_emb(self): def get_init_std_factor(self, depth: int) -> float: """ Calculate the initialization standard deviation scaling factor for a given layer depth. - + Args: depth: Current layer depth (0-indexed) - + Returns: Scaling factor to divide the base initialization std by """ @@ -614,4 +587,4 @@ def get_init_std_factor(self, depth: int) -> float: return 1.0 -__all__ = ["BLTConfig", "InitStdFactor", "PatchingModeEnum"] \ No newline at end of file +__all__ = ["BLTConfig", "InitStdFactor", "PatchingModeEnum"] diff --git a/src/transformers/models/blt_wip/convert_hf_blt_original_to_unified.py b/src/transformers/models/blt_wip/convert_hf_blt_original_to_unified.py new file mode 100644 index 000000000000..dad247b19c62 --- /dev/null +++ b/src/transformers/models/blt_wip/convert_hf_blt_original_to_unified.py @@ -0,0 +1,540 @@ +import argparse +import json +import logging +import os +from typing import Dict, Any, Optional + +import torch +from huggingface_hub import hf_hub_download, snapshot_download +from safetensors.torch import load_file, save_file + +from transformers.utils import logging as transformers_logging + +logger = transformers_logging.get_logger(__name__) +transformers_logging.set_verbosity_info() + +# For standalone execution, we'll skip the model validation to avoid import issues +# The script will create the unified config and weights files without testing model instantiation +ENABLE_MODEL_VALIDATION = False + +import sys +import os + +from transformers.models.blt_wip.modeling_blt_wip import BLTModel +from transformers.models.blt_wip.configuration_blt import BLTConfig + + +ENABLE_MODEL_VALIDATION = True + +def download_model_files(model_id: str, cache_dir: Optional[str] = None) -> Dict[str, str]: + """ + Download all necessary files from HuggingFace Hub. + + Args: + model_id: HuggingFace model ID (e.g., "facebook/blt-1b") + cache_dir: Optional cache directory + + Returns: + Dictionary with paths to downloaded files + """ + logger.info(f"Downloading model files from {model_id}...") + + try: + # Download main config + config_path = hf_hub_download( + repo_id=model_id, + filename="config.json", + cache_dir=cache_dir + ) + + # Download main model weights + weights_path = hf_hub_download( + repo_id=model_id, + filename="model.safetensors", + cache_dir=cache_dir + ) + + # Download entropy model params + entropy_params_path = hf_hub_download( + repo_id=model_id, + filename="entropy_model/params.json", + cache_dir=cache_dir + ) + + # Download entropy model weights + entropy_weights_path = hf_hub_download( + repo_id=model_id, + filename="entropy_model/consolidated.pth", + cache_dir=cache_dir + ) + + return { + "config": config_path, + "weights": weights_path, + "entropy_params": entropy_params_path, + "entropy_weights": entropy_weights_path + } + + except Exception as e: + logger.error(f"Failed to download files from {model_id}: {e}") + raise + + +def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]: + """ + Merge main configuration with entropy model parameters. + + Args: + config_path: Path to main config.json + entropy_params_path: Path to entropy_model/params.json + + Returns: + Merged configuration dictionary + """ + logger.info("Merging configurations...") + + # Load main configuration + with open(config_path, 'r') as f: + main_config = json.load(f) + + # Load entropy model parameters + with open(entropy_params_path, 'r') as f: + entropy_data = json.load(f) + + # Extract entropy model and patcher parameters + entropy_model_params = entropy_data.get("entropy_model", {}) + patcher_args = entropy_data.get("data", {}).get("patcher_args", {}) + + # Create unified configuration + unified_config = main_config.copy() + + # Ensure required main model parameters are present with correct types + # Sometimes the original config may have different key names + if "vocab_size" not in unified_config: + unified_config["vocab_size"] = int(main_config.get("vocab_size", 256)) + if "dim" not in unified_config: + unified_config["dim"] = int(main_config.get("dim", main_config.get("hidden_size", main_config.get("d_model", 512)))) + if "n_layers" not in unified_config: + unified_config["n_layers"] = int(main_config.get("n_layers", main_config.get("num_layers", main_config.get("num_hidden_layers", 8)))) + if "n_heads" not in unified_config: + unified_config["n_heads"] = int(main_config.get("n_heads", main_config.get("num_attention_heads", main_config.get("num_heads", 8)))) + if "max_seqlen" not in unified_config: + unified_config["max_seqlen"] = int(main_config.get("max_seqlen", main_config.get("max_position_embeddings", main_config.get("seq_length", 1024)))) + + # Ensure other integer parameters are properly typed + for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]: + if key in unified_config and not isinstance(unified_config[key], int): + unified_config[key] = int(unified_config[key]) + + # Convert all patch_size values to integers to avoid float/int type errors + patch_size = patcher_args.get("patch_size", 8) + if isinstance(patch_size, float): + patch_size = int(patch_size) + + # Add patching configuration + unified_config.update({ + "patch_in_forward": True, + "realtime_patching": True, + "patching_mode": "entropy", + + # Patcher arguments + "patch_size": patch_size, + "patching_threshold": patcher_args.get("threshold", 0.5), + "patching_threshold_add": patcher_args.get("threshold_add", 0.0), + "max_patch_length": patcher_args.get("max_patch_length"), + "patching_batch_size": patcher_args.get("patching_batch_size", 1), + "patching_device": patcher_args.get("patching_device", "cuda"), + "monotonicity": patcher_args.get("monotonicity", False), + + # Entropy model (patcher) architecture parameters + "patcher_vocab_size": int(entropy_model_params.get("vocab_size", 256)), + "patcher_dim": int(entropy_model_params.get("dim", 512)), + "patcher_n_layers": int(entropy_model_params.get("n_layers", 8)), + "patcher_n_heads": int(entropy_model_params.get("n_heads", 8)), + "patcher_head_dim": int(entropy_model_params.get("head_dim")) if entropy_model_params.get("head_dim") is not None else None, + "patcher_n_kv_heads": int(entropy_model_params.get("n_kv_heads")) if entropy_model_params.get("n_kv_heads") is not None else None, + "patcher_max_seqlen": int(entropy_model_params.get("max_seqlen", 1024)), + "patcher_norm_eps": entropy_model_params.get("norm_eps", 1e-5), + "patcher_dropout": entropy_model_params.get("dropout", 0.0), + "patcher_sliding_window": int(entropy_model_params.get("sliding_window", 512)) if entropy_model_params.get("sliding_window") is not None else None, + "patcher_ffn_dim_multiplier": entropy_model_params.get("ffn_dim_multiplier"), + "patcher_multiple_of": int(entropy_model_params.get("multiple_of", 256)), + "patcher_rope_theta": entropy_model_params.get("rope_theta", 10000.0), + "patcher_rope_use_fp32_in_outer_product": entropy_model_params.get("rope_use_fp32_in_outer_product", False), + "patcher_attn_impl": entropy_model_params.get("attn_impl", "sdpa"), + "patcher_attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"), + "patcher_init_base_std": entropy_model_params.get("init_base_std"), + "patcher_init_std_factor": entropy_model_params.get("init_std_factor", "disabled"), + "patcher_dim_token_emb": entropy_model_params.get("dim_token_emb"), + "patcher_weight_tying": entropy_model_params.get("weight_tying", False), + "patcher_bos_token_id": entropy_model_params.get("bos_token_id", 1), + "patcher_eos_token_id": entropy_model_params.get("eos_token_id", 2), + }) + + logger.info(f"Merged configuration with {len(unified_config)} parameters") + return unified_config + + +def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]: + """ + Merge main model weights with entropy model weights. + + Args: + weights_path: Path to main model.safetensors + entropy_weights_path: Path to entropy_model/consolidated.pth + + Returns: + Merged state dictionary + """ + logger.info("Merging model weights...") + + # Load main model weights + main_weights = load_file(weights_path) + logger.info(f"Loaded main model weights: {len(main_weights)} tensors") + + # Load entropy model weights + entropy_weights = torch.load(entropy_weights_path, map_location='cpu', weights_only=True) + + # Handle nested entropy model structure + if 'model' in entropy_weights: + entropy_weights = entropy_weights['model'] + elif 'state_dict' in entropy_weights: + entropy_weights = entropy_weights['state_dict'] + + logger.info(f"Loaded entropy model weights: {len(entropy_weights)} tensors") + + # Create unified state dict + unified_weights = main_weights.copy() + + # Add entropy model weights with "patcher." prefix + for key, tensor in entropy_weights.items(): + patcher_key = f"patcher.{key}" + unified_weights[patcher_key] = tensor + + logger.info(f"Merged weights: {len(unified_weights)} tensors total") + return unified_weights + + +def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): + """ + Create tokenizer configuration file. + + Args: + output_dir: Output directory + config: Model configuration + """ + logger.info("Creating tokenizer configuration...") + + tokenizer_config = { + "tokenizer_class": "BltTokenizer", + "vocab_size": config.get("vocab_size", 256), + "model_max_length": config.get("max_seqlen", 1024), + "add_bos_token": True, + "add_eos_token": True, + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "", + } + + tokenizer_path = os.path.join(output_dir, "tokenizer_config.json") + with open(tokenizer_path, 'w') as f: + json.dump(tokenizer_config, f, indent=2) + + logger.info(f"Tokenizer config saved to {tokenizer_path}") + + +def validate_unified_model(config: Dict[str, Any], weights: Dict[str, torch.Tensor]): + """ + Validate the unified model configuration and weights. + + Args: + config: Unified configuration + weights: Unified weights + """ + logger.info("Validating unified model...") + + # Check required configuration keys + required_keys = [ + "vocab_size", "dim", "n_layers", "n_heads", + "patch_in_forward", "patcher_vocab_size", "patcher_dim" + ] + + missing_keys = [key for key in required_keys if key not in config] + if missing_keys: + logger.warning(f"Missing configuration keys: {missing_keys}") + + # Check for patcher weights + patcher_weights = [key for key in weights.keys() if key.startswith("patcher.")] + if not patcher_weights: + logger.warning("No patcher weights found in unified weights") + else: + logger.info(f"Found {len(patcher_weights)} patcher weight tensors") + + # Check for main model weights + main_weights = [key for key in weights.keys() if not key.startswith("patcher.")] + logger.info(f"Found {len(main_weights)} main model weight tensors") + + # Try to create the model with the configuration (if imports are available) + if ENABLE_MODEL_VALIDATION and BLTConfig is not None and BLTModel is not None: + try: + logger.info("Testing model instantiation...") + + # Debug: Print config keys to help diagnose issues + logger.debug(f"Config keys: {list(config.keys())}") + logger.debug(f"Config vocab_size: {config.get('vocab_size')} (type: {type(config.get('vocab_size'))})") + logger.debug(f"Config dim: {config.get('dim')} (type: {type(config.get('dim'))})") + + blt_config = BLTConfig(**config) + model = BLTModel(blt_config) + logger.info("✓ Model instantiation successful") + + # Try to load the weights + logger.info("Testing weight loading...") + try: + missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False) + if missing_keys: + logger.warning(f"Missing keys during weight loading: {missing_keys[:5]}...") # Show first 5 + if unexpected_keys: + logger.warning(f"Unexpected keys during weight loading: {unexpected_keys[:5]}...") # Show first 5 + logger.info("✓ Weight loading successful") + except Exception as weight_error: + logger.warning(f"Weight loading failed: {weight_error}") + logger.info("Model instantiation successful, but weight loading had issues") + + except Exception as e: + logger.error(f"Model validation failed: {e}") + logger.debug(f"Full error details:", exc_info=True) + logger.warning("Model may not be compatible with modeling_blt_wip.py") + logger.info("You can still use the converted files and test manually") + else: + logger.info("Skipping model instantiation test (BLT classes not available)") + logger.info("You can test the model manually after conversion") + + logger.info("Model validation completed") + + +def convert_hf_blt_to_unified( + model_id: str, + output_dir: str, + config_name: str = "config.json", + weights_name: str = "pytorch_model.bin", + safe_serialization: bool = True, + cache_dir: Optional[str] = None, + validate: bool = True, +) -> None: + """ + Convert BLT model from HuggingFace Hub format to unified format. + + Args: + model_id: HuggingFace model ID (e.g., "facebook/blt-1b") + output_dir: Output directory for unified model + config_name: Name for unified config file + weights_name: Name for unified weights file + safe_serialization: Whether to use safetensors format + cache_dir: Cache directory for downloads + validate: Whether to validate the unified model + """ + logger.info(f"Converting {model_id} to unified format...") + + # Download model files + file_paths = download_model_files(model_id, cache_dir) + + # Merge configurations + unified_config = merge_configurations( + file_paths["config"], + file_paths["entropy_params"] + ) + + # Merge weights + unified_weights = merge_weights( + file_paths["weights"], + file_paths["entropy_weights"] + ) + + # Validate if requested + if validate: + validate_unified_model(unified_config, unified_weights) + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # Save unified configuration + config_path = os.path.join(output_dir, config_name) + with open(config_path, 'w') as f: + json.dump(unified_config, f, indent=2) + logger.info(f"Unified config saved to {config_path}") + + # Save unified weights + if safe_serialization and weights_name.endswith('.bin'): + weights_name = weights_name.replace('.bin', '.safetensors') + elif not safe_serialization and weights_name.endswith('.safetensors'): + weights_name = weights_name.replace('.safetensors', '.bin') + + weights_path = os.path.join(output_dir, weights_name) + if safe_serialization: + save_file(unified_weights, weights_path) + else: + torch.save(unified_weights, weights_path) + logger.info(f"Unified weights saved to {weights_path}") + + # Create tokenizer config + create_tokenizer_config(output_dir, unified_config) + + # Create README + readme_path = os.path.join(output_dir, "README.md") + with open(readme_path, 'w') as f: + f.write(f"""# Unified BLT Model + +This model was converted from {model_id} to unified format compatible with modeling_blt_wip.py. + +## Files + +- `{config_name}`: Unified configuration (main config + entropy model params) +- `{weights_name}`: Unified weights (main model + entropy model weights with "patcher." prefix) +- `tokenizer_config.json`: Tokenizer configuration + +## Usage + +```python +import torch +import json +from modeling_blt_wip import BLTModel, BLTConfig + +# Load configuration +with open('{config_name}', 'r') as f: + config_dict = json.load(f) + +config = BLTConfig(**config_dict) + +# Load model +model = BLTModel(config) + +# Load weights +if '{weights_name}'.endswith('.safetensors'): + from safetensors.torch import load_file + state_dict = load_file('{weights_name}') +else: + state_dict = torch.load('{weights_name}', map_location='cpu') + +model.load_state_dict(state_dict, strict=False) +``` + +## Original Model + +Converted from: {model_id} +""") + + logger.info(f"Conversion completed! Unified model saved to: {output_dir}") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert BLT models from HuggingFace Hub format to unified format", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Convert facebook/blt-1b to unified format + python convert_hf_blt_to_unified.py \\ + --model_id facebook/blt-1b \\ + --output_dir ./unified_blt_1b + + # Convert with custom file names + python convert_hf_blt_to_unified.py \\ + --model_id facebook/blt-7b \\ + --output_dir ./unified_blt_7b \\ + --config_name unified_config.json \\ + --weights_name unified_model.safetensors + + # Convert without validation + python convert_hf_blt_to_unified.py \\ + --model_id facebook/blt-1b \\ + --output_dir ./my_blt \\ + --no_validate + """ + ) + + # Required arguments (with defaults for debugging) + parser.add_argument( + "--model_id", + type=str, + default="facebook/blt-1b", + help="HuggingFace model ID (e.g., facebook/blt-1b)" + ) + parser.add_argument( + "--output_dir", + type=str, + default="./unified_blt_debug", + help="Output directory for unified model" + ) + + # Optional arguments + parser.add_argument( + "--config_name", + type=str, + default="config.json", + help="Name for unified config file (default: config.json)" + ) + parser.add_argument( + "--weights_name", + type=str, + default="pytorch_model.bin", + help="Name for unified weights file (default: pytorch_model.bin)" + ) + parser.add_argument( + "--safe_serialization", + action="store_true", + default=True, + help="Use safetensors format for weights (default: True)" + ) + parser.add_argument( + "--no_safe_serialization", + dest="safe_serialization", + action="store_false", + help="Use .bin format instead of safetensors" + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="Cache directory for downloads" + ) + parser.add_argument( + "--no_validate", + dest="validate", + action="store_false", + default=True, + help="Skip model validation" + ) + parser.add_argument( + "--debug", + action="store_true", + default=True, # Enable debug by default for easier debugging + help="Enable debug logging" + ) + + args = parser.parse_args() + + # Setup logging + if args.debug: + transformers_logging.set_verbosity_debug() + logging.basicConfig(level=logging.DEBUG) + + # Run conversion + try: + convert_hf_blt_to_unified( + model_id=args.model_id, + output_dir=args.output_dir, + config_name=args.config_name, + weights_name=args.weights_name, + safe_serialization=args.safe_serialization, + cache_dir=args.cache_dir, + validate=args.validate, + ) + except Exception as e: + logger.error(f"Conversion failed: {e}") + raise + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/transformers/models/blt_wip/modeling_blt_wip.py b/src/transformers/models/blt_wip/modeling_blt_wip.py index 71e390077819..3cd5edccbc38 100644 --- a/src/transformers/models/blt_wip/modeling_blt_wip.py +++ b/src/transformers/models/blt_wip/modeling_blt_wip.py @@ -1,22 +1,20 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -from enum import Enum -from typing import Any, List, Optional, Tuple, Union - -import torch -from pydantic import model_validator -from torch import nn -from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention -import json import logging +import os +from typing import List, Optional, Tuple, Union import torch import torch.nn import torch.nn as nn from torch.nn import functional as F +from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention -import os -from contextlib import nullcontext +from ...modeling_utils import PreTrainedModel +from .configuration_blt import ( + BLTConfig, + PatchingModeEnum, +) SEP = " " BOS_ID: int = 1 @@ -32,15 +30,6 @@ logger = logging.getLogger() -from .configuration_blt import ( - BLTConfig, - PatchingModeEnum, - InitStdFactor, -) - -from ...modeling_utils import PreTrainedModel -from ...utils import logging as transformers_logging - flex_attention_comp = flex_attention @@ -72,9 +61,8 @@ def create_causal_mask( elif attn_impl == "flex_attention": return create_block_mask(causal_mask, None, None, seqlen, seqlen) else: - raise NotImplementedError( - f"Attention {attn_impl} with {sliding_window} sliding window not implemented" - ) + raise NotImplementedError(f"Attention {attn_impl} with {sliding_window} sliding window not implemented") + def cross_entropy(pred, target, **kwargs): return F.nll_loss( @@ -153,9 +141,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int 2, 2, ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" - shape = [ - d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2]) - ] + [2, 2] + shape = [d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])] + [2, 2] return freqs_cis.view(*shape) @@ -167,9 +153,7 @@ def apply_rotary_emb( ) -> Tuple[torch.Tensor, torch.Tensor]: xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 - freqs_cis = reshape_for_broadcast( - freqs_cis, xq_, seq_dim - ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, seq_dim).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 xq_out = (xq_ * freqs_cis).sum(5).flatten(3) xk_out = (xk_ * freqs_cis).sum(5).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) @@ -214,9 +198,7 @@ def reset_parameters(self): rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, ) - def forward( - self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None - ): + def forward(self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None): """ Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions Args: @@ -307,12 +289,12 @@ def forward( if attn_impl == "flex_attention": assert mask is None or isinstance(mask, BlockMask) - xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D elif attn_impl == "sdpa": - xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) assert mask is None or isinstance(mask, (str, torch.Tensor)) is_causal = (mask == "causal") if isinstance(mask, str) else False mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None @@ -325,10 +307,8 @@ def forward( ) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D else: - raise NotImplementedError( - f"Attention implementation {attn_impl} not supported" - ) - + raise NotImplementedError(f"Attention implementation {attn_impl} not supported") + output_reshaped = output.reshape(output_shape) output = self.wo(output_reshaped) @@ -431,18 +411,16 @@ def __init__(self, args): super().__init__() # Extract parameters from dictionary - dim = args['dim'] - n_heads = args['n_heads'] - head_dim = args['head_dim'] - n_kv_heads = args['n_kv_heads'] - rope_theta = args['rope_theta'] - multiple_of = args['multiple_of'] - ffn_dim_multiplier = args['ffn_dim_multiplier'] - norm_eps = args['norm_eps'] - - assert (head_dim is not None) or ( - n_heads is not None - ), "Should specify at least head_dim or n_heads" + dim = args["dim"] + n_heads = args["n_heads"] + head_dim = args["head_dim"] + n_kv_heads = args["n_kv_heads"] + rope_theta = args["rope_theta"] + multiple_of = args["multiple_of"] + ffn_dim_multiplier = args["ffn_dim_multiplier"] + norm_eps = args["norm_eps"] + + assert (head_dim is not None) or (n_heads is not None), "Should specify at least head_dim or n_heads" self.head_dim = head_dim or dim // n_heads self.n_heads = n_heads or dim // head_dim self.n_kv_heads = n_kv_heads or self.n_heads @@ -540,9 +518,8 @@ def rolling_polynomial_hash(t, hash_func_nb: int = 0): prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) return torch.sum(t * prime_powers, dim=-1) -def byte_group_hash_function( - x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 -): + +def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): """ Returns a hash of the input x and maps it to a value in the range [0, max_hash]. @@ -563,9 +540,7 @@ def byte_group_hash_function( return hash_values_range -def create_patch_mask_from_ids( - patch_ids, num_patches, window=None, patches_as_queries=False -): +def create_patch_mask_from_ids(patch_ids, num_patches, window=None, patches_as_queries=False): """ Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) is True if the patch id at position (i, j) is less than or equal to k. @@ -642,9 +617,7 @@ def patch_mask(b, h, q_idx, kv_idx): ) return block_mask else: - return torch.where( - cross_mask, torch.tensor(0.0), torch.tensor(float("-inf")) - ).unsqueeze( + return torch.where(cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).unsqueeze( 1 ) # [bs, 1, q_len, kv_len] @@ -750,7 +723,7 @@ def get_blt_input( class LocalModelBase(nn.Module): def __init__(self, config: BLTConfig, component_type: str = "encoder"): super().__init__() - + # Store config for later use self.config = config @@ -786,25 +759,23 @@ def __init__(self, config: BLTConfig, component_type: str = "encoder"): self.eos_id = config.eos_token_id self.boe_id = BOE_ID - + # Initialize cross attention layers as None (will be set by subclasses if needed) self.cross_attn_layers = None # Create parameter dict for BLTTransformerLayers layer_params = { - 'dim': self.dim, - 'n_heads': self.n_heads, - 'head_dim': config.head_dim, - 'n_kv_heads': getattr(config, 'n_kv_heads', None), - 'rope_theta': config.rope_theta, - 'multiple_of': getattr(config, 'multiple_of', 256), - 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None), - 'norm_eps': config.norm_eps, + "dim": self.dim, + "n_heads": self.n_heads, + "head_dim": config.head_dim, + "n_kv_heads": getattr(config, "n_kv_heads", None), + "rope_theta": config.rope_theta, + "multiple_of": getattr(config, "multiple_of", 256), + "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), + "norm_eps": config.norm_eps, } - self.layers = nn.ModuleList( - [BLTTransformerLayer(layer_params) for _ in range(self.n_layers)] - ) + self.layers = nn.ModuleList([BLTTransformerLayer(layer_params) for _ in range(self.n_layers)]) if not self.use_rope: self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length @@ -834,14 +805,12 @@ def __init__(self, config: BLTConfig, component_type: str = "encoder"): self.patch_embedding_projection = self._create_patch_projection(config) def _should_create_patch_projection(self, config: BLTConfig): - dimension_mismatch = ( - self.dim_patch_emb is not None and self.dim_patch_emb != self.dim - ) + dimension_mismatch = self.dim_patch_emb is not None and self.dim_patch_emb != self.dim # Check cross attention conditions - cross_attn_conditions = ( - config.cross_attn_encoder and config.cross_attn_init_by_pooling - ) or (config.cross_attn_decoder and config.cross_attn_init_by_pooling) + cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( + config.cross_attn_decoder and config.cross_attn_init_by_pooling + ) return dimension_mismatch or cross_attn_conditions @@ -954,9 +923,7 @@ def __init__(self, config: BLTConfig): def apply_embedding(self, tokens, embeds): if embeds is not None: - assert ( - self.expects_hash_embeddings - ), "Not expecting embeddings to be passed." + assert self.expects_hash_embeddings, "Not expecting embeddings to be passed." return embeds else: return self.tok_embeddings(tokens) @@ -992,17 +959,13 @@ def forward( for i, layer in enumerate(self.layers): h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) # check if cross attention should be applied to either all layer or only the last layer - if self.cross_attn_encoder and ( - i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder - ): + if self.cross_attn_encoder and (i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): # apply pooling and project if self.cross_attn_init_by_pooling and patch_embeds is None: patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) if self.patch_embedding_projection is not None: patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape( - bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim - ) + patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) layer_idx = i if self.cross_attn_all_layers_encoder else 0 patch_embeds_cross = self.cross_attn_layers[layer_idx]( @@ -1015,8 +978,6 @@ def forward( h_residual = patch_embeds if self.cross_attn_encoder else None return (h, h_residual), cache - - def patch_reduce(self, h, max_num_patches, reduction, patch_ids): """ Reduce variable length patches to single embedding per patch @@ -1032,9 +993,7 @@ def patch_reduce(self, h, max_num_patches, reduction, patch_ids): patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) - reduced_embs = torch.zeros( - (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device - ) + reduced_embs = torch.zeros((bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device) reduced_embs = reduced_embs.scatter_reduce( src=h, dim=1, @@ -1107,9 +1066,7 @@ def forward( assert patch_embeds is not None, "Patch embeddings must be passed." patch_embeds = self.patch_embedding_projection(patch_embeds) if self.cross_attn_k is not None: - patch_embeds = patch_embeds.reshape( - bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim - ) + patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) if patch_embeds is not None and not self.cross_attn_decoder: h = h + patch_embeds @@ -1118,9 +1075,7 @@ def forward( h = F.dropout(h, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): - if self.cross_attn_decoder and ( - i == 0 or self.cross_attn_all_layers_decoder - ): + if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder): # Use cross attention to extract info from patch_embeds into h h_cross = self.cross_attn_layers[i]( x=h, @@ -1211,9 +1166,9 @@ def forward( xk = repeat_kv(xk, self.heads_per_group, dim=2) xv = repeat_kv(xv, self.heads_per_group, dim=2) - # assert mask is None or isinstance(mask, BlockMask) - xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) - #output = flex_attention_comp(xq, xk, xv, block_mask=mask) + # assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) + # output = flex_attention_comp(xq, xk, xv, block_mask=mask) is_causal = (mask == "causal") if isinstance(mask, str) else False mask = mask if isinstance(mask, torch.Tensor) else None mask = mask.to(dtype=xq.dtype).to(xq.device) @@ -1271,11 +1226,11 @@ def init_weights(self, base_std: float, factor: float = 1.0): class GlobalTransformer(nn.Module): def __init__(self, config): super().__init__() - + # Store config for later use self.config = config - - self.dim = config.dim + + self.dim = config.dim_global self.init_base_std = config.init_base_std self.attn_impl = config.attn_impl self.attn_bias_type = config.attn_bias_type @@ -1283,38 +1238,38 @@ def __init__(self, config): self.max_seqlen = config.max_seqlen self.rope_embeddings = RotaryEmbedding( theta=config.rope_theta, - head_dim=config.head_dim or config.dim // config.n_heads, + head_dim=config.head_dim or config.dim_global // config.n_heads_global, max_seqlen=config.max_seqlen, rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, ) # Handle both eos_id and eos_token_id for compatibility - self.eos_id = getattr(config, 'eos_id', getattr(config, 'eos_token_id', 2)) + self.eos_id = getattr(config, "eos_id", getattr(config, "eos_token_id", 2)) # Create parameter dict for BLTTransformerLayers layer_params = { - 'dim': self.dim, - 'n_heads': config.n_heads, - 'head_dim': config.head_dim, - 'n_kv_heads': getattr(config, 'n_kv_heads', None), - 'rope_theta': config.rope_theta, - 'multiple_of': getattr(config, 'multiple_of', 256), - 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None), - 'norm_eps': config.norm_eps, + "dim": self.dim, + "n_heads": config.n_heads_global, + "head_dim": config.head_dim, + "n_kv_heads": getattr(config, "n_kv_heads_global", None), + "rope_theta": config.rope_theta, + "multiple_of": getattr(config, "multiple_of", 256), + "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), + "norm_eps": config.norm_eps, } self.layers = nn.ModuleList() - for _ in range(config.n_layers): + for _ in range(config.n_layers_global): self.layers.append(BLTTransformerLayer(layer_params)) - + # GlobalTransformer specific attributes self.dropout = config.dropout - self.dim_token_emb = config.dim_token_emb + self.dim_token_emb = config.global_dim_patch_emb self.token_embedding_projection = None - if config.dim_token_emb is not None and config.dim_token_emb != self.dim: + if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim: self.token_embedding_projection = nn.Linear( - config.dim_token_emb, - config.dim, + config.global_dim_patch_emb, + config.dim_global, bias=False, ) @@ -1351,7 +1306,7 @@ def forward( for i, layer in enumerate(self.layers): h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) - + return h, cache def init_weights(self): @@ -1359,7 +1314,7 @@ def init_weights(self): for depth, layer in enumerate(self.layers): factor = self.config.get_init_std_factor(depth) layer.init_weights(self.init_base_std, factor) - + # GlobalTransformer specific initialization std = self.dim_token_emb ** (-0.5) if self.token_embedding_projection is not None: @@ -1371,6 +1326,7 @@ def init_weights(self): b=3 * std, ) + def compute_hash_embeddings( local_encoder_tokens: torch.Tensor, local_encoder, @@ -1419,10 +1375,10 @@ class BLTPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained BLT models. - + This class provides the interface for model loading, saving, and weight initialization for all BLT model variants. It inherits from [`PreTrainedModel`] which provides the core functionality for working with HuggingFace models. - + Args: config ([`BLTConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the @@ -1455,35 +1411,19 @@ class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) - # Store config reference self.config = config - - # Create main components - they will read their parameters from config self.local_encoder = LocalEncoder(config) - - # Create global-specific config by copying config and overriding dimensions - global_config = type(config)(**config.to_dict()) - global_config.dim = config.dim_global - global_config.n_layers = config.n_layers_global - global_config.n_heads = config.n_heads_global - global_config.n_kv_heads = config.n_kv_heads_global - global_config.dim_token_emb = config.global_dim_patch_emb - - self.global_transformer = GlobalTransformer(global_config) + self.global_transformer = GlobalTransformer(config) self.local_decoder = LocalDecoder(config) - - # Initialize hash embeddings + self.encoder_hash_tok_embedding = init_hash_embeddings( config, local_encoder_dim=self.local_encoder.dim, encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, ) - # Initialize patcher if needed if config.patch_in_forward: - # Create patcher with config self.patcher = BLTPatcher(config) - # Set patcher to eval mode and disable gradients self.patcher.eval() for param in self.patcher.parameters(): param.requires_grad = False @@ -1512,16 +1452,16 @@ def init_weights(self): def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: """ Convert patch lengths to patch IDs for each token position. - + For each token position in the sequence, determines which patch it belongs to. - + Args: patch_lengths: [batch_size, num_patches] - length of each patch seq_len: total sequence length - + Returns: patch_ids: [batch_size, seq_len] - patch index for each token position - + Example: patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1 seq_len = 10 @@ -1529,55 +1469,56 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3 """ batch_size, num_patches = patch_lengths.shape - + # Create patch start positions: [0, 3, 5, 9] for the example above - patch_starts = torch.cat([ - torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), - patch_lengths.cumsum(dim=-1)[:, :-1] # cumsum without the final total - ], dim=-1) - + patch_starts = torch.cat( + [ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1], # cumsum without the final total + ], + dim=-1, + ) + # For each token position, find which patch it belongs to # by finding the rightmost patch start that's <= the position token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1] - + # Broadcasting: patch_starts[batch, patch] <= token_positions[position] # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1) - + # Count how many patch starts are <= each position, then subtract 1 to get patch index patch_ids = position_ge_patch_start.sum(dim=-1) - 1 - + return patch_ids def _decoder_patch_ids_from_lengths(self, patch_lengths: torch.Tensor, nb_boe: int, seq_len: int) -> torch.Tensor: """ Create decoder patch IDs by skipping the first encoder patch. - - The decoder starts after the first patch (which contains BOE tokens), + + The decoder starts after the first patch (which contains BOE tokens), so we need to map decoder positions to the remaining patches. - + Args: - patch_lengths: [batch_size, num_patches] from encoder + patch_lengths: [batch_size, num_patches] from encoder nb_boe: number of beginning-of-example tokens in first patch seq_len: decoder sequence length - + Returns: decoder_patch_ids: [batch_size, seq_len] mapping decoder positions to patch indices """ # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens) decoder_patch_lengths = patch_lengths[:, 1:] - + # Create patch IDs for the decoder sequence using the remaining patches return self._patch_ids_from_lengths(decoder_patch_lengths, seq_len) - - def forward( self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None, ): - # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings + # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings # are no longer used in the final BLT model bs, N = tokens.shape # Batch size and sequence length @@ -1597,7 +1538,7 @@ def forward( # assert ( # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward # ), "Patch in forward not enabled and no patch_lengths passed." - + # PATCHER MODEL DEFINED if self.config.patching_mode == PatchingModeEnum.entropy: _, patch_lengths, _ = self.patcher( @@ -1618,7 +1559,7 @@ def forward( patch_lengths = torch.ones( (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device ) - + # Apply any processing to patch lengths if self.config.max_patch_length is not None: # TODO: avoid going back to a list here. @@ -1633,13 +1574,9 @@ def forward( ) assert not check_non_zero_after_zero(patch_lengths) # Find the last non-zero column index using argmax on a reversed version of the tensor - last_non_zero_col_reversed = ( - (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() - ) + last_non_zero_col_reversed = (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() # Slice the tensor up to the last non-zero column - patch_lengths = patch_lengths[ - :, : patch_lengths.shape[1] - last_non_zero_col_reversed - ] + patch_lengths = patch_lengths[:, : patch_lengths.shape[1] - last_non_zero_col_reversed] else: if nb_boe > 0: patch_lengths[:, 0] += nb_boe @@ -1647,12 +1584,10 @@ def forward( assert torch.min(patch_lengths) >= 0 # Generate patch IDs from patch_lengths - patch_ids = self._patch_ids_from_lengths( - patch_lengths, local_encoder_tokens.shape[-1] + patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1]) + assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), ( + f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" ) - assert torch.max(patch_ids) + 1 <= torch.max( - (patch_lengths != 0).sum(dim=-1) - ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" cross_attn_mask_enc = None # Cross-attention encoder @@ -1708,21 +1643,15 @@ def forward( dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :] # Generate decoder patch IDs - decoder_patch_ids = self._decoder_patch_ids_from_lengths( - patch_lengths, nb_boe, local_decoder_tokens.shape[-1] + decoder_patch_ids = self._decoder_patch_ids_from_lengths(patch_lengths, nb_boe, local_decoder_tokens.shape[-1]) + assert torch.max(decoder_patch_ids) + 1 <= h.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" + assert decoder_patch_ids.shape[1] == dec_embeds.shape[1], ( + f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" ) - assert ( - torch.max(decoder_patch_ids) + 1 <= h.shape[1] - ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" - assert ( - decoder_patch_ids.shape[1] == dec_embeds.shape[1] - ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" # Cross-attention decoder if not self.config.cross_attn_decoder: - h = torch.gather( - h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) - ) + h = torch.gather(h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])) cross_attn_mask_dec = None assert local_decoder_tokens.shape == h.shape[:-1] else: @@ -1749,77 +1678,66 @@ def forward( class BLTPatcher(BLTPreTrainedModel): def __init__(self, config): super().__init__(config) - - # Store config reference for later use - self.config = config - - # Extract patcher parameters from BLTConfig - self.dim = config.patcher_dim - self.init_base_std = config.patcher_init_base_std - self.attn_impl = config.patcher_attn_impl - self.attn_bias_type = config.patcher_attn_bias_type - self.init_std_factor = config.patcher_init_std_factor - self.max_seqlen = config.patcher_max_seqlen - n_layers = config.patcher_n_layers - n_heads = config.patcher_n_heads - head_dim = config.patcher_head_dim - rope_theta = config.patcher_rope_theta - rope_use_fp32_in_outer_product = config.patcher_rope_use_fp32_in_outer_product - norm_eps = config.patcher_norm_eps - vocab_size = config.patcher_vocab_size - weight_tying = config.patcher_weight_tying - sliding_window = config.patcher_sliding_window - eos_token_id = config.patcher_eos_token_id - + self.rope_embeddings = RotaryEmbedding( - theta=rope_theta, - head_dim=head_dim or self.dim // n_heads, - max_seqlen=self.max_seqlen, - rope_use_fp32_in_outer_product=rope_use_fp32_in_outer_product, + theta=config.patcher_rope_theta, + head_dim=config.patcher_head_dim or config.patcher_dim // config.patcher_n_heads, + max_seqlen=config.patcher_max_seqlen, + rope_use_fp32_in_outer_product=config.patcher_rope_use_fp32_in_outer_product, ) # Handle both eos_id and eos_token_id for compatibility - self.eos_id = eos_token_id + self.eos_id = config.patcher_eos_token_id # Extract additional parameters for BLTTransformerLayer - n_kv_heads = getattr(config, 'patcher_n_kv_heads', None) if hasattr(config, 'patcher_dim') else getattr(config, 'n_kv_heads', None) - multiple_of = getattr(config, 'patcher_multiple_of', 256) if hasattr(config, 'patcher_dim') else getattr(config, 'multiple_of', 256) - ffn_dim_multiplier = getattr(config, 'patcher_ffn_dim_multiplier', None) if hasattr(config, 'patcher_dim') else getattr(config, 'ffn_dim_multiplier', None) - - # Create a simple parameter dict for BLTTransformerLayer - layer_params = { - 'dim': self.dim, - 'n_heads': n_heads, - 'head_dim': head_dim, - 'n_kv_heads': n_kv_heads, - 'rope_theta': rope_theta, - 'multiple_of': multiple_of, - 'ffn_dim_multiplier': ffn_dim_multiplier, - 'norm_eps': norm_eps, - } + n_kv_heads = ( + getattr(config, "patcher_n_kv_heads", None) + if hasattr(config, "patcher_dim") + else getattr(config, "n_kv_heads", None) + ) + multiple_of = ( + getattr(config, "patcher_multiple_of", 256) + if hasattr(config, "patcher_dim") + else getattr(config, "multiple_of", 256) + ) + ffn_dim_multiplier = ( + getattr(config, "patcher_ffn_dim_multiplier", None) + if hasattr(config, "patcher_dim") + else getattr(config, "ffn_dim_multiplier", None) + ) self.layers = nn.ModuleList() - for _ in range(n_layers): - self.layers.append(BLTTransformerLayer(layer_params)) - + for _ in range(config.patcher_n_layers): + self.layers.append( + BLTTransformerLayer( + { + "dim": config.patcher_dim, + "n_heads": config.patcher_n_heads, + "head_dim": config.patcher_head_dim, + "n_kv_heads": n_kv_heads, + "rope_theta": config.patcher_rope_theta, + "multiple_of": multiple_of, + "ffn_dim_multiplier": ffn_dim_multiplier, + "norm_eps": config.patcher_norm_eps, + } + ) + ) + # LMTransformer specific attributes - self.weight_tying = weight_tying - self.sliding_window = sliding_window + self.weight_tying = config.patcher_weight_tying + self.sliding_window = config.patcher_sliding_window - assert vocab_size > 0 + assert config.patcher_vocab_size > 0 - self.tok_embeddings = torch.nn.Embedding(vocab_size, self.dim) + self.tok_embeddings = torch.nn.Embedding(config.patcher_vocab_size, config.patcher_dim) - self.norm = RMSNorm(self.dim, eps=norm_eps) + self.norm = RMSNorm(config.patcher_dim, eps=config.patcher_norm_eps) self.output = nn.Linear( - self.dim, - vocab_size, + config.patcher_dim, + config.patcher_vocab_size, bias=False, ) - if self.weight_tying: - self.output.weight = self.tok_embeddings.weight - def forward( self, token_values: torch.Tensor, @@ -1837,45 +1755,41 @@ def forward( device: Optional[str] = None, enable_grad: bool = False, ): - attn_impl = self.attn_impl if attn_impl is None else attn_impl + attn_impl = self.config.patcher_attn_impl if attn_impl is None else attn_impl # Handle chunked processing for entropy calculation entropies = [] preds = [] - max_length = min(getattr(self, "max_length", 8192), self.max_seqlen) + max_length = min(getattr(self, "max_length", 8192), self.config.patcher_max_seqlen) batch_numel = max_length * patching_batch_size splits = torch.split(token_values.flatten(), batch_numel) - + for split in splits: pad_size = (max_length - (split.numel() % max_length)) % max_length - pad = torch.zeros( - pad_size, dtype=split.dtype, device=split.device, requires_grad=False - ) + pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False) split = torch.cat((split, pad), dim=0) split = split.reshape(-1, max_length) if device is not None: split = split.to(device) - + # Process chunk: embeddings -> layers -> output bsz, seqlen = split.shape h = self.tok_embeddings(split) chunk_mask = create_causal_mask( seqlen, attn_impl, - self.attn_bias_type, + self.config.patcher_attn_bias_type, sliding_window=self.sliding_window, tokens=split, eos_id=self.eos_id, ) freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None) - + for i, layer in enumerate(self.layers): h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=attn_impl) - + pred = self.output(self.norm(h)) - pred = pred.reshape(-1, pred.shape[-1])[ - : split.numel() - pad_size, : - ] # [batch_size * seq_len, vocab] + pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] preds.append(pred) pred_entropies = self.entropy(pred) entropies.append(pred_entropies) @@ -1884,11 +1798,11 @@ def forward( concat_entropies = concat_entropies.reshape(token_values.shape) concat_preds = torch.cat(preds, dim=0) concat_preds = concat_preds.reshape(token_values.shape[0], -1) - + # Always compute patch lengths from concatenated entropies bs, seq_len = token_values.shape seq_len_next_tok = seq_len + 1 if include_next_token else seq_len - + # Find patch start IDs based on entropy if patch_size is not None: patch_start_ids = self.find_entropy_patch_start_ids( @@ -1899,49 +1813,36 @@ def forward( threshold_add=threshold_add, monotonicity=monotonicity, ) - patch_lengths = self.patch_lengths_from_start_ids( - patch_start_ids, seq_len_next_tok - ) + patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok) else: # Default to byte-level patching - patch_lengths = torch.ones( - (bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device - ) + patch_lengths = torch.ones((bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) # Apply any processing to patch lengths if max_patch_length is not None: # TODO: avoid going back to a list here. - patch_lengths = [ - self.split_large_numbers(pl, max_patch_length) - for pl in patch_lengths.tolist() - ] + patch_lengths = [self.split_large_numbers(pl, max_patch_length) for pl in patch_lengths.tolist()] max_len = max([len(pl) for pl in patch_lengths]) patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] - patch_lengths = torch.tensor( - patch_lengths, dtype=token_values.dtype, device=token_values.device - ) + patch_lengths = torch.tensor(patch_lengths, dtype=token_values.dtype, device=token_values.device) assert not check_non_zero_after_zero(patch_lengths) # Find the last non-zero column index using argmax on a reversed version of the tensor - last_non_zero_col_reversed = ( - (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() - ) + last_non_zero_col_reversed = (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() # Slice the tensor up to the last non-zero column - patch_lengths = patch_lengths[ - :, : patch_lengths.shape[1] - last_non_zero_col_reversed - ] - + patch_lengths = patch_lengths[:, : patch_lengths.shape[1] - last_non_zero_col_reversed] + return concat_entropies, patch_lengths, concat_preds def init_weights(self): """Initialize weights for the patcher model""" # Initialize RoPE embeddings self.rope_embeddings.reset_parameters() - + # Initialize norm layer self.norm.reset_parameters() - + # Initialize token embeddings - emb_std = self.dim ** (-0.5) + emb_std = self.patcher_dim ** (-0.5) nn.init.trunc_normal_( self.tok_embeddings.weight, mean=0.0, @@ -1949,12 +1850,12 @@ def init_weights(self): a=-3 * emb_std, b=3 * emb_std, ) - + # Initialize transformer layers for depth, layer in enumerate(self.layers): factor = self.config.get_init_std_factor(depth) - layer.init_weights(self.init_base_std, factor) - + layer.init_weights(self.patcher_init_base_std, factor) + # Initialize output layer if not weight tied if not self.weight_tying: nn.init.trunc_normal_( @@ -1971,9 +1872,9 @@ def _init_weights(self, module): nn.init.trunc_normal_( module.weight, mean=0.0, - std=self.init_base_std or (self.dim ** (-0.5)), - a=-3 * (self.init_base_std or (self.dim ** (-0.5))), - b=3 * (self.init_base_std or (self.dim ** (-0.5))), + std=self.patcher_init_base_std or (self.patcher_dim ** (-0.5)), + a=-3 * (self.patcher_init_base_std or (self.patcher_dim ** (-0.5))), + b=3 * (self.patcher_init_base_std or (self.patcher_dim ** (-0.5))), ) if module.bias is not None: nn.init.zeros_(module.bias) @@ -1981,13 +1882,11 @@ def _init_weights(self, module): nn.init.trunc_normal_( module.weight, mean=0.0, - std=self.init_base_std or (self.dim ** (-0.5)), - a=-3 * (self.init_base_std or (self.dim ** (-0.5))), - b=3 * (self.init_base_std or (self.dim ** (-0.5))), + std=self.patcher_init_base_std or (self.patcher_dim ** (-0.5)), + a=-3 * (self.patcher_init_base_std or (self.patcher_dim ** (-0.5))), + b=3 * (self.patcher_init_base_std or (self.patcher_dim ** (-0.5))), ) - - @staticmethod def entropy(scores): """ @@ -2003,8 +1902,6 @@ def entropy(scores): entropy = -p_log_p.sum(dim=-1) return entropy - - @staticmethod def patch_start_ids_from_patch_start_mask(patch_start_mask): bs, trunc_seq_len = patch_start_mask.shape @@ -2017,11 +1914,7 @@ def patch_start_ids_from_patch_start_mask(patch_start_mask): device=patch_start_mask.device, ) else: - patch_ids = ( - torch.arange(trunc_seq_len, device=patch_start_mask.device) - .unsqueeze(0) - .repeat(bs, 1) - ) + patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(bs, 1) extra_patch_ids = torch.full( (bs, trunc_seq_len), trunc_seq_len, @@ -2029,12 +1922,8 @@ def patch_start_ids_from_patch_start_mask(patch_start_mask): device=patch_start_mask.device, ) all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) - patch_start_mask_padded = torch.cat( - (patch_start_mask, ~patch_start_mask), dim=1 - ) - patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape( - bs, trunc_seq_len - )[:, :max_patches] + patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1) + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, trunc_seq_len)[:, :max_patches] return patch_start_ids @staticmethod @@ -2074,14 +1963,8 @@ def find_entropy_patch_start_ids( """ bs, seq_len = entropies.shape[:2] - first_ids = ( - torch.tensor([0, 1], dtype=torch.long, device=entropies.device) - .unsqueeze(0) - .repeat(bs, 1) - ) - preds_truncation_len = first_ids.shape[ - 1 - ] # remove the first preds because they will be start of patches. + first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(bs, 1) + preds_truncation_len = first_ids.shape[1] # remove the first preds because they will be start of patches. entropies = entropies[:, 1:] if threshold is None: num_patches = seq_len // patch_size @@ -2094,9 +1977,7 @@ def find_entropy_patch_start_ids( # patch_start_mask[1:] |= tokens[:-1] < OFFSET patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) - patch_start_ids = torch.cat( - (first_ids, patch_start_ids + preds_truncation_len), dim=1 - ) + patch_start_ids = torch.cat((first_ids, patch_start_ids + preds_truncation_len), dim=1) return patch_start_ids @staticmethod @@ -2112,8 +1993,8 @@ def split_large_numbers(lst, m): new_lst.append(i) assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" return new_lst - - + + def init_hash_embeddings( config, local_encoder_dim: int, @@ -2126,7 +2007,7 @@ def init_hash_embeddings( embeddings = [] emb_dim = local_encoder_dim encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab - + for _ in range(config.encoder_hash_byte_group_nb_functions): for _ in encoder_hash_byte_group_size: embeddings.append( @@ -2143,7 +2024,7 @@ def init_hash_embeddings( "BLTPreTrainedModel", "BLTModel", "BLTPatcher", - "LocalEncoder", + "LocalEncoder", "LocalDecoder", "GlobalTransformer", ] diff --git a/src/transformers/models/blt_wip/modeling_blt_wip_backup.py b/src/transformers/models/blt_wip/modeling_blt_wip_backup.py new file mode 100644 index 000000000000..adc4104dcbeb --- /dev/null +++ b/src/transformers/models/blt_wip/modeling_blt_wip_backup.py @@ -0,0 +1,2166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from enum import Enum +from typing import Any, List, Optional, Tuple, Union + +import torch +from pydantic import model_validator +from torch import nn +from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention +import json +import logging + +import torch +import torch.nn +import torch.nn as nn +from torch.nn import functional as F + +import os +from contextlib import nullcontext + +SEP = " " +BOS_ID: int = 1 +EOS_ID: int = 2 +PAD_ID: int = -1 +BOE_ID: int = 0 +BPE_ID: int = 3 +OFFSET: int = 4 + +BYTE_UNITS: int = 256 + +RMSNorm = nn.RMSNorm + +logger = logging.getLogger() + +from .configuration_blt import ( + BLTConfig, + PatchingModeEnum, + InitStdFactor, +) + +from ...modeling_utils import PreTrainedModel +from ...utils import logging as transformers_logging + +flex_attention_comp = flex_attention + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "sdpa": + BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) + + if attn_bias_type == "causal": + return "causal" + + if BLT_SUPPRESS_ATTN_ERROR == 1: + return "causal" + else: + raise ValueError( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" + ) + elif attn_impl == "flex_attention": + return create_block_mask(causal_mask, None, None, seqlen, seqlen) + else: + raise NotImplementedError( + f"Attention {attn_impl} with {sliding_window} sliding window not implemented" + ) + +def cross_entropy(pred, target, **kwargs): + return F.nll_loss( + F.log_softmax(pred.flatten(end_dim=-2).float(), -1), + target.flatten(end_dim=-1), + **kwargs, + ) + + +def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + rope_use_fp32_in_outer_product: bool = False, +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + if rope_use_fp32_in_outer_product: + t = t.to(torch.float32) + + freqs = torch.outer(t, freqs).float() + + cos, sin = freqs.cos(), freqs.sin() + + return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + seq_dim (int): Sequence dimension index. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= seq_dim < ndim + assert freqs_cis.shape == ( + x.shape[seq_dim], + x.shape[-3], + 2, + 2, + ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" + shape = [ + d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2]) + ] + [2, 2] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + seq_dim: int, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + freqs_cis = reshape_for_broadcast( + freqs_cis, xq_, seq_dim + ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 + xq_out = (xq_ * freqs_cis).sum(5).flatten(3) + xk_out = (xk_ * freqs_cis).sum(5).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed. +class RotaryEmbedding(torch.nn.Module): + """ + RotaryEmbedding Module + """ + + def __init__( + self, + theta: float, + head_dim: int, + max_seqlen: int = 1024, + rope_use_fp32_in_outer_product: bool = False, + ): + super().__init__() + + self.theta = theta + self.head_dim = head_dim + self.max_seqlen = max_seqlen + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + dim=head_dim, + end=max_seqlen, + theta=theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ), + persistent=False, + ) + + def reset_parameters(self): + self.freqs_cis[...] = precompute_freqs_cis( + dim=self.head_dim, + end=self.max_seqlen, + theta=self.theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ) + + def forward( + self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None + ): + """ + Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions + Args: + seqlen (int): Contiguous sequence length + tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen + + Returns: + Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis + """ + test = (seqlen is not None) or (tok_idx is not None) + assert test, "Should provide atleast seqlen or tok_idx" + if tok_idx is not None: + return self.freqs_cis[tok_idx] + elif seqlen is not None: + return self.freqs_cis[0:seqlen] + + +class BLTAttention(nn.Module): + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + rope_theta: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + self.rope_theta = rope_theta + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + # B S D + bsz, seq_len, dim = x.shape + xq = self.wq(x.view_as(x)) + xk = self.wk(x.view_as(x)) + xv = self.wv(x.view_as(x)) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) + + # This condition helps us be easily compatible + # with inference by adding a pluggable KVCache + if hasattr(self, "kv_cache"): + xk, xv = self.kv_cache.update(xk, xv, tok_idx) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + if attn_impl == "flex_attention": + assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + output = flex_attention_comp(xq, xk, xv, block_mask=mask) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + elif attn_impl == "sdpa": + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + assert mask is None or isinstance(mask, (str, torch.Tensor)) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + else: + raise NotImplementedError( + f"Attention implementation {attn_impl} not supported" + ) + + output_reshaped = output.reshape(output_shape) + + output = self.wo(output_reshaped) + + return output + + def reset_parameters(self, init_std=None, factor=1.0): + init_std = init_std or (self.dim ** (-0.5)) / factor + + for w in [self.wq, self.wk, self.wv]: + nn.init.trunc_normal_( + w.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + nn.init.trunc_normal_( + self.wo.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + +class BLTMLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + mp_size: int = 1, + ): + super().__init__() + + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + assert hidden_dim % mp_size == 0 + + self.dim = dim + self.hidden_dim = hidden_dim + + self.w1 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w3 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # B S D + x1 = self.w1(x.view_as(x)) + x3 = self.w3(x.view_as(x)) + output = self.w2(F.silu(x1) * x3) + return output + + def reset_parameters(self, init_std=None, factor=1.0): + in_init_std = init_std or (self.dim ** (-0.5)) / factor + out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor + + nn.init.trunc_normal_( + self.w1.weight, + mean=0.0, + std=in_init_std, + a=-3 * in_init_std, + b=3 * in_init_std, + ) + nn.init.trunc_normal_( + self.w2.weight, + mean=0.0, + std=out_init_std, + a=-3 * out_init_std, + b=3 * out_init_std, + ) + nn.init.trunc_normal_( + self.w3.weight, + mean=0.0, + std=in_init_std, + a=-3 * in_init_std, + b=3 * in_init_std, + ) + + +class BLTTransformerLayer(nn.Module): + def __init__(self, args): + super().__init__() + + # Extract parameters from dictionary + dim = args['dim'] + n_heads = args['n_heads'] + head_dim = args['head_dim'] + n_kv_heads = args['n_kv_heads'] + rope_theta = args['rope_theta'] + multiple_of = args['multiple_of'] + ffn_dim_multiplier = args['ffn_dim_multiplier'] + norm_eps = args['norm_eps'] + + assert (head_dim is not None) or ( + n_heads is not None + ), "Should specify at least head_dim or n_heads" + self.head_dim = head_dim or dim // n_heads + self.n_heads = n_heads or dim // head_dim + self.n_kv_heads = n_kv_heads or self.n_heads + + assert n_heads % self.n_kv_heads == 0 + assert dim % n_heads == 0 + + self.attention = BLTAttention( + dim=dim, + head_dim=self.head_dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + rope_theta=rope_theta, + ) + self.feed_forward = BLTMLP( + dim=dim, + hidden_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + self.attention_norm = RMSNorm(dim, eps=norm_eps) + self.ffn_norm = RMSNorm(dim, eps=norm_eps) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + norm_x = self.attention_norm(x) + attn_out = self.attention( + norm_x, + freq_cis, + tok_idx=tok_idx, + mask=mask, + attn_impl=attn_impl, + ) + h = x + attn_out + h_norm = self.ffn_norm(h) + out = h + self.feed_forward(h_norm) + return out + + def init_weights(self, init_std=None, factor=1.0): + self.attention.reset_parameters(init_std, factor) + self.attention_norm.reset_parameters() + + self.feed_forward.reset_parameters(init_std, factor) + self.ffn_norm.reset_parameters() + + +def rightpad(seq, pad_id, max_len): + return seq + [pad_id] * (max_len - len(seq)) + + +def check_non_zero_after_zero(tensor): + zero_mask = tensor == 0 + shifted_mask = torch.cat( + [ + torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), + zero_mask[:, :-1], + ], + dim=1, + ) + non_zero_after_zero = (tensor != 0) & shifted_mask + return non_zero_after_zero.any() + + +def fill_tokens(tokens, patch_size, fill_id): + batch_size, seq_len = tokens.shape + if seq_len % patch_size == 0: + return tokens + else: + remaining = patch_size - seq_len % patch_size + final_padding = tokens.new(batch_size, remaining).fill_(fill_id) + return torch.cat((tokens, final_padding), dim=1) + + +def rolling_polynomial_hash(t, hash_func_nb: int = 0): + primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, + ] + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) + prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) + return torch.sum(t * prime_powers, dim=-1) + +def byte_group_hash_function( + x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 +): + """ + Returns a hash of the input x and maps it to a value in the range [0, max_hash]. + + expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. + returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. + + Note: max hash can make a big difference on the number of collisions. + """ + with torch.no_grad(): + bs, seq_len = x.shape + prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) + x = torch.cat([prefix, x], dim=1) + windows = x.unfold(1, group_size, 1) + # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values_range = hashes % max_hash + hash_values_range.requires_grad = False + return hash_values_range + + +def create_patch_mask_from_ids( + patch_ids, num_patches, window=None, patches_as_queries=False +): + """ + Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) + is True if the patch id at position (i, j) is less than or equal to k. + Args: + patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. + num_patches (int): Total number of patches. + window (int): If not None, only considers patches within a window of size window. + patches_as_queries (bool): If True, the patches are used as queries + Returns: + torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. + """ + bs, seq_len = patch_ids.shape + if not patches_as_queries: + q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) + kv_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(bs, seq_len, num_patches) + ) + else: + kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) + q_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(bs, num_patches, seq_len) + ) + if window is None: + mask = q_ids == kv_ids + else: + mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window) + return mask + + +def cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=1, + window=None, + block_mask=True, +): + bs = patch_ids.shape[0] + with torch.no_grad(): + # Create the patch mask + cross_mask = create_patch_mask_from_ids( + patch_ids, + patch_lengths.shape[1], + window=window, + patches_as_queries=patches_as_queries, + ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1) + q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N + kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k + assert cross_mask.shape == ( + bs, + q_len, + kv_len, + ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" + block_mask = None + if block_mask: + + def patch_mask(b, h, q_idx, kv_idx): + return cross_mask[b, q_idx, kv_idx] + + block_mask = create_block_mask( + patch_mask, + B=bs, + H=None, + Q_LEN=q_len, + KV_LEN=kv_len, + _compile=True, + ) + return block_mask + else: + return torch.where( + cross_mask, torch.tensor(0.0), torch.tensor(float("-inf")) + ).unsqueeze( + 1 + ) # [bs, 1, q_len, kv_len] + + +def get_blt_input( + tokens: torch.Tensor, + enforce_patch_size_multiple: bool, + nb_boe: torch.Tensor, + patch_size: int, + boe_id: int, +): + """ + This function returns X_et, X_gt and X_dt, the encoder, global, and decoder + tokens respectively. + + Consider the input and target sequences: + X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13] + Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14] + with patch_size=4 + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + + Current without boe: + X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] + X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + --> lag fix: + X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]] + X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]] + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + Dynamic (current): + X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + entropy patching: + input: 7, bos, 9, 10 + pred (high entropy): eos, 8, 10, eos + + X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos] + X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]] + X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + --> lag fix no boe (force single byte first patch): + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + input: 4, 7, bos, 9, 10 + pred (high entropy): 5, eos, 8, 10, eos + + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + Handle the last byte properly. + patch_lengths = [1, 1, 3, 2, 2 1 2 2 1] + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]] + + + bpe delim + X_et = [[3,4,5,6,7,,eos,bos,,8,9,,10,,eos,bos,11,12] + X_g = [[3], [4,5,6,7,], [eos,bos,], .. + X_dt = [[3,4,5,6,7], [,eos,bos], [,bos,8], .. + Y = [4,5,6,7,, eos,bos, 8,9,, .. + + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + """ + batch_size, seq_len = tokens.shape + local_encoder_tokens = tokens + local_decoder_tokens = tokens + + if nb_boe > 0: + padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id) + local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1) + # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id) + + # create global tokens, contains boe tokens and eos + # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size) + # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:] + # global_tokens += global_tokens.eq(0).int() * boe_id + # TODO: fix this when we want to use block causal in the global. + + if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0: + local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + + return local_encoder_tokens, None, local_decoder_tokens + + +class LocalModelBase(nn.Module): + def __init__(self, config: BLTConfig, component_type: str = "encoder"): + super().__init__() + + # Store config for later use + self.config = config + + # Use component-specific dimensions + if component_type == "encoder": + self.dim = config.dim_local_encoder + self.n_layers = config.n_layers_local_encoder + self.n_heads = config.n_heads_local_encoder + self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen + self.attn_bias_type = "local_block_causal" + self.sliding_window = config.local_attention_window_len + elif component_type == "decoder": + self.dim = config.dim_local_decoder + self.n_layers = config.n_layers_local_decoder + self.n_heads = config.n_heads_local_decoder + self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen + self.attn_bias_type = "local_block_causal" + self.sliding_window = config.local_attention_window_len + else: + raise ValueError(f"Unknown component_type: {component_type}") + + self.dropout = config.dropout + self.vocab_size = config.vocab_size + config.pm_size + self.patch_size = config.patch_size + + self.attn_impl = config.attn_impl + self.use_rope = config.use_rope + self.init_std_factor = config.init_std_factor + self.init_base_std = config.init_base_std + self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None) + self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None) + self.cross_attn_k = getattr(config, "cross_attn_k", None) + self.eos_id = config.eos_token_id + + self.boe_id = BOE_ID + + # Initialize cross attention layers as None (will be set by subclasses if needed) + self.cross_attn_layers = None + + # Create parameter dict for BLTTransformerLayers + layer_params = { + 'dim': self.dim, + 'n_heads': self.n_heads, + 'head_dim': config.head_dim, + 'n_kv_heads': getattr(config, 'n_kv_heads', None), + 'rope_theta': config.rope_theta, + 'multiple_of': getattr(config, 'multiple_of', 256), + 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None), + 'norm_eps': config.norm_eps, + } + + self.layers = nn.ModuleList( + [BLTTransformerLayer(layer_params) for _ in range(self.n_layers)] + ) + + if not self.use_rope: + self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length + else: + self.rope = RotaryEmbedding( + theta=config.rope_theta, + head_dim=config.head_dim or self.dim // self.n_heads, + max_seqlen=self.max_seqlen, + rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, + ) + self.pos_embeddings = None + + # Set dimension-specific embedding dimensions + if component_type == "encoder": + self.dim_token_emb = config.encoder_dim_token_emb + self.dim_patch_emb = config.encoder_dim_patch_emb + elif component_type == "decoder": + self.dim_token_emb = config.decoder_dim_token_emb + self.dim_patch_emb = config.dim_global + + self.token_embedding_projection = ( + nn.Linear(self.dim_token_emb, self.dim, bias=False) + if self.dim_token_emb is not None and self.dim_token_emb != self.dim + else None + ) + + self.patch_embedding_projection = self._create_patch_projection(config) + + def _should_create_patch_projection(self, config: BLTConfig): + dimension_mismatch = ( + self.dim_patch_emb is not None and self.dim_patch_emb != self.dim + ) + + # Check cross attention conditions + cross_attn_conditions = ( + config.cross_attn_encoder and config.cross_attn_init_by_pooling + ) or (config.cross_attn_decoder and config.cross_attn_init_by_pooling) + + return dimension_mismatch or cross_attn_conditions + + def _create_patch_projection(self, config): + if not self._should_create_patch_projection(config): + return None + + output_dim = self.dim_token_emb * (self.cross_attn_k or 1) + + return nn.Linear( + in_features=self.dim_patch_emb, + out_features=output_dim, + bias=False, + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + return embeds + else: + return self.tok_embeddings(tokens) + + def init_weights(self, init_std=None): + self.rope.reset_parameters() + if hasattr(self, "norm"): + self.norm.reset_parameters() + + init_std = init_std or (self.dim ** (-0.5)) + if hasattr(self, "tok_embeddings"): + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + if self.pos_embeddings is not None: + nn.init.trunc_normal_( + self.pos_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + for depth, layer in enumerate(self.layers): + factor = self.config.get_init_std_factor(depth) + layer.init_weights(self.init_base_std, factor) + + if hasattr(self, "output"): + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if self.token_embedding_projection is not None: + nn.init.trunc_normal_( + self.token_embedding_projection.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if self.patch_embedding_projection is not None: + patch_emb_std = self.dim_patch_emb ** (-0.5) + nn.init.trunc_normal_( + self.patch_embedding_projection.weight, + mean=0.0, + std=patch_emb_std, + a=-3 * patch_emb_std, + b=3 * patch_emb_std, + ) + + if self.cross_attn_layers is not None: + for depth, layer in enumerate(self.cross_attn_layers): + factor = self.config.get_init_std_factor(depth) + layer.init_weights(None, factor) + + +class LocalEncoder(LocalModelBase): + def __init__(self, config: BLTConfig): + super().__init__(config, component_type="encoder") + + self.apply_transformer = config.use_local_encoder_transformer + self.downsampling_by_pooling = config.downsampling_by_pooling + self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None + self.cross_attn_encoder = config.cross_attn_encoder + self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_nheads = config.cross_attn_nheads + + self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim) + + if self.cross_attn_encoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=config.norm_eps, + ) + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + assert ( + self.expects_hash_embeddings + ), "Not expecting embeddings to be passed." + return embeds + else: + return self.tok_embeddings(tokens) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ """ + bs, seqlen = tokens.shape + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) + + h = self.apply_embedding(tokens, embeds) + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + + for i, layer in enumerate(self.layers): + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) + # check if cross attention should be applied to either all layer or only the last layer + if self.cross_attn_encoder and ( + i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder + ): + # apply pooling and project + if self.cross_attn_init_by_pooling and patch_embeds is None: + patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) + if self.patch_embedding_projection is not None: + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape( + bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim + ) + + layer_idx = i if self.cross_attn_all_layers_encoder else 0 + patch_embeds_cross = self.cross_attn_layers[layer_idx]( + x=patch_embeds, + kv=h, + mask=cross_mask, + ) + patch_embeds = patch_embeds + patch_embeds_cross + + h_residual = patch_embeds if self.cross_attn_encoder else None + return (h, h_residual), cache + + + + def patch_reduce(self, h, max_num_patches, reduction, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + bs, seq_len, emb_dim = h.shape + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + + reduced_embs = torch.zeros( + (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device + ) + reduced_embs = reduced_embs.scatter_reduce( + src=h, + dim=1, + index=patch_ids, + reduce=reduction, + include_self=False, + ) + reduced_embs = reduced_embs[:, :max_num_patches, :] + + return reduced_embs + + +class LocalDecoder(LocalModelBase): + def __init__(self, config: BLTConfig): + super().__init__(config, component_type="decoder") + + # Model configuration flags + self.cross_attn_decoder = config.cross_attn_decoder + self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_nheads = config.cross_attn_nheads + + self.norm = RMSNorm(self.dim, eps=config.norm_eps) + + if self.cross_attn_decoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=config.norm_eps, + ) + ) + + self.output = nn.Linear( + self.dim, + config.vocab_size, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor], + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + bs, seqlen = tokens.shape + assert embeds is not None, "Embeddings must be provided" + + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) + + h = embeds + + if self.patch_embedding_projection is not None: + assert patch_embeds is not None, "Patch embeddings must be passed." + patch_embeds = self.patch_embedding_projection(patch_embeds) + if self.cross_attn_k is not None: + patch_embeds = patch_embeds.reshape( + bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim + ) + + if patch_embeds is not None and not self.cross_attn_decoder: + h = h + patch_embeds + + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + for i, layer in enumerate(self.layers): + if self.cross_attn_decoder and ( + i == 0 or self.cross_attn_all_layers_decoder + ): + # Use cross attention to extract info from patch_embeds into h + h_cross = self.cross_attn_layers[i]( + x=h, + kv=patch_embeds, + mask=cross_mask, + ) + h = h + h_cross + + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) + + h_preds = self.norm(h) + h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) + h_preds = self.output(h_preds) + h_preds = h_preds.float() + return h_preds, cache + + +class BLTCrossAttention(nn.Module): + """ + BLTCrossAttention block to attend to the encoder states from the decoder. + Rope is not supported. + """ + + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + kv: torch.Tensor, + mask: Optional[Union[BlockMask, str]] = None, + ) -> torch.Tensor: + # B S D + bsz, seq_len, _ = x.shape + _, slen_kv, _ = kv.shape + x_norm = self.cross_attn_norm_q(x) + kv = self.cross_attn_norm_kv(kv) + + xq = self.wq(x_norm) + xk = self.wk(kv) + xv = self.wv(kv) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + # assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + #output = flex_attention_comp(xq, xk, xv, block_mask=mask) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask if isinstance(mask, torch.Tensor) else None + mask = mask.to(dtype=xq.dtype).to(xq.device) + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + output = self.wo(output.reshape(output_shape)) + + return x + output + + def init_weights(self, base_std: float, factor: float = 1.0): + std = base_std or (self.dim ** (-0.5)) / factor + + nn.init.trunc_normal_( + self.wq.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wk.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wv.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wo.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + self.cross_attn_norm_q.reset_parameters() + self.cross_attn_norm_kv.reset_parameters() + + +class GlobalTransformer(nn.Module): + def __init__(self, config): + super().__init__() + + # Store config for later use + self.config = config + + self.dim = config.dim + self.init_base_std = config.init_base_std + self.attn_impl = config.attn_impl + self.attn_bias_type = config.attn_bias_type + self.init_std_factor = config.init_std_factor + self.max_seqlen = config.max_seqlen + self.rope_embeddings = RotaryEmbedding( + theta=config.rope_theta, + head_dim=config.head_dim or config.dim // config.n_heads, + max_seqlen=config.max_seqlen, + rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, + ) + # Handle both eos_id and eos_token_id for compatibility + self.eos_id = getattr(config, 'eos_id', getattr(config, 'eos_token_id', 2)) + + # Create parameter dict for BLTTransformerLayers + layer_params = { + 'dim': self.dim, + 'n_heads': config.n_heads, + 'head_dim': config.head_dim, + 'n_kv_heads': getattr(config, 'n_kv_heads', None), + 'rope_theta': config.rope_theta, + 'multiple_of': getattr(config, 'multiple_of', 256), + 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None), + 'norm_eps': config.norm_eps, + } + + self.layers = nn.ModuleList() + for _ in range(config.n_layers): + self.layers.append(BLTTransformerLayer(layer_params)) + + # GlobalTransformer specific attributes + self.dropout = config.dropout + self.dim_token_emb = config.dim_token_emb + + self.token_embedding_projection = None + if config.dim_token_emb is not None and config.dim_token_emb != self.dim: + self.token_embedding_projection = nn.Linear( + config.dim_token_emb, + config.dim, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + embeds: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + bs, seqlen = tokens.shape + + h = embeds + + mask = ( + mask + if mask is not None + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) + ) + + if self.token_embedding_projection is not None and h.shape[-1] != self.dim: + h = self.token_embedding_projection(h) + + h = F.dropout(h, p=self.dropout, training=self.training) + + freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) + + return h, cache + + def init_weights(self): + self.rope_embeddings.reset_parameters() + for depth, layer in enumerate(self.layers): + factor = self.config.get_init_std_factor(depth) + layer.init_weights(self.init_base_std, factor) + + # GlobalTransformer specific initialization + std = self.dim_token_emb ** (-0.5) + if self.token_embedding_projection is not None: + nn.init.trunc_normal_( + self.token_embedding_projection.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """ + Compute embeddings using hash token embeddings. + + Args: + local_encoder_tokens: Input tokens tensor + local_encoder: Encoder object with tok_embeddings method + encoder_hash_tok_embedding: ModuleList of hash token embeddings + encoder_hash_byte_group_nb_functions: Number of hash functions + encoder_hash_byte_group_size: List of byte group sizes + encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings + + Returns: + torch.Tensor: Combined embeddings + """ + if encoder_hash_tok_embedding is None: + return None + + local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens) + + i = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for byte_group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, + byte_group_size, + hash_func_nb=func_nb, + max_hash=encoder_hash_byte_group_vocab, + ) + hash_tok_embedding = encoder_hash_tok_embedding[i] + local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) + i += 1 + + assert i == len(encoder_hash_tok_embedding) + return local_encoder_embeds + + +class BLTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + BLT models. + + This class provides the interface for model loading, saving, and weight initialization for all BLT model variants. + It inherits from [`PreTrainedModel`] which provides the core functionality for working with HuggingFace models. + + Args: + config ([`BLTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + """ + + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "LocalEncoder", "LocalDecoder", "GlobalTransformer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False # BLT uses its own attention implementation + _supports_sdpa = True + _supports_cache_class = False + + def _init_weights(self, module): + """Initialize the weights - this is called by PreTrainedModel but we delegate to our custom init""" + # Don't do anything here - we use the custom init_weights method instead + pass + + +class BLTModel(BLTPreTrainedModel): + """ + The BLTModel (BLT) is a byte-level language model architecture that processes byte sequences + by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, + and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for + improved performance and inference efficiency. + """ + + def __init__(self, config: BLTConfig): + super().__init__(config) + + # Store config reference + self.config = config + + # Create main components - they will read their parameters from config + self.local_encoder = LocalEncoder(config) + + # Create global-specific config by copying config and overriding dimensions + global_config = type(config)(**config.to_dict()) + global_config.dim = config.dim_global + global_config.n_layers = config.n_layers_global + global_config.n_heads = config.n_heads_global + global_config.n_kv_heads = config.n_kv_heads_global + global_config.dim_token_emb = config.global_dim_patch_emb + + self.global_transformer = GlobalTransformer(global_config) + self.local_decoder = LocalDecoder(config) + + # Initialize hash embeddings + self.encoder_hash_tok_embedding = init_hash_embeddings( + config, + local_encoder_dim=self.local_encoder.dim, + encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, + ) + + # Initialize patcher if needed + if config.patch_in_forward: + if config.realtime_patching and config.entropy_model_checkpoint_dir is not None: + # Load entropy model directly + entropy_model_checkpoint_dir = config.entropy_model_checkpoint_dir + + if not os.path.exists(entropy_model_checkpoint_dir): + raise FileNotFoundError(f"Entropy model checkpoint directory not found: {entropy_model_checkpoint_dir}") + + # Load entropy model parameters + params_path = os.path.join(entropy_model_checkpoint_dir, "params.json") + if not os.path.exists(params_path): + raise FileNotFoundError(f"params.json not found in: {entropy_model_checkpoint_dir}") + + with open(params_path) as fr: + reloaded = json.loads(fr.read()) + + torch.set_default_dtype(torch.bfloat16) + model_params = reloaded["entropy_model"] + logger.warning( + "Update checkpoint to load attn and sliding window args from checkpoint" + ) + + # Override patcher configuration with actual entropy model parameters from checkpoint + config.patcher_dim = model_params["dim"] + config.patcher_n_layers = model_params["n_layers"] + config.patcher_n_heads = model_params["n_heads"] + config.patcher_max_seqlen = model_params["max_seqlen"] + config.patcher_ffn_dim_multiplier = model_params["ffn_dim_multiplier"] + config.patcher_vocab_size = model_params["vocab_size"] + # Use sensible defaults for parameters not in checkpoint + config.patcher_attn_bias_type = "local_block_causal" + config.patcher_attn_impl = "sdpa" # originally xformers + config.patcher_sliding_window = 512 + + # BLTPatcher will extract patcher_ parameters from config directly + self.patcher = BLTPatcher(config) + + state_path = os.path.join( + entropy_model_checkpoint_dir, "consolidated.pth" + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.patcher.load_state_dict( + torch.load(state_path, map_location=device)["model"], strict=False + ) + self.patcher.to(device) + self.patcher = self.patcher.eval() + # no grads for the model: + for param in self.patcher.parameters(): + param.requires_grad = False + else: + self.patcher = None + + # Initialize weights and apply final processing + self.post_init() + + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: + """ + Convert patch lengths to patch IDs for each token position. + + For each token position in the sequence, determines which patch it belongs to. + + Args: + patch_lengths: [batch_size, num_patches] - length of each patch + seq_len: total sequence length + + Returns: + patch_ids: [batch_size, seq_len] - patch index for each token position + + Example: + patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1 + seq_len = 10 + Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]] + # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3 + """ + batch_size, num_patches = patch_lengths.shape + + # Create patch start positions: [0, 3, 5, 9] for the example above + patch_starts = torch.cat([ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1] # cumsum without the final total + ], dim=-1) + + # For each token position, find which patch it belongs to + # by finding the rightmost patch start that's <= the position + token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1] + + # Broadcasting: patch_starts[batch, patch] <= token_positions[position] + # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t + position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1) + + # Count how many patch starts are <= each position, then subtract 1 to get patch index + patch_ids = position_ge_patch_start.sum(dim=-1) - 1 + + return patch_ids + + def _decoder_patch_ids_from_lengths(self, patch_lengths: torch.Tensor, nb_boe: int, seq_len: int) -> torch.Tensor: + """ + Create decoder patch IDs by skipping the first encoder patch. + + The decoder starts after the first patch (which contains BOE tokens), + so we need to map decoder positions to the remaining patches. + + Args: + patch_lengths: [batch_size, num_patches] from encoder + nb_boe: number of beginning-of-example tokens in first patch + seq_len: decoder sequence length + + Returns: + decoder_patch_ids: [batch_size, seq_len] mapping decoder positions to patch indices + """ + # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens) + decoder_patch_lengths = patch_lengths[:, 1:] + + # Create patch IDs for the decoder sequence using the remaining patches + return self._patch_ids_from_lengths(decoder_patch_lengths, seq_len) + + + + def forward( + self, + tokens: torch.Tensor, + patch_lengths: Optional[torch.Tensor] = None, + ): + # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings + # are no longer used in the final BLT model + + bs, N = tokens.shape # Batch size and sequence length + + # Get megabyte inputs + nb_boe = int(0 if self.config.patching_mode != "" else self.config.patch_size - 1) + local_encoder_tokens, _, local_decoder_tokens = get_blt_input( + tokens=tokens, + enforce_patch_size_multiple=False, + nb_boe=nb_boe, + patch_size=self.config.patch_size, + boe_id=BOE_ID, + ) + + # Patching + if patch_lengths is None: + # assert ( + # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward + # ), "Patch in forward not enabled and no patch_lengths passed." + + # PATCHER MODEL DEFINED + if self.config.patching_mode == PatchingModeEnum.entropy: + _, patch_lengths, _ = self.patcher( + local_encoder_tokens, + patch_size=self.config.patch_size, + include_next_token=True, + threshold=self.config.patching_threshold, + threshold_add=self.config.patching_threshold_add, + monotonicity=self.config.monotonicity, + max_patch_length=self.config.max_patch_length, + patching_batch_size=self.config.patching_batch_size, + device=self.config.patching_device, + ) + else: + # self.config.patching_mode == PatchingModeEnum.byte + bs, seq_len = local_encoder_tokens.shape + seq_len_next_tok = seq_len + 1 # include_next_token=True + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + ) + + # Apply any processing to patch lengths + if self.config.max_patch_length is not None: + # TODO: avoid going back to a list here. + patch_lengths = [ + BLTPatcher.split_large_numbers(pl, self.config.max_patch_length) + for pl in patch_lengths.tolist() + ] + max_len = max([len(pl) for pl in patch_lengths]) + patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] + patch_lengths = torch.tensor( + patch_lengths, dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + ) + assert not check_non_zero_after_zero(patch_lengths) + # Find the last non-zero column index using argmax on a reversed version of the tensor + last_non_zero_col_reversed = ( + (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() + ) + # Slice the tensor up to the last non-zero column + patch_lengths = patch_lengths[ + :, : patch_lengths.shape[1] - last_non_zero_col_reversed + ] + else: + if nb_boe > 0: + patch_lengths[:, 0] += nb_boe + + assert torch.min(patch_lengths) >= 0 + + # Generate patch IDs from patch_lengths + patch_ids = self._patch_ids_from_lengths( + patch_lengths, local_encoder_tokens.shape[-1] + ) + assert torch.max(patch_ids) + 1 <= torch.max( + (patch_lengths != 0).sum(dim=-1) + ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" + + cross_attn_mask_enc = None + # Cross-attention encoder + if self.config.cross_attn_encoder: + cross_attn_mask_enc = cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=True, + cross_attn_k=self.config.cross_attn_k, + window=self.config.cross_attn_window_encoder, + block_mask=self.config.cross_attn_use_flex_attention, + ) + + # Hashing and embedding + local_encoder_embeds = compute_hash_embeddings( + local_encoder_tokens=local_encoder_tokens, + local_encoder=self.local_encoder, + encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, + encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab, + ) + + # NOTE: Frequency-based n-gram embeddings removed as per paper + # The final BLT model uses only hash-based n-gram embeddings + + # Local encoder + (h_encoder, h_cross), cache_encoder = self.local_encoder( + tokens=local_encoder_tokens, + embeds=local_encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + + # Downsampling + h = h_cross.view(bs, patch_lengths.shape[1], -1) + + # Global transformer + global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(BOE_ID) + rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id) + eos_patch_ids = patch_ids[rows, cols] + global_tokens[rows, eos_patch_ids] = self.config.eos_token_id + + h, _ = self.global_transformer( + embeds=h, + tokens=global_tokens, + ) + + # Unpatching + dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :] + + # Generate decoder patch IDs + decoder_patch_ids = self._decoder_patch_ids_from_lengths( + patch_lengths, nb_boe, local_decoder_tokens.shape[-1] + ) + assert ( + torch.max(decoder_patch_ids) + 1 <= h.shape[1] + ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" + assert ( + decoder_patch_ids.shape[1] == dec_embeds.shape[1] + ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" + + # Cross-attention decoder + if not self.config.cross_attn_decoder: + h = torch.gather( + h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + ) + cross_attn_mask_dec = None + assert local_decoder_tokens.shape == h.shape[:-1] + else: + cross_attn_mask_dec = cross_attn_mask( + decoder_patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=self.config.cross_attn_k, + window=self.config.cross_attn_window_decoder, + block_mask=self.config.cross_attn_use_flex_attention, + ) + + # Local decoder + output, _ = self.local_decoder( + embeds=dec_embeds, + patch_embeds=h, + tokens=local_decoder_tokens, + cross_mask=cross_attn_mask_dec, + ) + return output + + def init_weights(self): + self.local_encoder.init_weights() + self.global_transformer.init_weights() + self.local_decoder.init_weights() + + if self.encoder_hash_tok_embedding is not None: + emb_std = self.local_encoder.dim ** (-0.5) + for emb in self.encoder_hash_tok_embedding: + nn.init.trunc_normal_( + emb.weight, + mean=0.0, + std=emb_std, + a=-3 * emb_std, + b=3 * emb_std, + ) + + +class BLTPatcher(BLTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + # Store config reference for later use + self.config = config + + # Extract patcher parameters from BLTConfig + self.dim = config.patcher_dim + self.init_base_std = config.patcher_init_base_std + self.attn_impl = config.patcher_attn_impl + self.attn_bias_type = config.patcher_attn_bias_type + self.init_std_factor = config.patcher_init_std_factor + self.max_seqlen = config.patcher_max_seqlen + n_layers = config.patcher_n_layers + n_heads = config.patcher_n_heads + head_dim = config.patcher_head_dim + rope_theta = config.patcher_rope_theta + rope_use_fp32_in_outer_product = config.patcher_rope_use_fp32_in_outer_product + norm_eps = config.patcher_norm_eps + vocab_size = config.patcher_vocab_size + weight_tying = config.patcher_weight_tying + sliding_window = config.patcher_sliding_window + eos_token_id = config.patcher_eos_token_id + + self.rope_embeddings = RotaryEmbedding( + theta=rope_theta, + head_dim=head_dim or self.dim // n_heads, + max_seqlen=self.max_seqlen, + rope_use_fp32_in_outer_product=rope_use_fp32_in_outer_product, + ) + # Handle both eos_id and eos_token_id for compatibility + self.eos_id = eos_token_id + + # Extract additional parameters for BLTTransformerLayer + n_kv_heads = getattr(config, 'patcher_n_kv_heads', None) if hasattr(config, 'patcher_dim') else getattr(config, 'n_kv_heads', None) + multiple_of = getattr(config, 'patcher_multiple_of', 256) if hasattr(config, 'patcher_dim') else getattr(config, 'multiple_of', 256) + ffn_dim_multiplier = getattr(config, 'patcher_ffn_dim_multiplier', None) if hasattr(config, 'patcher_dim') else getattr(config, 'ffn_dim_multiplier', None) + + # Create a simple parameter dict for BLTTransformerLayer + layer_params = { + 'dim': self.dim, + 'n_heads': n_heads, + 'head_dim': head_dim, + 'n_kv_heads': n_kv_heads, + 'rope_theta': rope_theta, + 'multiple_of': multiple_of, + 'ffn_dim_multiplier': ffn_dim_multiplier, + 'norm_eps': norm_eps, + } + + self.layers = nn.ModuleList() + for _ in range(n_layers): + self.layers.append(BLTTransformerLayer(layer_params)) + + # LMTransformer specific attributes + self.weight_tying = weight_tying + self.sliding_window = sliding_window + + assert vocab_size > 0 + + self.tok_embeddings = torch.nn.Embedding(vocab_size, self.dim) + + self.norm = RMSNorm(self.dim, eps=norm_eps) + + self.output = nn.Linear( + self.dim, + vocab_size, + bias=False, + ) + + if self.weight_tying: + self.output.weight = self.tok_embeddings.weight + + def forward( + self, + token_values: torch.Tensor, + target: Optional[torch.Tensor] = None, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + attn_impl: str | None = None, + patch_size: Optional[int] = None, + include_next_token: bool = True, + threshold: Optional[float] = None, + threshold_add: Optional[float] = None, + monotonicity: bool = False, + max_patch_length: Optional[int] = None, + patching_batch_size: int = 1, # Changed from Optional[int] = None to int = 1 + device: Optional[str] = None, + enable_grad: bool = False, + ): + attn_impl = self.attn_impl if attn_impl is None else attn_impl + + # Handle chunked processing for entropy calculation + # grad_context = nullcontext() if enable_grad else torch.no_grad() + # with grad_context: + entropies = [] + preds = [] + max_length = min(getattr(self, "max_length", 8192), self.max_seqlen) + batch_numel = max_length * patching_batch_size + splits = torch.split(token_values.flatten(), batch_numel) + + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros( + pad_size, dtype=split.dtype, device=split.device, requires_grad=False + ) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length) + if device is not None: + split = split.to(device) + + # Process chunk: embeddings -> layers -> output + bsz, seqlen = split.shape + h = self.tok_embeddings(split) + chunk_mask = create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=split, + eos_id=self.eos_id, + ) + freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=attn_impl) + + pred = self.output(self.norm(h)) + pred = pred.reshape(-1, pred.shape[-1])[ + : split.numel() - pad_size, : + ] # [batch_size * seq_len, vocab] + preds.append(pred) + pred_entropies = self.entropy(pred) + entropies.append(pred_entropies) + + concat_entropies = torch.cat(entropies, dim=0) + concat_entropies = concat_entropies.reshape(token_values.shape) + concat_preds = torch.cat(preds, dim=0) + concat_preds = concat_preds.reshape(token_values.shape[0], -1) + + # Always compute patch lengths from concatenated entropies + bs, seq_len = token_values.shape + seq_len_next_tok = seq_len + 1 if include_next_token else seq_len + + # Find patch start IDs based on entropy + if patch_size is not None: + patch_start_ids = self.find_entropy_patch_start_ids( + concat_entropies, + patch_size, + include_next_token=include_next_token, + threshold=threshold, + threshold_add=threshold_add, + monotonicity=monotonicity, + ) + patch_lengths = self.patch_lengths_from_start_ids( + patch_start_ids, seq_len_next_tok + ) + else: + # Default to byte-level patching + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device + ) + + # Apply any processing to patch lengths + if max_patch_length is not None: + # TODO: avoid going back to a list here. + patch_lengths = [ + self.split_large_numbers(pl, max_patch_length) + for pl in patch_lengths.tolist() + ] + max_len = max([len(pl) for pl in patch_lengths]) + patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] + patch_lengths = torch.tensor( + patch_lengths, dtype=token_values.dtype, device=token_values.device + ) + assert not check_non_zero_after_zero(patch_lengths) + # Find the last non-zero column index using argmax on a reversed version of the tensor + last_non_zero_col_reversed = ( + (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() + ) + # Slice the tensor up to the last non-zero column + patch_lengths = patch_lengths[ + :, : patch_lengths.shape[1] - last_non_zero_col_reversed + ] + + return concat_entropies, patch_lengths, concat_preds + + def reset_parameters(self, init_std=None): + self.norm.reset_parameters() + + def init_weights(self): + self.reset_parameters() + init_std = self.dim ** (-0.5) + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + self.rope_embeddings.reset_parameters() + for depth, layer in enumerate(self.layers): + factor = self.config.get_init_std_factor(depth) + layer.init_weights(self.init_base_std, factor) + + if not self.weight_tying: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + @staticmethod + def entropy(scores): + """ + scores: [bs, seq_len, vocab] + returns [bs, seq_len] + + Computes the entropy for each token in the batch. + Note: uses natural log. + """ + log_probs = F.log_softmax(scores, dim=-1) + probs = torch.exp(log_probs) + p_log_p = log_probs * probs + entropy = -p_log_p.sum(dim=-1) + return entropy + + + + @staticmethod + def patch_start_ids_from_patch_start_mask(patch_start_mask): + bs, trunc_seq_len = patch_start_mask.shape + max_patches = patch_start_mask.sum(dim=1).max() + if max_patches == 0: + patch_start_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + else: + patch_ids = ( + torch.arange(trunc_seq_len, device=patch_start_mask.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + extra_patch_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) + patch_start_mask_padded = torch.cat( + (patch_start_mask, ~patch_start_mask), dim=1 + ) + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape( + bs, trunc_seq_len + )[:, :max_patches] + return patch_start_ids + + @staticmethod + def patch_lengths_from_start_ids(patch_start_ids, seq_len): + """ + Calculate patch lengths from start ids. + start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then + the rest are filled to the seq len. + seq_len: ex: 7 length of the sequence + + returns the patch lengths: + [1, 6] for the above example. + """ + last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) + patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) + patch_lengths = patch_end_ids - patch_start_ids + 1 + assert torch.all(patch_lengths >= 0), f"{patch_lengths}" + assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" + return patch_lengths + + @staticmethod + def find_entropy_patch_start_ids( + entropies, + patch_size=None, + threshold=None, + threshold_add=None, + monotonicity=False, + include_next_token=True, + ): + """ + Use entropies to find the start ids of each patch. + Use patch_size or threshold to figure out the total number of patches to allocate. + + When threshold is not None the number of patches is not constant between + different sequences, but patches can be identified incrementally rather than + decided globally using the entire sequence. + """ + bs, seq_len = entropies.shape[:2] + + first_ids = ( + torch.tensor([0, 1], dtype=torch.long, device=entropies.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + preds_truncation_len = first_ids.shape[ + 1 + ] # remove the first preds because they will be start of patches. + entropies = entropies[:, 1:] + if threshold is None: + num_patches = seq_len // patch_size + patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices + patch_start_ids = patch_start_ids.sort(dim=1).values + else: + patch_start_mask = entropies > threshold + if not include_next_token: + patch_start_mask = patch_start_mask[:, :-1] + # patch_start_mask[1:] |= tokens[:-1] < OFFSET + patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) + + patch_start_ids = torch.cat( + (first_ids, patch_start_ids + preds_truncation_len), dim=1 + ) + return patch_start_ids + + @staticmethod + def split_large_numbers(lst, m): + new_lst = [] + for i in lst: + if i > m: + while i > m: + new_lst.append(m) + i -= m + new_lst.append(i) + else: + new_lst.append(i) + assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" + return new_lst + + +def init_hash_embeddings( + config, + local_encoder_dim: int, + encoder_hash_byte_group_size: list, +): + """Initialize hash-based token embeddings for the BLT encoder.""" + if config.encoder_hash_byte_group_size is None: + return None + + embeddings = [] + emb_dim = local_encoder_dim + encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab + + for _ in range(config.encoder_hash_byte_group_nb_functions): + for _ in encoder_hash_byte_group_size: + embeddings.append( + nn.Embedding( + encoder_hash_byte_group_vocab, + emb_dim, + ) + ) + + return nn.ModuleList(embeddings) + + +__all__ = [ + "BLTPreTrainedModel", + "BLTModel", + "BLTPatcher", + "LocalEncoder", + "LocalDecoder", + "GlobalTransformer", +] \ No newline at end of file diff --git a/src/transformers/models/blt_wip/tokenizers/abstract_tokenizer.py b/src/transformers/models/blt_wip/tokenizers/abstract_tokenizer.py index f827302aaa4e..ff31d655ae34 100644 --- a/src/transformers/models/blt_wip/tokenizers/abstract_tokenizer.py +++ b/src/transformers/models/blt_wip/tokenizers/abstract_tokenizer.py @@ -12,9 +12,7 @@ def decode(self, tokens: list[int]): pass @abc.abstractmethod - def get_token_offsets( - self, text: str, tokens: list[int] | None = None - ) -> tuple[list[str], list[int]]: + def get_token_offsets(self, text: str, tokens: list[int] | None = None) -> tuple[list[str], list[int]]: """Return the offsets of the tokens in the original text. Only used for evaluation.""" pass diff --git a/src/transformers/models/blt_wip/tokenizers/blt_tokenizer.py b/src/transformers/models/blt_wip/tokenizers/blt_tokenizer.py index 6d874d910c11..2d018ff90ead 100644 --- a/src/transformers/models/blt_wip/tokenizers/blt_tokenizer.py +++ b/src/transformers/models/blt_wip/tokenizers/blt_tokenizer.py @@ -50,11 +50,7 @@ def text2bytes_bpe_delims( # Remove the '▁' characters bpe_strs = [] for i, bpe_str in enumerate(cur_bpe_strs): - if ( - len(bpe_strs) <= 1 - and all([c == " " for s in bpe_strs for c in s]) - and not all(c == "▁" for c in bpe_str) - ): + if len(bpe_strs) <= 1 and all([c == " " for s in bpe_strs for c in s]) and not all(c == "▁" for c in bpe_str): # Remove leading space for first non space token. bpe_str = bpe_str.replace("▁", "") elif i == 0 and all(c == "▁" for c in bpe_str): @@ -93,9 +89,7 @@ def __init__( self.bpe_id = BPE_ID self.bpe_tokenizer_path = bpe_tokenizer_path if bpe_delim: - self.bpe_tokenizer = SentencePieceTokenizer( - model_path=self.bpe_tokenizer_path - ) + self.bpe_tokenizer = SentencePieceTokenizer(model_path=self.bpe_tokenizer_path) else: self.bpe_tokenizer = None self.bpe_delim = bpe_delim @@ -106,9 +100,7 @@ def __init__( def get_vocab_size(self) -> int: return self.n_words - def encode( - self, text: str, add_bos: bool | None = None, add_eos: bool | None = None - ): + def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None): if add_bos is None: add_bos = self.add_bos if add_eos is None: @@ -143,11 +135,7 @@ def decode(self, tokens: list[int], cut_at_eos: bool = False): tokens = tokens[: k + 1] break return bytes( - [ - tok - self.offsetting_special_char - for tok in tokens - if tok - self.offsetting_special_char >= 0 - ] + [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0] ).decode("utf-8", errors="ignore") def get_token_offsets(self, text: str, tokens: list[int] | None = None): diff --git a/src/transformers/models/blt_wip/tokenizers/sentence_piece_tokenizer.py b/src/transformers/models/blt_wip/tokenizers/sentence_piece_tokenizer.py index f789ae77d7fa..fece8cbce7a8 100644 --- a/src/transformers/models/blt_wip/tokenizers/sentence_piece_tokenizer.py +++ b/src/transformers/models/blt_wip/tokenizers/sentence_piece_tokenizer.py @@ -2,6 +2,7 @@ import logging import os + try: from sentencepiece import SentencePieceProcessor @@ -11,13 +12,12 @@ from .abstract_tokenizer import Tokenizer + logger = logging.getLogger(__name__) class SentencePieceTokenizer(Tokenizer): - def __init__( - self, model_path: str, add_bos: bool = True, add_eos: bool = True - ) -> None: + def __init__(self, model_path: str, add_bos: bool = True, add_eos: bool = True) -> None: assert os.path.isfile(model_path), model_path self.sp_model = SentencePieceProcessor(model_file=model_path) @@ -30,9 +30,7 @@ def __init__( self.pad_id: int = self.sp_model.pad_id() self.add_bos = add_bos self.add_eos = add_eos - logger.info( - f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" - ) + logger.info(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() def get_vocab_size(self) -> int: @@ -45,17 +43,13 @@ def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = Non if add_eos is None: add_eos = self.add_eos assert type(s) is str - tokens = ( - [self.bos_id] * add_bos + self.sp_model.encode(s) + [self.eos_id] * add_eos - ) + tokens = [self.bos_id] * add_bos + self.sp_model.encode(s) + [self.eos_id] * add_eos return tokens def decode(self, tokens: list[int]): return self.sp_model.decode(tokens) - def get_token_offsets( - self, text: str, tokens: list[int] | None = None - ) -> tuple[list[str], list[int]]: + def get_token_offsets(self, text: str, tokens: list[int] | None = None) -> tuple[list[str], list[int]]: pieces = self.sp_model.encode_as_immutable_proto(text).pieces substrs = [p.surface for p in pieces] offsets = [p.begin for p in pieces] diff --git a/src/transformers/models/blt_wip/unified_blt_debug/config.json b/src/transformers/models/blt_wip/unified_blt_debug/config.json new file mode 100644 index 000000000000..67ab54a27541 --- /dev/null +++ b/src/transformers/models/blt_wip/unified_blt_debug/config.json @@ -0,0 +1,144 @@ +{ + "args": { + "alpha_depth": "disabled", + "architecture": "vanilla", + "attn_bias_type": "block_causal", + "attn_impl": "xformers", + "attn_to_keep": "all", + "conv_kernel_size": null, + "cross_attn_all_layers_decoder": true, + "cross_attn_all_layers_encoder": false, + "cross_attn_decoder": true, + "cross_attn_encoder": true, + "cross_attn_init_by_pooling": true, + "cross_attn_k": 2, + "cross_attn_nheads": 16, + "cross_attn_use_flex_attention": true, + "cross_attn_window_decoder": null, + "cross_attn_window_encoder": null, + "custom_bwd": false, + "dim": 512, + "dim_global": 2048, + "dim_local_decoder": 1024, + "dim_local_encoder": 1024, + "dim_patch_emb": null, + "dim_token": null, + "dim_token_emb": null, + "downsampling_by_pooling": "max", + "dropout": 0.0, + "encoder_enable_byte_group_hash": false, + "encoder_enable_byte_ngrams": false, + "encoder_hash_byte_group_nb_functions": 1, + "encoder_hash_byte_group_size": [ + 3, + 4, + 5, + 6, + 7, + 8 + ], + "encoder_hash_byte_group_vocab": 500002, + "encoder_lm_loss": false, + "encoder_ngram_table_dir": null, + "encoder_ngram_to_size_str": null, + "encoder_preds_low_entropy_toks": null, + "encoder_preds_random_toks": null, + "entropy_model_checkpoint_dir": null, + "entropy_model_is_ngram_model": false, + "eos_id": 2, + "ffn_dim_multiplier": 1.0, + "full_logging_n_layers": 4, + "fuse_sequence_parallel": false, + "global_local_decoder_residual_layer": null, + "head_dim": null, + "init_base_std": null, + "init_std_factor": "current_depth", + "init_use_depth": "current", + "init_use_gaussian": true, + "layer_ckpt": "none", + "local_attention_window_len": 512, + "log_patch_lengths": false, + "loss_parallel": false, + "max_encoder_seq_length": 24576, + "max_length": 256, + "max_patch_length": null, + "max_seqlen": 4096, + "monotonicity": false, + "multiple_of": 256, + "n_heads": 8, + "n_heads_global": 16, + "n_heads_local_decoder": 16, + "n_heads_local_encoder": 16, + "n_kv_heads": null, + "n_kv_heads_global": null, + "n_layers": 8, + "n_layers_global": 25, + "n_layers_local_decoder": 9, + "n_layers_local_encoder": 1, + "ngram_vocab_sizes": null, + "non_linearity": "swiglu", + "norm_affine": true, + "norm_eps": 1e-05, + "norm_type": "rmsnorm", + "output_size": -1, + "pad_to_max_length": true, + "patch_in_forward": true, + "patch_size": 4.5, + "patching_batch_size": 1, + "patching_device": "cuda", + "patching_mode": "entropy", + "patching_threshold": 1.335442066192627, + "patching_threshold_add": null, + "patching_thresholds_str": null, + "pm_size": 0, + "pre_norm": true, + "recompute_attn": false, + "recompute_fc1_out": false, + "recompute_fc3_out": false, + "rope_theta": 500000.0, + "rope_use_fp32_in_outer_product": true, + "seed": 42, + "sequence_parallel": false, + "share_encoder_decoder_emb": true, + "tie_local_encoder_decoder": false, + "tie_local_encoder_decoder_logits": false, + "tokenize_with_bpe_delimiter": false, + "use_fsdp": true, + "use_local_encoder_transformer": true, + "use_rope": true, + "vocab_size": 260, + "weight_tying": false + }, + "patch_in_forward": true, + "realtime_patching": true, + "patching_mode": "entropy", + "patch_size": 4.5, + "patching_threshold": 1.335442066192627, + "patching_threshold_add": null, + "max_patch_length": null, + "patching_batch_size": 1, + "patching_device": "cuda", + "monotonicity": false, + "patcher_vocab_size": 260, + "patcher_dim": 768, + "patcher_n_layers": 14, + "patcher_n_heads": 12, + "patcher_head_dim": null, + "patcher_n_kv_heads": null, + "patcher_max_seqlen": 8192, + "patcher_norm_eps": 1e-05, + "patcher_dropout": 0.0, + "patcher_sliding_window": 512, + "patcher_ffn_dim_multiplier": 1.0, + "patcher_multiple_of": 256, + "patcher_rope_theta": 10000.0, + "patcher_rope_use_fp32_in_outer_product": false, + "patcher_attn_impl": "xformers", + "patcher_attn_bias_type": "local_block_causal", + "patcher_init_base_std": null, + "patcher_init_std_factor": "current_depth", + "patcher_dim_token_emb": null, + "patcher_weight_tying": false, + "patcher_bos_token_id": 1, + "patcher_eos_token_id": 2 +} \ No newline at end of file From 12c000e876d659ed58a0aa797840316531278c3b Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 16 Jun 2025 07:46:09 +0000 Subject: [PATCH 014/139] clean a few comments --- .../models/blt_wip/modeling_blt_wip.py | 314 ++++-------------- 1 file changed, 60 insertions(+), 254 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt_wip.py b/src/transformers/models/blt_wip/modeling_blt_wip.py index 3cd5edccbc38..fe846fc0ab9f 100644 --- a/src/transformers/models/blt_wip/modeling_blt_wip.py +++ b/src/transformers/models/blt_wip/modeling_blt_wip.py @@ -16,6 +16,7 @@ PatchingModeEnum, ) + SEP = " " BOS_ID: int = 1 EOS_ID: int = 2 @@ -190,13 +191,6 @@ def __init__( persistent=False, ) - def reset_parameters(self): - self.freqs_cis[...] = precompute_freqs_cis( - dim=self.head_dim, - end=self.max_seqlen, - theta=self.theta, - rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, - ) def forward(self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None): """ @@ -315,25 +309,7 @@ def forward( return output - def reset_parameters(self, init_std=None, factor=1.0): - init_std = init_std or (self.dim ** (-0.5)) / factor - - for w in [self.wq, self.wk, self.wv]: - nn.init.trunc_normal_( - w.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - nn.init.trunc_normal_( - self.wo.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) class BLTMLP(nn.Module): @@ -379,31 +355,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = self.w2(F.silu(x1) * x3) return output - def reset_parameters(self, init_std=None, factor=1.0): - in_init_std = init_std or (self.dim ** (-0.5)) / factor - out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor - nn.init.trunc_normal_( - self.w1.weight, - mean=0.0, - std=in_init_std, - a=-3 * in_init_std, - b=3 * in_init_std, - ) - nn.init.trunc_normal_( - self.w2.weight, - mean=0.0, - std=out_init_std, - a=-3 * out_init_std, - b=3 * out_init_std, - ) - nn.init.trunc_normal_( - self.w3.weight, - mean=0.0, - std=in_init_std, - a=-3 * in_init_std, - b=3 * in_init_std, - ) class BLTTransformerLayer(nn.Module): @@ -465,12 +417,7 @@ def forward( out = h + self.feed_forward(h_norm) return out - def init_weights(self, init_std=None, factor=1.0): - self.attention.reset_parameters(init_std, factor) - self.attention_norm.reset_parameters() - self.feed_forward.reset_parameters(init_std, factor) - self.ffn_norm.reset_parameters() def rightpad(seq, pad_id, max_len): @@ -832,65 +779,7 @@ def apply_embedding(self, tokens, embeds): else: return self.tok_embeddings(tokens) - def init_weights(self, init_std=None): - self.rope.reset_parameters() - if hasattr(self, "norm"): - self.norm.reset_parameters() - - init_std = init_std or (self.dim ** (-0.5)) - if hasattr(self, "tok_embeddings"): - nn.init.trunc_normal_( - self.tok_embeddings.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - if self.pos_embeddings is not None: - nn.init.trunc_normal_( - self.pos_embeddings.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - for depth, layer in enumerate(self.layers): - factor = self.config.get_init_std_factor(depth) - layer.init_weights(self.init_base_std, factor) - - if hasattr(self, "output"): - nn.init.trunc_normal_( - self.output.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - if self.token_embedding_projection is not None: - nn.init.trunc_normal_( - self.token_embedding_projection.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - if self.patch_embedding_projection is not None: - patch_emb_std = self.dim_patch_emb ** (-0.5) - nn.init.trunc_normal_( - self.patch_embedding_projection.weight, - mean=0.0, - std=patch_emb_std, - a=-3 * patch_emb_std, - b=3 * patch_emb_std, - ) - if self.cross_attn_layers is not None: - for depth, layer in enumerate(self.cross_attn_layers): - factor = self.config.get_init_std_factor(depth) - layer.init_weights(None, factor) class LocalEncoder(LocalModelBase): @@ -1185,42 +1074,7 @@ def forward( return x + output - def init_weights(self, base_std: float, factor: float = 1.0): - std = base_std or (self.dim ** (-0.5)) / factor - - nn.init.trunc_normal_( - self.wq.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - nn.init.trunc_normal_( - self.wk.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - - nn.init.trunc_normal_( - self.wv.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - - nn.init.trunc_normal_( - self.wo.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - self.cross_attn_norm_q.reset_parameters() - self.cross_attn_norm_kv.reset_parameters() class GlobalTransformer(nn.Module): @@ -1309,22 +1163,7 @@ def forward( return h, cache - def init_weights(self): - self.rope_embeddings.reset_parameters() - for depth, layer in enumerate(self.layers): - factor = self.config.get_init_std_factor(depth) - layer.init_weights(self.init_base_std, factor) - # GlobalTransformer specific initialization - std = self.dim_token_emb ** (-0.5) - if self.token_embedding_projection is not None: - nn.init.trunc_normal_( - self.token_embedding_projection.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) def compute_hash_embeddings( @@ -1372,19 +1211,6 @@ def compute_hash_embeddings( class BLTPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - BLT models. - - This class provides the interface for model loading, saving, and weight initialization for all BLT model variants. - It inherits from [`PreTrainedModel`] which provides the core functionality for working with HuggingFace models. - - Args: - config ([`BLTConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. - """ - config_class = BLTConfig base_model_prefix = "model" supports_gradient_checkpointing = True @@ -1395,19 +1221,67 @@ class BLTPreTrainedModel(PreTrainedModel): _supports_cache_class = False def _init_weights(self, module): - """Initialize the weights - this is called by PreTrainedModel but we delegate to our custom init""" - # Don't do anything here - we use the custom init_weights method instead - pass + if isinstance(module, nn.Linear): + std = getattr(module, '_custom_std', module.in_features ** (-0.5)) + + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.Embedding): + std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) + + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)): + nn.init.ones_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, RotaryEmbedding): + module.freqs_cis[...] = precompute_freqs_cis( + dim=module.head_dim, + end=module.max_seqlen, + theta=module.theta, + rope_use_fp32_in_outer_product=module.rope_use_fp32_in_outer_product, + ) + + elif isinstance(module, BLTModel): + if module.encoder_hash_tok_embedding is not None: + emb_std = module.local_encoder.dim ** (-0.5) + for emb in module.encoder_hash_tok_embedding: + emb._custom_std = emb_std + + elif isinstance(module, (LocalEncoder, LocalDecoder)): + if module.token_embedding_projection is not None: + module.token_embedding_projection._custom_std = module.dim ** (-0.5) + + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.dim_patch_emb ** (-0.5) + + elif isinstance(module, GlobalTransformer): + if module.token_embedding_projection is not None: + module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5) + + elif isinstance(module, BLTPatcher): + emb_std = module.config.patcher_dim ** (-0.5) + module.tok_embeddings._custom_std = emb_std + module.output._custom_std = emb_std class BLTModel(BLTPreTrainedModel): - """ - The BLTModel (BLT) is a byte-level language model architecture that processes byte sequences - by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, - and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for - improved performance and inference efficiency. - """ - def __init__(self, config: BLTConfig): super().__init__(config) @@ -1430,24 +1304,7 @@ def __init__(self, config: BLTConfig): else: self.patcher = None - def init_weights(self): - self.local_encoder.init_weights() - self.global_transformer.init_weights() - self.local_decoder.init_weights() - - if self.encoder_hash_tok_embedding is not None: - emb_std = self.local_encoder.dim ** (-0.5) - for emb in self.encoder_hash_tok_embedding: - nn.init.trunc_normal_( - emb.weight, - mean=0.0, - std=emb_std, - a=-3 * emb_std, - b=3 * emb_std, - ) - if self.patcher is not None: - self.patcher.init_weights() def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: """ @@ -1723,7 +1580,6 @@ def __init__(self, config): ) # LMTransformer specific attributes - self.weight_tying = config.patcher_weight_tying self.sliding_window = config.patcher_sliding_window assert config.patcher_vocab_size > 0 @@ -1833,59 +1689,9 @@ def forward( return concat_entropies, patch_lengths, concat_preds - def init_weights(self): - """Initialize weights for the patcher model""" - # Initialize RoPE embeddings - self.rope_embeddings.reset_parameters() - - # Initialize norm layer - self.norm.reset_parameters() - - # Initialize token embeddings - emb_std = self.patcher_dim ** (-0.5) - nn.init.trunc_normal_( - self.tok_embeddings.weight, - mean=0.0, - std=emb_std, - a=-3 * emb_std, - b=3 * emb_std, - ) - # Initialize transformer layers - for depth, layer in enumerate(self.layers): - factor = self.config.get_init_std_factor(depth) - layer.init_weights(self.patcher_init_base_std, factor) - # Initialize output layer if not weight tied - if not self.weight_tying: - nn.init.trunc_normal_( - self.output.weight, - mean=0.0, - std=emb_std, - a=-3 * emb_std, - b=3 * emb_std, - ) - def _init_weights(self, module): - """Initialize weights for a specific module""" - if isinstance(module, nn.Linear): - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=self.patcher_init_base_std or (self.patcher_dim ** (-0.5)), - a=-3 * (self.patcher_init_base_std or (self.patcher_dim ** (-0.5))), - b=3 * (self.patcher_init_base_std or (self.patcher_dim ** (-0.5))), - ) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=self.patcher_init_base_std or (self.patcher_dim ** (-0.5)), - a=-3 * (self.patcher_init_base_std or (self.patcher_dim ** (-0.5))), - b=3 * (self.patcher_init_base_std or (self.patcher_dim ** (-0.5))), - ) @staticmethod def entropy(scores): @@ -2027,4 +1833,4 @@ def init_hash_embeddings( "LocalEncoder", "LocalDecoder", "GlobalTransformer", -] +] \ No newline at end of file From 4b2185db80e7a540d65419c30f4300194ced7aa5 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 16 Jun 2025 08:32:58 +0000 Subject: [PATCH 015/139] cleanup folder --- src/demo_hf.py | 6 +- src/transformers/models/blt_wip/blt_args.py | 187 -- .../{modeling_blt_wip.py => modeling_blt.py} | 0 .../models/blt_wip/modeling_blt_wip_backup.py | 2166 ----------------- .../models/blt_wip/tokenization_blt.py | 273 +++ .../models/blt_wip/tokenizers/__init__.py | 1 - .../blt_wip/tokenizers/abstract_tokenizer.py | 21 - .../blt_wip/tokenizers/blt_tokenizer.py | 143 -- .../tokenizers/sentence_piece_tokenizer.py | 56 - .../blt_wip/unified_blt_debug/config.json | 144 -- 10 files changed, 276 insertions(+), 2721 deletions(-) delete mode 100644 src/transformers/models/blt_wip/blt_args.py rename src/transformers/models/blt_wip/{modeling_blt_wip.py => modeling_blt.py} (100%) delete mode 100644 src/transformers/models/blt_wip/modeling_blt_wip_backup.py create mode 100644 src/transformers/models/blt_wip/tokenization_blt.py delete mode 100644 src/transformers/models/blt_wip/tokenizers/__init__.py delete mode 100644 src/transformers/models/blt_wip/tokenizers/abstract_tokenizer.py delete mode 100644 src/transformers/models/blt_wip/tokenizers/blt_tokenizer.py delete mode 100644 src/transformers/models/blt_wip/tokenizers/sentence_piece_tokenizer.py delete mode 100644 src/transformers/models/blt_wip/unified_blt_debug/config.json diff --git a/src/demo_hf.py b/src/demo_hf.py index d88be6783480..a1fa640c5c42 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -4,7 +4,7 @@ import torch from transformers.models.blt_wip.modeling_blt_wip import BLTModel -from transformers.models.blt_wip.tokenizers.blt_tokenizer import BltTokenizer +from transformers.models.blt_wip.tokenization_blt import BLTTokenizer logger = logging.getLogger() @@ -43,7 +43,7 @@ def generate( prompts: list[str] | None, *, model: BLTModel, - tokenizer: BltTokenizer, + tokenizer: BLTTokenizer, max_prompt_len: int = 256, max_gen_len: int = 256, use_sampling: bool = False, @@ -101,7 +101,7 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): blt_repo = "itazap/blt-1b" model = BLTModel.from_pretrained(blt_repo).to(device) - tokenizer = BltTokenizer(vocab_size_unit_1=model.config.vocab_size, add_bos=True, add_eos=True) + tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) prompts = [prompt] diff --git a/src/transformers/models/blt_wip/blt_args.py b/src/transformers/models/blt_wip/blt_args.py deleted file mode 100644 index e043d1dd20a8..000000000000 --- a/src/transformers/models/blt_wip/blt_args.py +++ /dev/null @@ -1,187 +0,0 @@ -from enum import Enum -from typing import Any - -from pydantic import BaseModel, ConfigDict, model_validator -from typing_extensions import Self - - -EOS_ID: int = 2 - - -class InitStdFactor(str, Enum): - DISABLED = "disabled" # Init std is divided by 1.0 - GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) - CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) - DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 - - -class PatchingModeEnum(str, Enum): - entropy = "entropy" - bpe = "bpe" - bpe_patcher = "bpe_patcher" - space = "space" - static = "static" - byte = "byte" - - -class LMTransformerArgs(BaseModel): - """Arguments for the Language Model Transformer (used as entropy model for patching)""" - - model_config = ConfigDict() - - # Basic architecture - dim: int = 512 - n_layers: int = 8 - head_dim: int | None = None - n_heads: int | None = None - n_kv_heads: int | None = None - - # Transformer configuration - max_seqlen: int = 1024 - norm_eps: float = 1e-5 - dropout: float = 0 - vocab_size: int = -1 - sliding_window: int | None = None - - # Feedforward - ffn_dim_multiplier: float | None = None - multiple_of: int = 256 - - # Positional encoding - rope_theta: float = 10000.0 - rope_use_fp32_in_outer_product: bool = False - - # Attention - attn_impl: str = "sdpa" - attn_bias_type: str = "causal" - - # Initialization - init_base_std: float | None = None - init_std_factor: InitStdFactor = InitStdFactor.DISABLED - - # Embedding dimensions - dim_token_emb: int | None = None - - # Model behavior - weight_tying: bool = False - seed: int = 42 - - # Special token config - eos_id: int = EOS_ID - - -class ByteLatentTransformerArgs(BaseModel): - """Arguments for the Byte Latent Transformer (main BLT model)""" - - model_config = ConfigDict() - - # Basic model configuration - seed: int = 42 - vocab_size: int = -1 - - # Main architecture dimensions (these will be used for creating transformer args) - dim: int = 512 - n_layers: int = 8 - head_dim: int | None = None - n_heads: int | None = None - n_kv_heads: int | None = None - - # Component-specific dimensions - dim_global: int = 512 - dim_local_decoder: int = 512 - dim_local_encoder: int = 512 - n_layers_global: int = 8 - n_layers_local_decoder: int = 8 - n_layers_local_encoder: int = 8 - n_heads_global: int = 8 - n_heads_local_decoder: int = 8 - n_heads_local_encoder: int = 8 - n_kv_heads_global: int | None = None - - # Transformer configuration (needed by transformer components) - max_seqlen: int = 1024 - norm_eps: float = 1e-5 - dropout: float = 0 - - # Feedforward (needed by transformer components) - ffn_dim_multiplier: float = 1.0 - multiple_of: int = 256 - - # Positional encoding (needed by transformer components) - rope_theta: float = 10000.0 - rope_use_fp32_in_outer_product: bool = False - - # Attention (needed by transformer components) - attn_impl: str = "sdpa" - attn_bias_type: str = "causal" - - # Initialization (needed by transformer components) - init_base_std: float | None = None - init_std_factor: InitStdFactor = InitStdFactor.DISABLED - - # Embedding dimensions (needed by transformer components) - dim_token_emb: int | None = None - - # Patching configuration - patch_in_forward: bool = False - realtime_patching: bool = True - patch_size: float | None = None - patching_mode: str | None = None - patching_threshold: float | None = None - patching_threshold_add: float | None = None - monotonicity: bool = False - patching_batch_size: int = 1 - patching_device: str = "cuda" - max_patch_length: int | None = None - entropy_model_checkpoint_dir: str | None = None - - # Cross attention configurations - cross_attn_encoder: bool = False - cross_attn_decoder: bool = False - cross_attn_window_encoder: int | None = None - cross_attn_window_decoder: int | None = None - cross_attn_k: int | None = None - cross_attn_nheads: int | None = None - cross_attn_all_layers_decoder: bool = False - cross_attn_all_layers_encoder: bool = False - cross_attn_use_flex_attention: bool = True - cross_attn_init_by_pooling: bool = False - - # Encoder configurations - use_local_encoder_transformer: bool = False - max_encoder_seq_length: int | None = None - encoder_hash_byte_group_size: Any | None = None - encoder_hash_byte_group_vocab: int = 30000 - encoder_hash_byte_group_nb_functions: int = 3 - encoder_enable_byte_ngrams: bool = False - encoder_ngram_to_size_str: str | None = None - downsampling_by_pooling: str | None = None - - # Architecture and dimensions - dim_token: int | None = None - share_encoder_decoder_emb: bool = True - weight_tying: bool = False - - # Attention configuration - local_attention_window_len: int | None = None - use_rope: bool = True - - # Performance optimization - sequence_parallel: bool = False - loss_parallel: bool = False - fuse_sequence_parallel: bool = False - use_fsdp: bool = True - - # Parameter mixing - pm_size: int = 0 - - # Special token config - eos_id: int = EOS_ID - - @model_validator(mode="after") - def check_hash_byte_sizes(self) -> Self: - if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str: - self.encoder_hash_byte_group_size = [ - int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0 - ] - return self diff --git a/src/transformers/models/blt_wip/modeling_blt_wip.py b/src/transformers/models/blt_wip/modeling_blt.py similarity index 100% rename from src/transformers/models/blt_wip/modeling_blt_wip.py rename to src/transformers/models/blt_wip/modeling_blt.py diff --git a/src/transformers/models/blt_wip/modeling_blt_wip_backup.py b/src/transformers/models/blt_wip/modeling_blt_wip_backup.py deleted file mode 100644 index adc4104dcbeb..000000000000 --- a/src/transformers/models/blt_wip/modeling_blt_wip_backup.py +++ /dev/null @@ -1,2166 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. - -from enum import Enum -from typing import Any, List, Optional, Tuple, Union - -import torch -from pydantic import model_validator -from torch import nn -from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention -import json -import logging - -import torch -import torch.nn -import torch.nn as nn -from torch.nn import functional as F - -import os -from contextlib import nullcontext - -SEP = " " -BOS_ID: int = 1 -EOS_ID: int = 2 -PAD_ID: int = -1 -BOE_ID: int = 0 -BPE_ID: int = 3 -OFFSET: int = 4 - -BYTE_UNITS: int = 256 - -RMSNorm = nn.RMSNorm - -logger = logging.getLogger() - -from .configuration_blt import ( - BLTConfig, - PatchingModeEnum, - InitStdFactor, -) - -from ...modeling_utils import PreTrainedModel -from ...utils import logging as transformers_logging - -flex_attention_comp = flex_attention - - -def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - -def create_causal_mask( - seqlen, - attn_impl: str, - attn_bias_type: str | None, - *, - eos_id: int | None = None, - tokens: torch.Tensor | None = None, - sliding_window: int | None = None, -): - if attn_impl == "sdpa": - BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) - - if attn_bias_type == "causal": - return "causal" - - if BLT_SUPPRESS_ATTN_ERROR == 1: - return "causal" - else: - raise ValueError( - "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" - ) - elif attn_impl == "flex_attention": - return create_block_mask(causal_mask, None, None, seqlen, seqlen) - else: - raise NotImplementedError( - f"Attention {attn_impl} with {sliding_window} sliding window not implemented" - ) - -def cross_entropy(pred, target, **kwargs): - return F.nll_loss( - F.log_softmax(pred.flatten(end_dim=-2).float(), -1), - target.flatten(end_dim=-1), - **kwargs, - ) - - -def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - - -def precompute_freqs_cis( - dim: int, - end: int, - theta: float = 10000.0, - rope_use_fp32_in_outer_product: bool = False, -): - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' - and the end index 'end'. The 'theta' parameter scales the frequencies. - The returned tensor contains complex values in complex64 data type. - - Args: - dim (int): Dimension of the frequency tensor. - end (int): End index for precomputing frequencies. - theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. - - Returns: - torch.Tensor: Precomputed frequency tensor with complex exponentials. - """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) - if rope_use_fp32_in_outer_product: - t = t.to(torch.float32) - - freqs = torch.outer(t, freqs).float() - - cos, sin = freqs.cos(), freqs.sin() - - return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): - """ - Reshape frequency tensor for broadcasting it with another tensor. - - This function reshapes the frequency tensor to have the same shape as the target tensor 'x' - for the purpose of broadcasting the frequency tensor during element-wise operations. - - Args: - freqs_cis (torch.Tensor): Frequency tensor to be reshaped. - x (torch.Tensor): Target tensor for broadcasting compatibility. - seq_dim (int): Sequence dimension index. - - Returns: - torch.Tensor: Reshaped frequency tensor. - """ - ndim = x.ndim - assert 0 <= seq_dim < ndim - assert freqs_cis.shape == ( - x.shape[seq_dim], - x.shape[-3], - 2, - 2, - ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" - shape = [ - d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2]) - ] + [2, 2] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - seq_dim: int, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 - xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 - freqs_cis = reshape_for_broadcast( - freqs_cis, xq_, seq_dim - ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 - xq_out = (xq_ * freqs_cis).sum(5).flatten(3) - xk_out = (xk_ * freqs_cis).sum(5).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed. -class RotaryEmbedding(torch.nn.Module): - """ - RotaryEmbedding Module - """ - - def __init__( - self, - theta: float, - head_dim: int, - max_seqlen: int = 1024, - rope_use_fp32_in_outer_product: bool = False, - ): - super().__init__() - - self.theta = theta - self.head_dim = head_dim - self.max_seqlen = max_seqlen - self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product - - self.register_buffer( - "freqs_cis", - precompute_freqs_cis( - dim=head_dim, - end=max_seqlen, - theta=theta, - rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, - ), - persistent=False, - ) - - def reset_parameters(self): - self.freqs_cis[...] = precompute_freqs_cis( - dim=self.head_dim, - end=self.max_seqlen, - theta=self.theta, - rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, - ) - - def forward( - self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None - ): - """ - Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions - Args: - seqlen (int): Contiguous sequence length - tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen - - Returns: - Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis - """ - test = (seqlen is not None) or (tok_idx is not None) - assert test, "Should provide atleast seqlen or tok_idx" - if tok_idx is not None: - return self.freqs_cis[tok_idx] - elif seqlen is not None: - return self.freqs_cis[0:seqlen] - - -class BLTAttention(nn.Module): - def __init__( - self, - dim: int, - head_dim: int, - n_heads: int, - n_kv_heads: int, - rope_theta: float, - ): - super().__init__() - - self.dim = dim - self.head_dim = head_dim - self.rope_theta = rope_theta - - self.n_heads = n_heads - self.n_kv_heads = n_kv_heads - self.heads_per_group = self.n_heads // self.n_kv_heads - - self.wq = nn.Linear( - dim, - n_heads * head_dim, - bias=False, - ) - self.wk = nn.Linear( - dim, - n_kv_heads * head_dim, - bias=False, - ) - self.wv = nn.Linear( - dim, - n_kv_heads * head_dim, - bias=False, - ) - - self.wo = nn.Linear( - n_heads * head_dim, - dim, - bias=False, - ) - - def forward( - self, - x: torch.Tensor, - freq_cis: torch.Tensor, - tok_idx: Optional[torch.Tensor] = None, - mask: Optional[Union[BlockMask, str]] = None, - attn_impl: str = "sdpa", - ) -> torch.Tensor: - # B S D - bsz, seq_len, dim = x.shape - xq = self.wq(x.view_as(x)) - xk = self.wk(x.view_as(x)) - xv = self.wv(x.view_as(x)) - - output_shape = xq.shape - # B S D -> B S H D - xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) - xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) - xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) - - xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) - - # This condition helps us be easily compatible - # with inference by adding a pluggable KVCache - if hasattr(self, "kv_cache"): - xk, xv = self.kv_cache.update(xk, xv, tok_idx) - - xk = repeat_kv(xk, self.heads_per_group, dim=2) - xv = repeat_kv(xv, self.heads_per_group, dim=2) - - if attn_impl == "flex_attention": - assert mask is None or isinstance(mask, BlockMask) - xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) - output = flex_attention_comp(xq, xk, xv, block_mask=mask) - output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - - elif attn_impl == "sdpa": - xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) - assert mask is None or isinstance(mask, (str, torch.Tensor)) - is_causal = (mask == "causal") if isinstance(mask, str) else False - mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None - output = F.scaled_dot_product_attention( - xq, - xk, - xv, - is_causal=is_causal, - attn_mask=mask, - ) - output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - else: - raise NotImplementedError( - f"Attention implementation {attn_impl} not supported" - ) - - output_reshaped = output.reshape(output_shape) - - output = self.wo(output_reshaped) - - return output - - def reset_parameters(self, init_std=None, factor=1.0): - init_std = init_std or (self.dim ** (-0.5)) / factor - - for w in [self.wq, self.wk, self.wv]: - nn.init.trunc_normal_( - w.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - nn.init.trunc_normal_( - self.wo.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - -class BLTMLP(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int, - ffn_dim_multiplier: Optional[float], - mp_size: int = 1, - ): - super().__init__() - - hidden_dim = int(2 * hidden_dim / 3) - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - assert hidden_dim % mp_size == 0 - - self.dim = dim - self.hidden_dim = hidden_dim - - self.w1 = nn.Linear( - dim, - hidden_dim, - bias=False, - ) - self.w3 = nn.Linear( - dim, - hidden_dim, - bias=False, - ) - self.w2 = nn.Linear( - hidden_dim, - dim, - bias=False, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # B S D - x1 = self.w1(x.view_as(x)) - x3 = self.w3(x.view_as(x)) - output = self.w2(F.silu(x1) * x3) - return output - - def reset_parameters(self, init_std=None, factor=1.0): - in_init_std = init_std or (self.dim ** (-0.5)) / factor - out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor - - nn.init.trunc_normal_( - self.w1.weight, - mean=0.0, - std=in_init_std, - a=-3 * in_init_std, - b=3 * in_init_std, - ) - nn.init.trunc_normal_( - self.w2.weight, - mean=0.0, - std=out_init_std, - a=-3 * out_init_std, - b=3 * out_init_std, - ) - nn.init.trunc_normal_( - self.w3.weight, - mean=0.0, - std=in_init_std, - a=-3 * in_init_std, - b=3 * in_init_std, - ) - - -class BLTTransformerLayer(nn.Module): - def __init__(self, args): - super().__init__() - - # Extract parameters from dictionary - dim = args['dim'] - n_heads = args['n_heads'] - head_dim = args['head_dim'] - n_kv_heads = args['n_kv_heads'] - rope_theta = args['rope_theta'] - multiple_of = args['multiple_of'] - ffn_dim_multiplier = args['ffn_dim_multiplier'] - norm_eps = args['norm_eps'] - - assert (head_dim is not None) or ( - n_heads is not None - ), "Should specify at least head_dim or n_heads" - self.head_dim = head_dim or dim // n_heads - self.n_heads = n_heads or dim // head_dim - self.n_kv_heads = n_kv_heads or self.n_heads - - assert n_heads % self.n_kv_heads == 0 - assert dim % n_heads == 0 - - self.attention = BLTAttention( - dim=dim, - head_dim=self.head_dim, - n_heads=self.n_heads, - n_kv_heads=self.n_kv_heads, - rope_theta=rope_theta, - ) - self.feed_forward = BLTMLP( - dim=dim, - hidden_dim=4 * dim, - multiple_of=multiple_of, - ffn_dim_multiplier=ffn_dim_multiplier, - ) - self.attention_norm = RMSNorm(dim, eps=norm_eps) - self.ffn_norm = RMSNorm(dim, eps=norm_eps) - - def forward( - self, - x: torch.Tensor, - freq_cis: torch.Tensor, - tok_idx: Optional[torch.Tensor] = None, - mask: Optional[Union[BlockMask, str]] = None, - attn_impl: str = "sdpa", - ) -> torch.Tensor: - norm_x = self.attention_norm(x) - attn_out = self.attention( - norm_x, - freq_cis, - tok_idx=tok_idx, - mask=mask, - attn_impl=attn_impl, - ) - h = x + attn_out - h_norm = self.ffn_norm(h) - out = h + self.feed_forward(h_norm) - return out - - def init_weights(self, init_std=None, factor=1.0): - self.attention.reset_parameters(init_std, factor) - self.attention_norm.reset_parameters() - - self.feed_forward.reset_parameters(init_std, factor) - self.ffn_norm.reset_parameters() - - -def rightpad(seq, pad_id, max_len): - return seq + [pad_id] * (max_len - len(seq)) - - -def check_non_zero_after_zero(tensor): - zero_mask = tensor == 0 - shifted_mask = torch.cat( - [ - torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), - zero_mask[:, :-1], - ], - dim=1, - ) - non_zero_after_zero = (tensor != 0) & shifted_mask - return non_zero_after_zero.any() - - -def fill_tokens(tokens, patch_size, fill_id): - batch_size, seq_len = tokens.shape - if seq_len % patch_size == 0: - return tokens - else: - remaining = patch_size - seq_len % patch_size - final_padding = tokens.new(batch_size, remaining).fill_(fill_id) - return torch.cat((tokens, final_padding), dim=1) - - -def rolling_polynomial_hash(t, hash_func_nb: int = 0): - primes = [ - 1000000007, - 5915587277, - 1500450271, - 3267000013, - 5754853343, - 4093082899, - 9576890767, - 3628273133, - 2860486313, - 5463458053, - 3367900313, - ] - prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) - prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) - return torch.sum(t * prime_powers, dim=-1) - -def byte_group_hash_function( - x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 -): - """ - Returns a hash of the input x and maps it to a value in the range [0, max_hash]. - - expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. - returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. - - Note: max hash can make a big difference on the number of collisions. - """ - with torch.no_grad(): - bs, seq_len = x.shape - prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) - x = torch.cat([prefix, x], dim=1) - windows = x.unfold(1, group_size, 1) - # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) - hashes = rolling_polynomial_hash(windows, hash_func_nb) - hash_values_range = hashes % max_hash - hash_values_range.requires_grad = False - return hash_values_range - - -def create_patch_mask_from_ids( - patch_ids, num_patches, window=None, patches_as_queries=False -): - """ - Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) - is True if the patch id at position (i, j) is less than or equal to k. - Args: - patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. - num_patches (int): Total number of patches. - window (int): If not None, only considers patches within a window of size window. - patches_as_queries (bool): If True, the patches are used as queries - Returns: - torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. - """ - bs, seq_len = patch_ids.shape - if not patches_as_queries: - q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) - kv_ids = ( - torch.arange(num_patches, device=patch_ids.device) - .unsqueeze(0) - .unsqueeze(0) - .expand(bs, seq_len, num_patches) - ) - else: - kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) - q_ids = ( - torch.arange(num_patches, device=patch_ids.device) - .unsqueeze(0) - .unsqueeze(-1) - .expand(bs, num_patches, seq_len) - ) - if window is None: - mask = q_ids == kv_ids - else: - mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window) - return mask - - -def cross_attn_mask( - patch_ids, - patch_lengths, - N, - patches_as_queries=False, - cross_attn_k=1, - window=None, - block_mask=True, -): - bs = patch_ids.shape[0] - with torch.no_grad(): - # Create the patch mask - cross_mask = create_patch_mask_from_ids( - patch_ids, - patch_lengths.shape[1], - window=window, - patches_as_queries=patches_as_queries, - ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1) - q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N - kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k - assert cross_mask.shape == ( - bs, - q_len, - kv_len, - ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" - block_mask = None - if block_mask: - - def patch_mask(b, h, q_idx, kv_idx): - return cross_mask[b, q_idx, kv_idx] - - block_mask = create_block_mask( - patch_mask, - B=bs, - H=None, - Q_LEN=q_len, - KV_LEN=kv_len, - _compile=True, - ) - return block_mask - else: - return torch.where( - cross_mask, torch.tensor(0.0), torch.tensor(float("-inf")) - ).unsqueeze( - 1 - ) # [bs, 1, q_len, kv_len] - - -def get_blt_input( - tokens: torch.Tensor, - enforce_patch_size_multiple: bool, - nb_boe: torch.Tensor, - patch_size: int, - boe_id: int, -): - """ - This function returns X_et, X_gt and X_dt, the encoder, global, and decoder - tokens respectively. - - Consider the input and target sequences: - X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13] - Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14] - with patch_size=4 - - Note 1: that there will be no special tokens introduced at the patch level. - Note 2: X_e needs to be trimmed to be passed to Global - - Current without boe: - X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] - X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch - X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] - Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] - - --> lag fix: - X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]] - X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]] - X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] - Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] - - Dynamic (current): - X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos] - Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] - - entropy patching: - input: 7, bos, 9, 10 - pred (high entropy): eos, 8, 10, eos - - X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos] - X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]] - X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]] - Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] - - --> lag fix no boe (force single byte first patch): - X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] - X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch - X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] - Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] - - input: 4, 7, bos, 9, 10 - pred (high entropy): 5, eos, 8, 10, eos - - X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] - X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch - X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] - Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] - - Handle the last byte properly. - patch_lengths = [1, 1, 3, 2, 2 1 2 2 1] - X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] - X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch - X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]] - Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]] - - - bpe delim - X_et = [[3,4,5,6,7,,eos,bos,,8,9,,10,,eos,bos,11,12] - X_g = [[3], [4,5,6,7,], [eos,bos,], .. - X_dt = [[3,4,5,6,7], [,eos,bos], [,bos,8], .. - Y = [4,5,6,7,, eos,bos, 8,9,, .. - - - Note 1: that there will be no special tokens introduced at the patch level. - Note 2: X_e needs to be trimmed to be passed to Global - """ - batch_size, seq_len = tokens.shape - local_encoder_tokens = tokens - local_decoder_tokens = tokens - - if nb_boe > 0: - padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id) - local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1) - # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id) - - # create global tokens, contains boe tokens and eos - # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) - # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size) - # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:] - # global_tokens += global_tokens.eq(0).int() * boe_id - # TODO: fix this when we want to use block causal in the global. - - if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0: - local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) - - return local_encoder_tokens, None, local_decoder_tokens - - -class LocalModelBase(nn.Module): - def __init__(self, config: BLTConfig, component_type: str = "encoder"): - super().__init__() - - # Store config for later use - self.config = config - - # Use component-specific dimensions - if component_type == "encoder": - self.dim = config.dim_local_encoder - self.n_layers = config.n_layers_local_encoder - self.n_heads = config.n_heads_local_encoder - self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen - self.attn_bias_type = "local_block_causal" - self.sliding_window = config.local_attention_window_len - elif component_type == "decoder": - self.dim = config.dim_local_decoder - self.n_layers = config.n_layers_local_decoder - self.n_heads = config.n_heads_local_decoder - self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen - self.attn_bias_type = "local_block_causal" - self.sliding_window = config.local_attention_window_len - else: - raise ValueError(f"Unknown component_type: {component_type}") - - self.dropout = config.dropout - self.vocab_size = config.vocab_size + config.pm_size - self.patch_size = config.patch_size - - self.attn_impl = config.attn_impl - self.use_rope = config.use_rope - self.init_std_factor = config.init_std_factor - self.init_base_std = config.init_base_std - self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None) - self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None) - self.cross_attn_k = getattr(config, "cross_attn_k", None) - self.eos_id = config.eos_token_id - - self.boe_id = BOE_ID - - # Initialize cross attention layers as None (will be set by subclasses if needed) - self.cross_attn_layers = None - - # Create parameter dict for BLTTransformerLayers - layer_params = { - 'dim': self.dim, - 'n_heads': self.n_heads, - 'head_dim': config.head_dim, - 'n_kv_heads': getattr(config, 'n_kv_heads', None), - 'rope_theta': config.rope_theta, - 'multiple_of': getattr(config, 'multiple_of', 256), - 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None), - 'norm_eps': config.norm_eps, - } - - self.layers = nn.ModuleList( - [BLTTransformerLayer(layer_params) for _ in range(self.n_layers)] - ) - - if not self.use_rope: - self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length - else: - self.rope = RotaryEmbedding( - theta=config.rope_theta, - head_dim=config.head_dim or self.dim // self.n_heads, - max_seqlen=self.max_seqlen, - rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, - ) - self.pos_embeddings = None - - # Set dimension-specific embedding dimensions - if component_type == "encoder": - self.dim_token_emb = config.encoder_dim_token_emb - self.dim_patch_emb = config.encoder_dim_patch_emb - elif component_type == "decoder": - self.dim_token_emb = config.decoder_dim_token_emb - self.dim_patch_emb = config.dim_global - - self.token_embedding_projection = ( - nn.Linear(self.dim_token_emb, self.dim, bias=False) - if self.dim_token_emb is not None and self.dim_token_emb != self.dim - else None - ) - - self.patch_embedding_projection = self._create_patch_projection(config) - - def _should_create_patch_projection(self, config: BLTConfig): - dimension_mismatch = ( - self.dim_patch_emb is not None and self.dim_patch_emb != self.dim - ) - - # Check cross attention conditions - cross_attn_conditions = ( - config.cross_attn_encoder and config.cross_attn_init_by_pooling - ) or (config.cross_attn_decoder and config.cross_attn_init_by_pooling) - - return dimension_mismatch or cross_attn_conditions - - def _create_patch_projection(self, config): - if not self._should_create_patch_projection(config): - return None - - output_dim = self.dim_token_emb * (self.cross_attn_k or 1) - - return nn.Linear( - in_features=self.dim_patch_emb, - out_features=output_dim, - bias=False, - ) - - def apply_embedding(self, tokens, embeds): - if embeds is not None: - return embeds - else: - return self.tok_embeddings(tokens) - - def init_weights(self, init_std=None): - self.rope.reset_parameters() - if hasattr(self, "norm"): - self.norm.reset_parameters() - - init_std = init_std or (self.dim ** (-0.5)) - if hasattr(self, "tok_embeddings"): - nn.init.trunc_normal_( - self.tok_embeddings.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - if self.pos_embeddings is not None: - nn.init.trunc_normal_( - self.pos_embeddings.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - for depth, layer in enumerate(self.layers): - factor = self.config.get_init_std_factor(depth) - layer.init_weights(self.init_base_std, factor) - - if hasattr(self, "output"): - nn.init.trunc_normal_( - self.output.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - if self.token_embedding_projection is not None: - nn.init.trunc_normal_( - self.token_embedding_projection.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - if self.patch_embedding_projection is not None: - patch_emb_std = self.dim_patch_emb ** (-0.5) - nn.init.trunc_normal_( - self.patch_embedding_projection.weight, - mean=0.0, - std=patch_emb_std, - a=-3 * patch_emb_std, - b=3 * patch_emb_std, - ) - - if self.cross_attn_layers is not None: - for depth, layer in enumerate(self.cross_attn_layers): - factor = self.config.get_init_std_factor(depth) - layer.init_weights(None, factor) - - -class LocalEncoder(LocalModelBase): - def __init__(self, config: BLTConfig): - super().__init__(config, component_type="encoder") - - self.apply_transformer = config.use_local_encoder_transformer - self.downsampling_by_pooling = config.downsampling_by_pooling - self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None - self.cross_attn_encoder = config.cross_attn_encoder - self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder - self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling - self.cross_attn_nheads = config.cross_attn_nheads - - self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim) - - if self.cross_attn_encoder: - self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1 - for _ in range(layers_to_add): - self.cross_attn_layers.append( - BLTCrossAttention( - dim=self.dim, - head_dim=self.dim // self.cross_attn_nheads, - n_heads=self.cross_attn_nheads, - n_kv_heads=self.cross_attn_nheads, - norm_eps=config.norm_eps, - ) - ) - - def apply_embedding(self, tokens, embeds): - if embeds is not None: - assert ( - self.expects_hash_embeddings - ), "Not expecting embeddings to be passed." - return embeds - else: - return self.tok_embeddings(tokens) - - def forward( - self, - tokens: torch.Tensor, - embeds: Optional[torch.Tensor] = None, - patch_embeds: Optional[torch.Tensor] = None, - mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, - cross_mask: Optional[torch.Tensor] = None, - num_patches: Optional[int] = None, - patch_ids: Optional[torch.Tensor] = None, - cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, - ): - """ """ - bs, seqlen = tokens.shape - if mask is None: - mask = create_causal_mask( - seqlen, - self.attn_impl, - "local_block_causal", - sliding_window=self.sliding_window, - tokens=tokens, - eos_id=self.eos_id, - ) - - h = self.apply_embedding(tokens, embeds) - freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None - - h = F.dropout(h, p=self.dropout, training=self.training) - - for i, layer in enumerate(self.layers): - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) - # check if cross attention should be applied to either all layer or only the last layer - if self.cross_attn_encoder and ( - i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder - ): - # apply pooling and project - if self.cross_attn_init_by_pooling and patch_embeds is None: - patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) - if self.patch_embedding_projection is not None: - patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape( - bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim - ) - - layer_idx = i if self.cross_attn_all_layers_encoder else 0 - patch_embeds_cross = self.cross_attn_layers[layer_idx]( - x=patch_embeds, - kv=h, - mask=cross_mask, - ) - patch_embeds = patch_embeds + patch_embeds_cross - - h_residual = patch_embeds if self.cross_attn_encoder else None - return (h, h_residual), cache - - - - def patch_reduce(self, h, max_num_patches, reduction, patch_ids): - """ - Reduce variable length patches to single embedding per patch - Note: this works with variable number of patches for different sequences in the batch - It handles variable length patches by assuming that patch_lengths will be 0 for any - extra patches on the *right*. Since there can be a variable number of patches - this function also return the number of patches for each sequence in the batch. - Any embeddings on the right that are not allocated to a patch - (i.e. if the sum(patch_lengths[i]) < seq_len for any i) - will be sent to a dummy patch, which is trimmed before returning. - """ - bs, seq_len, emb_dim = h.shape - - patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) - - reduced_embs = torch.zeros( - (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device - ) - reduced_embs = reduced_embs.scatter_reduce( - src=h, - dim=1, - index=patch_ids, - reduce=reduction, - include_self=False, - ) - reduced_embs = reduced_embs[:, :max_num_patches, :] - - return reduced_embs - - -class LocalDecoder(LocalModelBase): - def __init__(self, config: BLTConfig): - super().__init__(config, component_type="decoder") - - # Model configuration flags - self.cross_attn_decoder = config.cross_attn_decoder - self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder - self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling - self.cross_attn_nheads = config.cross_attn_nheads - - self.norm = RMSNorm(self.dim, eps=config.norm_eps) - - if self.cross_attn_decoder: - self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1 - for _ in range(layers_to_add): - self.cross_attn_layers.append( - BLTCrossAttention( - dim=self.dim, - head_dim=self.dim // self.cross_attn_nheads, - n_heads=self.cross_attn_nheads, - n_kv_heads=self.cross_attn_nheads, - norm_eps=config.norm_eps, - ) - ) - - self.output = nn.Linear( - self.dim, - config.vocab_size, - bias=False, - ) - - def forward( - self, - tokens: torch.Tensor, - embeds: Optional[torch.Tensor], - patch_embeds: Optional[torch.Tensor] = None, - mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, - cross_mask: Optional[torch.Tensor] = None, - cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, - ): - bs, seqlen = tokens.shape - assert embeds is not None, "Embeddings must be provided" - - if mask is None: - mask = create_causal_mask( - seqlen, - self.attn_impl, - "local_block_causal", - sliding_window=self.sliding_window, - tokens=tokens, - eos_id=self.eos_id, - ) - - h = embeds - - if self.patch_embedding_projection is not None: - assert patch_embeds is not None, "Patch embeddings must be passed." - patch_embeds = self.patch_embedding_projection(patch_embeds) - if self.cross_attn_k is not None: - patch_embeds = patch_embeds.reshape( - bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim - ) - - if patch_embeds is not None and not self.cross_attn_decoder: - h = h + patch_embeds - - freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None - - h = F.dropout(h, p=self.dropout, training=self.training) - for i, layer in enumerate(self.layers): - if self.cross_attn_decoder and ( - i == 0 or self.cross_attn_all_layers_decoder - ): - # Use cross attention to extract info from patch_embeds into h - h_cross = self.cross_attn_layers[i]( - x=h, - kv=patch_embeds, - mask=cross_mask, - ) - h = h + h_cross - - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) - - h_preds = self.norm(h) - h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) - h_preds = self.output(h_preds) - h_preds = h_preds.float() - return h_preds, cache - - -class BLTCrossAttention(nn.Module): - """ - BLTCrossAttention block to attend to the encoder states from the decoder. - Rope is not supported. - """ - - def __init__( - self, - dim: int, - head_dim: int, - n_heads: int, - n_kv_heads: int, - norm_eps: float, - ): - super().__init__() - - self.dim = dim - self.head_dim = head_dim - - self.n_heads = n_heads - self.n_kv_heads = n_kv_heads - self.heads_per_group = self.n_heads // self.n_kv_heads - - self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) - self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) - - self.wq = nn.Linear( - dim, - n_heads * head_dim, - bias=False, - ) - self.wk = nn.Linear( - dim, - n_kv_heads * head_dim, - bias=False, - ) - self.wv = nn.Linear( - dim, - n_kv_heads * head_dim, - bias=False, - ) - - self.wo = nn.Linear( - n_heads * head_dim, - dim, - bias=False, - ) - - def forward( - self, - x: torch.Tensor, - kv: torch.Tensor, - mask: Optional[Union[BlockMask, str]] = None, - ) -> torch.Tensor: - # B S D - bsz, seq_len, _ = x.shape - _, slen_kv, _ = kv.shape - x_norm = self.cross_attn_norm_q(x) - kv = self.cross_attn_norm_kv(kv) - - xq = self.wq(x_norm) - xk = self.wk(kv) - xv = self.wv(kv) - - output_shape = xq.shape - # B S D -> B S H D - xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) - xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) - xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) - - xk = repeat_kv(xk, self.heads_per_group, dim=2) - xv = repeat_kv(xv, self.heads_per_group, dim=2) - - # assert mask is None or isinstance(mask, BlockMask) - xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) - #output = flex_attention_comp(xq, xk, xv, block_mask=mask) - is_causal = (mask == "causal") if isinstance(mask, str) else False - mask = mask if isinstance(mask, torch.Tensor) else None - mask = mask.to(dtype=xq.dtype).to(xq.device) - output = F.scaled_dot_product_attention( - xq, - xk, - xv, - is_causal=is_causal, - attn_mask=mask, - ) - output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - - output = self.wo(output.reshape(output_shape)) - - return x + output - - def init_weights(self, base_std: float, factor: float = 1.0): - std = base_std or (self.dim ** (-0.5)) / factor - - nn.init.trunc_normal_( - self.wq.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - - nn.init.trunc_normal_( - self.wk.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - - nn.init.trunc_normal_( - self.wv.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - - nn.init.trunc_normal_( - self.wo.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - self.cross_attn_norm_q.reset_parameters() - self.cross_attn_norm_kv.reset_parameters() - - -class GlobalTransformer(nn.Module): - def __init__(self, config): - super().__init__() - - # Store config for later use - self.config = config - - self.dim = config.dim - self.init_base_std = config.init_base_std - self.attn_impl = config.attn_impl - self.attn_bias_type = config.attn_bias_type - self.init_std_factor = config.init_std_factor - self.max_seqlen = config.max_seqlen - self.rope_embeddings = RotaryEmbedding( - theta=config.rope_theta, - head_dim=config.head_dim or config.dim // config.n_heads, - max_seqlen=config.max_seqlen, - rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, - ) - # Handle both eos_id and eos_token_id for compatibility - self.eos_id = getattr(config, 'eos_id', getattr(config, 'eos_token_id', 2)) - - # Create parameter dict for BLTTransformerLayers - layer_params = { - 'dim': self.dim, - 'n_heads': config.n_heads, - 'head_dim': config.head_dim, - 'n_kv_heads': getattr(config, 'n_kv_heads', None), - 'rope_theta': config.rope_theta, - 'multiple_of': getattr(config, 'multiple_of', 256), - 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None), - 'norm_eps': config.norm_eps, - } - - self.layers = nn.ModuleList() - for _ in range(config.n_layers): - self.layers.append(BLTTransformerLayer(layer_params)) - - # GlobalTransformer specific attributes - self.dropout = config.dropout - self.dim_token_emb = config.dim_token_emb - - self.token_embedding_projection = None - if config.dim_token_emb is not None and config.dim_token_emb != self.dim: - self.token_embedding_projection = nn.Linear( - config.dim_token_emb, - config.dim, - bias=False, - ) - - def forward( - self, - tokens: torch.Tensor, - tok_idx: Optional[torch.Tensor] = None, - embeds: Optional[torch.Tensor] = None, - mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, - cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, - ): - bs, seqlen = tokens.shape - - h = embeds - - mask = ( - mask - if mask is not None - else create_causal_mask( - seqlen, - self.attn_impl, - self.attn_bias_type, - tokens=tokens, - eos_id=self.eos_id, - ) - ) - - if self.token_embedding_projection is not None and h.shape[-1] != self.dim: - h = self.token_embedding_projection(h) - - h = F.dropout(h, p=self.dropout, training=self.training) - - freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) - - for i, layer in enumerate(self.layers): - h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) - - return h, cache - - def init_weights(self): - self.rope_embeddings.reset_parameters() - for depth, layer in enumerate(self.layers): - factor = self.config.get_init_std_factor(depth) - layer.init_weights(self.init_base_std, factor) - - # GlobalTransformer specific initialization - std = self.dim_token_emb ** (-0.5) - if self.token_embedding_projection is not None: - nn.init.trunc_normal_( - self.token_embedding_projection.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - -def compute_hash_embeddings( - local_encoder_tokens: torch.Tensor, - local_encoder, - encoder_hash_tok_embedding: nn.ModuleList, - encoder_hash_byte_group_nb_functions: int, - encoder_hash_byte_group_size: list, - encoder_hash_byte_group_vocab: int, -) -> torch.Tensor: - """ - Compute embeddings using hash token embeddings. - - Args: - local_encoder_tokens: Input tokens tensor - local_encoder: Encoder object with tok_embeddings method - encoder_hash_tok_embedding: ModuleList of hash token embeddings - encoder_hash_byte_group_nb_functions: Number of hash functions - encoder_hash_byte_group_size: List of byte group sizes - encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings - - Returns: - torch.Tensor: Combined embeddings - """ - if encoder_hash_tok_embedding is None: - return None - - local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens) - - i = 0 - for func_nb in range(encoder_hash_byte_group_nb_functions): - for byte_group_size in encoder_hash_byte_group_size: - hash_ids = byte_group_hash_function( - local_encoder_tokens, - byte_group_size, - hash_func_nb=func_nb, - max_hash=encoder_hash_byte_group_vocab, - ) - hash_tok_embedding = encoder_hash_tok_embedding[i] - local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) - i += 1 - - assert i == len(encoder_hash_tok_embedding) - return local_encoder_embeds - - -class BLTPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - BLT models. - - This class provides the interface for model loading, saving, and weight initialization for all BLT model variants. - It inherits from [`PreTrainedModel`] which provides the core functionality for working with HuggingFace models. - - Args: - config ([`BLTConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. - """ - - config_class = BLTConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["BLTTransformerLayer", "LocalEncoder", "LocalDecoder", "GlobalTransformer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = False # BLT uses its own attention implementation - _supports_sdpa = True - _supports_cache_class = False - - def _init_weights(self, module): - """Initialize the weights - this is called by PreTrainedModel but we delegate to our custom init""" - # Don't do anything here - we use the custom init_weights method instead - pass - - -class BLTModel(BLTPreTrainedModel): - """ - The BLTModel (BLT) is a byte-level language model architecture that processes byte sequences - by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, - and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for - improved performance and inference efficiency. - """ - - def __init__(self, config: BLTConfig): - super().__init__(config) - - # Store config reference - self.config = config - - # Create main components - they will read their parameters from config - self.local_encoder = LocalEncoder(config) - - # Create global-specific config by copying config and overriding dimensions - global_config = type(config)(**config.to_dict()) - global_config.dim = config.dim_global - global_config.n_layers = config.n_layers_global - global_config.n_heads = config.n_heads_global - global_config.n_kv_heads = config.n_kv_heads_global - global_config.dim_token_emb = config.global_dim_patch_emb - - self.global_transformer = GlobalTransformer(global_config) - self.local_decoder = LocalDecoder(config) - - # Initialize hash embeddings - self.encoder_hash_tok_embedding = init_hash_embeddings( - config, - local_encoder_dim=self.local_encoder.dim, - encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, - ) - - # Initialize patcher if needed - if config.patch_in_forward: - if config.realtime_patching and config.entropy_model_checkpoint_dir is not None: - # Load entropy model directly - entropy_model_checkpoint_dir = config.entropy_model_checkpoint_dir - - if not os.path.exists(entropy_model_checkpoint_dir): - raise FileNotFoundError(f"Entropy model checkpoint directory not found: {entropy_model_checkpoint_dir}") - - # Load entropy model parameters - params_path = os.path.join(entropy_model_checkpoint_dir, "params.json") - if not os.path.exists(params_path): - raise FileNotFoundError(f"params.json not found in: {entropy_model_checkpoint_dir}") - - with open(params_path) as fr: - reloaded = json.loads(fr.read()) - - torch.set_default_dtype(torch.bfloat16) - model_params = reloaded["entropy_model"] - logger.warning( - "Update checkpoint to load attn and sliding window args from checkpoint" - ) - - # Override patcher configuration with actual entropy model parameters from checkpoint - config.patcher_dim = model_params["dim"] - config.patcher_n_layers = model_params["n_layers"] - config.patcher_n_heads = model_params["n_heads"] - config.patcher_max_seqlen = model_params["max_seqlen"] - config.patcher_ffn_dim_multiplier = model_params["ffn_dim_multiplier"] - config.patcher_vocab_size = model_params["vocab_size"] - # Use sensible defaults for parameters not in checkpoint - config.patcher_attn_bias_type = "local_block_causal" - config.patcher_attn_impl = "sdpa" # originally xformers - config.patcher_sliding_window = 512 - - # BLTPatcher will extract patcher_ parameters from config directly - self.patcher = BLTPatcher(config) - - state_path = os.path.join( - entropy_model_checkpoint_dir, "consolidated.pth" - ) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.patcher.load_state_dict( - torch.load(state_path, map_location=device)["model"], strict=False - ) - self.patcher.to(device) - self.patcher = self.patcher.eval() - # no grads for the model: - for param in self.patcher.parameters(): - param.requires_grad = False - else: - self.patcher = None - - # Initialize weights and apply final processing - self.post_init() - - def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: - """ - Convert patch lengths to patch IDs for each token position. - - For each token position in the sequence, determines which patch it belongs to. - - Args: - patch_lengths: [batch_size, num_patches] - length of each patch - seq_len: total sequence length - - Returns: - patch_ids: [batch_size, seq_len] - patch index for each token position - - Example: - patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1 - seq_len = 10 - Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]] - # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3 - """ - batch_size, num_patches = patch_lengths.shape - - # Create patch start positions: [0, 3, 5, 9] for the example above - patch_starts = torch.cat([ - torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), - patch_lengths.cumsum(dim=-1)[:, :-1] # cumsum without the final total - ], dim=-1) - - # For each token position, find which patch it belongs to - # by finding the rightmost patch start that's <= the position - token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1] - - # Broadcasting: patch_starts[batch, patch] <= token_positions[position] - # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t - position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1) - - # Count how many patch starts are <= each position, then subtract 1 to get patch index - patch_ids = position_ge_patch_start.sum(dim=-1) - 1 - - return patch_ids - - def _decoder_patch_ids_from_lengths(self, patch_lengths: torch.Tensor, nb_boe: int, seq_len: int) -> torch.Tensor: - """ - Create decoder patch IDs by skipping the first encoder patch. - - The decoder starts after the first patch (which contains BOE tokens), - so we need to map decoder positions to the remaining patches. - - Args: - patch_lengths: [batch_size, num_patches] from encoder - nb_boe: number of beginning-of-example tokens in first patch - seq_len: decoder sequence length - - Returns: - decoder_patch_ids: [batch_size, seq_len] mapping decoder positions to patch indices - """ - # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens) - decoder_patch_lengths = patch_lengths[:, 1:] - - # Create patch IDs for the decoder sequence using the remaining patches - return self._patch_ids_from_lengths(decoder_patch_lengths, seq_len) - - - - def forward( - self, - tokens: torch.Tensor, - patch_lengths: Optional[torch.Tensor] = None, - ): - # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings - # are no longer used in the final BLT model - - bs, N = tokens.shape # Batch size and sequence length - - # Get megabyte inputs - nb_boe = int(0 if self.config.patching_mode != "" else self.config.patch_size - 1) - local_encoder_tokens, _, local_decoder_tokens = get_blt_input( - tokens=tokens, - enforce_patch_size_multiple=False, - nb_boe=nb_boe, - patch_size=self.config.patch_size, - boe_id=BOE_ID, - ) - - # Patching - if patch_lengths is None: - # assert ( - # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward - # ), "Patch in forward not enabled and no patch_lengths passed." - - # PATCHER MODEL DEFINED - if self.config.patching_mode == PatchingModeEnum.entropy: - _, patch_lengths, _ = self.patcher( - local_encoder_tokens, - patch_size=self.config.patch_size, - include_next_token=True, - threshold=self.config.patching_threshold, - threshold_add=self.config.patching_threshold_add, - monotonicity=self.config.monotonicity, - max_patch_length=self.config.max_patch_length, - patching_batch_size=self.config.patching_batch_size, - device=self.config.patching_device, - ) - else: - # self.config.patching_mode == PatchingModeEnum.byte - bs, seq_len = local_encoder_tokens.shape - seq_len_next_tok = seq_len + 1 # include_next_token=True - patch_lengths = torch.ones( - (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device - ) - - # Apply any processing to patch lengths - if self.config.max_patch_length is not None: - # TODO: avoid going back to a list here. - patch_lengths = [ - BLTPatcher.split_large_numbers(pl, self.config.max_patch_length) - for pl in patch_lengths.tolist() - ] - max_len = max([len(pl) for pl in patch_lengths]) - patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] - patch_lengths = torch.tensor( - patch_lengths, dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device - ) - assert not check_non_zero_after_zero(patch_lengths) - # Find the last non-zero column index using argmax on a reversed version of the tensor - last_non_zero_col_reversed = ( - (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() - ) - # Slice the tensor up to the last non-zero column - patch_lengths = patch_lengths[ - :, : patch_lengths.shape[1] - last_non_zero_col_reversed - ] - else: - if nb_boe > 0: - patch_lengths[:, 0] += nb_boe - - assert torch.min(patch_lengths) >= 0 - - # Generate patch IDs from patch_lengths - patch_ids = self._patch_ids_from_lengths( - patch_lengths, local_encoder_tokens.shape[-1] - ) - assert torch.max(patch_ids) + 1 <= torch.max( - (patch_lengths != 0).sum(dim=-1) - ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" - - cross_attn_mask_enc = None - # Cross-attention encoder - if self.config.cross_attn_encoder: - cross_attn_mask_enc = cross_attn_mask( - patch_ids, - patch_lengths, - N, - patches_as_queries=True, - cross_attn_k=self.config.cross_attn_k, - window=self.config.cross_attn_window_encoder, - block_mask=self.config.cross_attn_use_flex_attention, - ) - - # Hashing and embedding - local_encoder_embeds = compute_hash_embeddings( - local_encoder_tokens=local_encoder_tokens, - local_encoder=self.local_encoder, - encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, - encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions, - encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size, - encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab, - ) - - # NOTE: Frequency-based n-gram embeddings removed as per paper - # The final BLT model uses only hash-based n-gram embeddings - - # Local encoder - (h_encoder, h_cross), cache_encoder = self.local_encoder( - tokens=local_encoder_tokens, - embeds=local_encoder_embeds, - patch_embeds=None, - cross_mask=cross_attn_mask_enc, - num_patches=patch_lengths.shape[1], - patch_ids=patch_ids, - ) - - # Downsampling - h = h_cross.view(bs, patch_lengths.shape[1], -1) - - # Global transformer - global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(BOE_ID) - rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id) - eos_patch_ids = patch_ids[rows, cols] - global_tokens[rows, eos_patch_ids] = self.config.eos_token_id - - h, _ = self.global_transformer( - embeds=h, - tokens=global_tokens, - ) - - # Unpatching - dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :] - - # Generate decoder patch IDs - decoder_patch_ids = self._decoder_patch_ids_from_lengths( - patch_lengths, nb_boe, local_decoder_tokens.shape[-1] - ) - assert ( - torch.max(decoder_patch_ids) + 1 <= h.shape[1] - ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" - assert ( - decoder_patch_ids.shape[1] == dec_embeds.shape[1] - ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" - - # Cross-attention decoder - if not self.config.cross_attn_decoder: - h = torch.gather( - h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) - ) - cross_attn_mask_dec = None - assert local_decoder_tokens.shape == h.shape[:-1] - else: - cross_attn_mask_dec = cross_attn_mask( - decoder_patch_ids, - patch_lengths, - N, - patches_as_queries=False, - cross_attn_k=self.config.cross_attn_k, - window=self.config.cross_attn_window_decoder, - block_mask=self.config.cross_attn_use_flex_attention, - ) - - # Local decoder - output, _ = self.local_decoder( - embeds=dec_embeds, - patch_embeds=h, - tokens=local_decoder_tokens, - cross_mask=cross_attn_mask_dec, - ) - return output - - def init_weights(self): - self.local_encoder.init_weights() - self.global_transformer.init_weights() - self.local_decoder.init_weights() - - if self.encoder_hash_tok_embedding is not None: - emb_std = self.local_encoder.dim ** (-0.5) - for emb in self.encoder_hash_tok_embedding: - nn.init.trunc_normal_( - emb.weight, - mean=0.0, - std=emb_std, - a=-3 * emb_std, - b=3 * emb_std, - ) - - -class BLTPatcher(BLTPreTrainedModel): - def __init__(self, config): - super().__init__(config) - - # Store config reference for later use - self.config = config - - # Extract patcher parameters from BLTConfig - self.dim = config.patcher_dim - self.init_base_std = config.patcher_init_base_std - self.attn_impl = config.patcher_attn_impl - self.attn_bias_type = config.patcher_attn_bias_type - self.init_std_factor = config.patcher_init_std_factor - self.max_seqlen = config.patcher_max_seqlen - n_layers = config.patcher_n_layers - n_heads = config.patcher_n_heads - head_dim = config.patcher_head_dim - rope_theta = config.patcher_rope_theta - rope_use_fp32_in_outer_product = config.patcher_rope_use_fp32_in_outer_product - norm_eps = config.patcher_norm_eps - vocab_size = config.patcher_vocab_size - weight_tying = config.patcher_weight_tying - sliding_window = config.patcher_sliding_window - eos_token_id = config.patcher_eos_token_id - - self.rope_embeddings = RotaryEmbedding( - theta=rope_theta, - head_dim=head_dim or self.dim // n_heads, - max_seqlen=self.max_seqlen, - rope_use_fp32_in_outer_product=rope_use_fp32_in_outer_product, - ) - # Handle both eos_id and eos_token_id for compatibility - self.eos_id = eos_token_id - - # Extract additional parameters for BLTTransformerLayer - n_kv_heads = getattr(config, 'patcher_n_kv_heads', None) if hasattr(config, 'patcher_dim') else getattr(config, 'n_kv_heads', None) - multiple_of = getattr(config, 'patcher_multiple_of', 256) if hasattr(config, 'patcher_dim') else getattr(config, 'multiple_of', 256) - ffn_dim_multiplier = getattr(config, 'patcher_ffn_dim_multiplier', None) if hasattr(config, 'patcher_dim') else getattr(config, 'ffn_dim_multiplier', None) - - # Create a simple parameter dict for BLTTransformerLayer - layer_params = { - 'dim': self.dim, - 'n_heads': n_heads, - 'head_dim': head_dim, - 'n_kv_heads': n_kv_heads, - 'rope_theta': rope_theta, - 'multiple_of': multiple_of, - 'ffn_dim_multiplier': ffn_dim_multiplier, - 'norm_eps': norm_eps, - } - - self.layers = nn.ModuleList() - for _ in range(n_layers): - self.layers.append(BLTTransformerLayer(layer_params)) - - # LMTransformer specific attributes - self.weight_tying = weight_tying - self.sliding_window = sliding_window - - assert vocab_size > 0 - - self.tok_embeddings = torch.nn.Embedding(vocab_size, self.dim) - - self.norm = RMSNorm(self.dim, eps=norm_eps) - - self.output = nn.Linear( - self.dim, - vocab_size, - bias=False, - ) - - if self.weight_tying: - self.output.weight = self.tok_embeddings.weight - - def forward( - self, - token_values: torch.Tensor, - target: Optional[torch.Tensor] = None, - tok_idx: Optional[torch.Tensor] = None, - mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, - attn_impl: str | None = None, - patch_size: Optional[int] = None, - include_next_token: bool = True, - threshold: Optional[float] = None, - threshold_add: Optional[float] = None, - monotonicity: bool = False, - max_patch_length: Optional[int] = None, - patching_batch_size: int = 1, # Changed from Optional[int] = None to int = 1 - device: Optional[str] = None, - enable_grad: bool = False, - ): - attn_impl = self.attn_impl if attn_impl is None else attn_impl - - # Handle chunked processing for entropy calculation - # grad_context = nullcontext() if enable_grad else torch.no_grad() - # with grad_context: - entropies = [] - preds = [] - max_length = min(getattr(self, "max_length", 8192), self.max_seqlen) - batch_numel = max_length * patching_batch_size - splits = torch.split(token_values.flatten(), batch_numel) - - for split in splits: - pad_size = (max_length - (split.numel() % max_length)) % max_length - pad = torch.zeros( - pad_size, dtype=split.dtype, device=split.device, requires_grad=False - ) - split = torch.cat((split, pad), dim=0) - split = split.reshape(-1, max_length) - if device is not None: - split = split.to(device) - - # Process chunk: embeddings -> layers -> output - bsz, seqlen = split.shape - h = self.tok_embeddings(split) - chunk_mask = create_causal_mask( - seqlen, - attn_impl, - self.attn_bias_type, - sliding_window=self.sliding_window, - tokens=split, - eos_id=self.eos_id, - ) - freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None) - - for i, layer in enumerate(self.layers): - h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=attn_impl) - - pred = self.output(self.norm(h)) - pred = pred.reshape(-1, pred.shape[-1])[ - : split.numel() - pad_size, : - ] # [batch_size * seq_len, vocab] - preds.append(pred) - pred_entropies = self.entropy(pred) - entropies.append(pred_entropies) - - concat_entropies = torch.cat(entropies, dim=0) - concat_entropies = concat_entropies.reshape(token_values.shape) - concat_preds = torch.cat(preds, dim=0) - concat_preds = concat_preds.reshape(token_values.shape[0], -1) - - # Always compute patch lengths from concatenated entropies - bs, seq_len = token_values.shape - seq_len_next_tok = seq_len + 1 if include_next_token else seq_len - - # Find patch start IDs based on entropy - if patch_size is not None: - patch_start_ids = self.find_entropy_patch_start_ids( - concat_entropies, - patch_size, - include_next_token=include_next_token, - threshold=threshold, - threshold_add=threshold_add, - monotonicity=monotonicity, - ) - patch_lengths = self.patch_lengths_from_start_ids( - patch_start_ids, seq_len_next_tok - ) - else: - # Default to byte-level patching - patch_lengths = torch.ones( - (bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device - ) - - # Apply any processing to patch lengths - if max_patch_length is not None: - # TODO: avoid going back to a list here. - patch_lengths = [ - self.split_large_numbers(pl, max_patch_length) - for pl in patch_lengths.tolist() - ] - max_len = max([len(pl) for pl in patch_lengths]) - patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] - patch_lengths = torch.tensor( - patch_lengths, dtype=token_values.dtype, device=token_values.device - ) - assert not check_non_zero_after_zero(patch_lengths) - # Find the last non-zero column index using argmax on a reversed version of the tensor - last_non_zero_col_reversed = ( - (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() - ) - # Slice the tensor up to the last non-zero column - patch_lengths = patch_lengths[ - :, : patch_lengths.shape[1] - last_non_zero_col_reversed - ] - - return concat_entropies, patch_lengths, concat_preds - - def reset_parameters(self, init_std=None): - self.norm.reset_parameters() - - def init_weights(self): - self.reset_parameters() - init_std = self.dim ** (-0.5) - nn.init.trunc_normal_( - self.tok_embeddings.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - self.rope_embeddings.reset_parameters() - for depth, layer in enumerate(self.layers): - factor = self.config.get_init_std_factor(depth) - layer.init_weights(self.init_base_std, factor) - - if not self.weight_tying: - nn.init.trunc_normal_( - self.output.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - @staticmethod - def entropy(scores): - """ - scores: [bs, seq_len, vocab] - returns [bs, seq_len] - - Computes the entropy for each token in the batch. - Note: uses natural log. - """ - log_probs = F.log_softmax(scores, dim=-1) - probs = torch.exp(log_probs) - p_log_p = log_probs * probs - entropy = -p_log_p.sum(dim=-1) - return entropy - - - - @staticmethod - def patch_start_ids_from_patch_start_mask(patch_start_mask): - bs, trunc_seq_len = patch_start_mask.shape - max_patches = patch_start_mask.sum(dim=1).max() - if max_patches == 0: - patch_start_ids = torch.full( - (bs, trunc_seq_len), - trunc_seq_len, - dtype=torch.long, - device=patch_start_mask.device, - ) - else: - patch_ids = ( - torch.arange(trunc_seq_len, device=patch_start_mask.device) - .unsqueeze(0) - .repeat(bs, 1) - ) - extra_patch_ids = torch.full( - (bs, trunc_seq_len), - trunc_seq_len, - dtype=torch.long, - device=patch_start_mask.device, - ) - all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) - patch_start_mask_padded = torch.cat( - (patch_start_mask, ~patch_start_mask), dim=1 - ) - patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape( - bs, trunc_seq_len - )[:, :max_patches] - return patch_start_ids - - @staticmethod - def patch_lengths_from_start_ids(patch_start_ids, seq_len): - """ - Calculate patch lengths from start ids. - start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then - the rest are filled to the seq len. - seq_len: ex: 7 length of the sequence - - returns the patch lengths: - [1, 6] for the above example. - """ - last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) - patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) - patch_lengths = patch_end_ids - patch_start_ids + 1 - assert torch.all(patch_lengths >= 0), f"{patch_lengths}" - assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" - return patch_lengths - - @staticmethod - def find_entropy_patch_start_ids( - entropies, - patch_size=None, - threshold=None, - threshold_add=None, - monotonicity=False, - include_next_token=True, - ): - """ - Use entropies to find the start ids of each patch. - Use patch_size or threshold to figure out the total number of patches to allocate. - - When threshold is not None the number of patches is not constant between - different sequences, but patches can be identified incrementally rather than - decided globally using the entire sequence. - """ - bs, seq_len = entropies.shape[:2] - - first_ids = ( - torch.tensor([0, 1], dtype=torch.long, device=entropies.device) - .unsqueeze(0) - .repeat(bs, 1) - ) - preds_truncation_len = first_ids.shape[ - 1 - ] # remove the first preds because they will be start of patches. - entropies = entropies[:, 1:] - if threshold is None: - num_patches = seq_len // patch_size - patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices - patch_start_ids = patch_start_ids.sort(dim=1).values - else: - patch_start_mask = entropies > threshold - if not include_next_token: - patch_start_mask = patch_start_mask[:, :-1] - # patch_start_mask[1:] |= tokens[:-1] < OFFSET - patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) - - patch_start_ids = torch.cat( - (first_ids, patch_start_ids + preds_truncation_len), dim=1 - ) - return patch_start_ids - - @staticmethod - def split_large_numbers(lst, m): - new_lst = [] - for i in lst: - if i > m: - while i > m: - new_lst.append(m) - i -= m - new_lst.append(i) - else: - new_lst.append(i) - assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" - return new_lst - - -def init_hash_embeddings( - config, - local_encoder_dim: int, - encoder_hash_byte_group_size: list, -): - """Initialize hash-based token embeddings for the BLT encoder.""" - if config.encoder_hash_byte_group_size is None: - return None - - embeddings = [] - emb_dim = local_encoder_dim - encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab - - for _ in range(config.encoder_hash_byte_group_nb_functions): - for _ in encoder_hash_byte_group_size: - embeddings.append( - nn.Embedding( - encoder_hash_byte_group_vocab, - emb_dim, - ) - ) - - return nn.ModuleList(embeddings) - - -__all__ = [ - "BLTPreTrainedModel", - "BLTModel", - "BLTPatcher", - "LocalEncoder", - "LocalDecoder", - "GlobalTransformer", -] \ No newline at end of file diff --git a/src/transformers/models/blt_wip/tokenization_blt.py b/src/transformers/models/blt_wip/tokenization_blt.py new file mode 100644 index 000000000000..cf57143de5dd --- /dev/null +++ b/src/transformers/models/blt_wip/tokenization_blt.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for BLT.""" + +import os +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +if TYPE_CHECKING: + from ...tokenization_utils_base import TextInput + +logger = logging.get_logger(__name__) + +# BLT tokenizer constants +SEP = " " +BOS_ID: int = 1 +EOS_ID: int = 2 +PAD_ID: int = -1 +BOE_ID: int = 0 +BPE_ID: int = 3 +OFFSET: int = 4 +BYTE_UNITS: int = 256 + +VOCAB_FILES_NAMES = {} # BLT doesn't require external vocab files + + +class BLTTokenizer(PreTrainedTokenizer): + """ + Construct a BLT tokenizer. Based on byte-level tokenization where each byte is treated as a token. + + This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset. + It supports special tokens for beginning of sequence (BOS), end of sequence (EOS), + beginning of example (BOE), and padding (PAD). + + Args: + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The padding token. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. Not used in BLT but kept for compatibility. + boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of example token, specific to BLT. + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add a `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + bos_token="", + eos_token="", + pad_token="", + unk_token="", + boe_token="", + add_bos_token=True, + add_eos_token=True, + clean_up_tokenization_spaces=False, + spaces_between_special_tokens=False, + **kwargs, + ): + # Store BLT-specific parameters first + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.vocab_size_unit_1 = BYTE_UNITS + self.offsetting_special_char = OFFSET + + # BLT token IDs (exactly like original) + self.boe_id = BOE_ID + self.bos_id = BOS_ID + self.eos_id = EOS_ID + self.pad_id = PAD_ID + self.bpe_id = BPE_ID + self.n_words = self.vocab_size_unit_1 + self.offsetting_special_char + + # Convert string tokens to AddedToken objects + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + self.boe_token = AddedToken(boe_token, normalized=False, special=True) if isinstance(boe_token, str) else boe_token + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + unk_token=unk_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.vocab_size_unit_1 + self.offsetting_special_char + + def get_vocab(self): + """Returns vocab as a dict""" + # Create a mapping for byte values + offset + vocab = {} + + # Add special tokens (with defensive checks) + if hasattr(self, 'bos_token'): + vocab[str(self.bos_token)] = self.bos_id + if hasattr(self, 'eos_token'): + vocab[str(self.eos_token)] = self.eos_id + if hasattr(self, 'pad_token'): + vocab[str(self.pad_token)] = self.pad_id + if hasattr(self, 'boe_token'): + vocab[str(self.boe_token)] = self.boe_id + + # Add byte tokens as string representations of byte values + vocab_size_unit_1 = getattr(self, 'vocab_size_unit_1', BYTE_UNITS) + offsetting_special_char = getattr(self, 'offsetting_special_char', OFFSET) + for i in range(vocab_size_unit_1): + vocab[str(i)] = i + offsetting_special_char + + # Add any additional tokens if available + if hasattr(self, 'added_tokens_encoder'): + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str, **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. For BLT, we work directly with byte values. + Returns a list of strings that represent the byte values. + """ + # Convert text to UTF-8 bytes, just like the original + try: + bytes_data = text.encode("utf-8", errors="ignore") + except UnicodeEncodeError: + bytes_data = text.encode("utf-8", errors="ignore") + + # Return string representations of byte values for the tokenizer framework + return [str(byte_val) for byte_val in bytes_data] + + def _convert_token_to_id(self, token: str) -> int: + """Converts a token (str) to an id using the vocab.""" + # Handle special tokens + if token == str(self.bos_token): + return self.bos_id + elif token == str(self.eos_token): + return self.eos_id + elif token == str(self.pad_token): + return self.pad_id + elif token == str(self.boe_token): + return self.boe_id + else: + try: + # Convert byte value string to int and add offset (like original) + byte_val = int(token) + if 0 <= byte_val <= 255: + return byte_val + self.offsetting_special_char + except ValueError: + pass + + # Check if it's in added tokens + return self.added_tokens_encoder.get(token, self.unk_token_id) + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) to a token (str) using the vocab.""" + # Handle special tokens + if index == self.bos_id: + return str(self.bos_token) + elif index == self.eos_id: + return str(self.eos_token) + elif index == self.pad_id: + return str(self.pad_token) + elif index == self.boe_id: + return str(self.boe_token) + elif index >= self.offsetting_special_char and index < self.vocab_size: + # Convert back to byte value (like original) + byte_val = index - self.offsetting_special_char + return str(byte_val) + else: + # Check added tokens + for token, token_id in self.added_tokens_encoder.items(): + if token_id == index: + return token + return str(self.unk_token) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """Converts a sequence of tokens to a single string.""" + byte_values = [] + + for token in tokens: + # Skip special tokens + if token in [str(self.bos_token), str(self.eos_token), str(self.pad_token), str(self.boe_token)]: + continue + + try: + # Convert token back to byte value (like original decode method) + byte_val = int(token) + if 0 <= byte_val <= 255: + byte_values.append(byte_val) + except ValueError: + continue + + # Convert byte values back to string (exactly like original) + try: + return bytes(byte_values).decode("utf-8", errors="ignore") + except (UnicodeDecodeError, ValueError): + return "" + + def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None): + """ + Encode text exactly like the original BLT tokenizer. + """ + if add_bos is None: + add_bos = self.add_bos_token + if add_eos is None: + add_eos = self.add_eos_token + + # Since bpe_delim=False, we use the simple byte encoding + tokens = bytes(text, encoding="utf-8", errors="ignore") + + # Offsetting (exactly like original) + tokens = [int(unit) + self.offsetting_special_char for unit in tokens] + + if add_bos: + tokens.insert(0, self.bos_id) + if add_eos: + tokens.append(self.eos_id) + + return tokens + + def decode(self, tokens: list[int], cut_at_eos: bool = False): + """ + Decode tokens exactly like the original BLT tokenizer. + """ + if cut_at_eos: + for k, t in enumerate(tokens): + if t == self.eos_id: + tokens = tokens[: k + 1] + break + return bytes( + [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0] + ).decode("utf-8", errors="ignore") + + def get_vocab_size(self) -> int: + """Get vocab size like the original tokenizer.""" + return self.vocab_size_unit_1 + self.offsetting_special_char + +__all__ = ["BLTTokenizer"] \ No newline at end of file diff --git a/src/transformers/models/blt_wip/tokenizers/__init__.py b/src/transformers/models/blt_wip/tokenizers/__init__.py deleted file mode 100644 index 71ca4b12c770..000000000000 --- a/src/transformers/models/blt_wip/tokenizers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/src/transformers/models/blt_wip/tokenizers/abstract_tokenizer.py b/src/transformers/models/blt_wip/tokenizers/abstract_tokenizer.py deleted file mode 100644 index ff31d655ae34..000000000000 --- a/src/transformers/models/blt_wip/tokenizers/abstract_tokenizer.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -import abc - - -class Tokenizer(abc.ABC): - @abc.abstractmethod - def encode(self, text: str, add_bos: bool, add_eos: bool): - pass - - @abc.abstractmethod - def decode(self, tokens: list[int]): - pass - - @abc.abstractmethod - def get_token_offsets(self, text: str, tokens: list[int] | None = None) -> tuple[list[str], list[int]]: - """Return the offsets of the tokens in the original text. Only used for evaluation.""" - pass - - @abc.abstractmethod - def get_vocab_size(self) -> int: - pass diff --git a/src/transformers/models/blt_wip/tokenizers/blt_tokenizer.py b/src/transformers/models/blt_wip/tokenizers/blt_tokenizer.py deleted file mode 100644 index 2d018ff90ead..000000000000 --- a/src/transformers/models/blt_wip/tokenizers/blt_tokenizer.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -import re - -from .abstract_tokenizer import Tokenizer -from .sentence_piece_tokenizer import SentencePieceTokenizer - - -SEP = " " -BOS_ID: int = 1 -EOS_ID: int = 2 -PAD_ID: int = -1 -BOE_ID: int = 0 -BPE_ID: int = 3 -OFFSET: int = 4 - -BYTE_UNITS: int = 256 - - -def convert_to_bytes(s): - # check if the output is a bytes like object of the format <0x00> - if re.match(r"<0x[0-9a-fA-F]+>", s): - return bytes.fromhex(s[3:-1]) - else: - return bytes(s, "utf-8", errors="ignore") - - -def text2bytes_bpe_delims( - text: str, - *, - bpe_tokenizer, - bpe_id: int, - offsetting_special_char: int, - add_bos: bool, - add_eos: bool, -): - cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos) - # merge the leading space tokens - leading_space_tokens = [] - other_bpe_tokens = [] - leading = True - for token in cur_bpe: - bpe_str = bpe_tokenizer.sp_model.id_to_piece(token) - if leading and all(c == "▁" for c in bpe_str): - leading_space_tokens.append(bpe_str) - else: - leading = False - other_bpe_tokens.append(bpe_str) - cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens - - # Remove the '▁' characters - bpe_strs = [] - for i, bpe_str in enumerate(cur_bpe_strs): - if len(bpe_strs) <= 1 and all([c == " " for s in bpe_strs for c in s]) and not all(c == "▁" for c in bpe_str): - # Remove leading space for first non space token. - bpe_str = bpe_str.replace("▁", "") - elif i == 0 and all(c == "▁" for c in bpe_str): - bpe_str = " " * (len(text) - len(text.lstrip(" "))) - else: - bpe_str = bpe_str.replace("▁", " ") - if len(bpe_str) > 0: - bpe_strs.append(bpe_str) - ex_seq = [] - # Convert bpe tokens to bytes - for s in bpe_strs: - byte_chunk = convert_to_bytes(s) - proc_chunk = [int(unit) for unit in byte_chunk] - ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk) - - return ex_seq - - -class BltTokenizer(Tokenizer): - def __init__( - self, - *, - vocab_size_unit_1: int = BYTE_UNITS, - bpe_delim: bool = False, - bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", - add_bos: bool = True, - add_eos: bool = True, - ): - self.add_bos = add_bos - self.add_eos = add_eos - self.vocab_size_unit_1 = vocab_size_unit_1 - self.boe_id = BOE_ID - self.bos_id = BOS_ID - self.eos_id = EOS_ID - self.pad_id = PAD_ID - self.bpe_id = BPE_ID - self.bpe_tokenizer_path = bpe_tokenizer_path - if bpe_delim: - self.bpe_tokenizer = SentencePieceTokenizer(model_path=self.bpe_tokenizer_path) - else: - self.bpe_tokenizer = None - self.bpe_delim = bpe_delim - self.offsetting_special_char = OFFSET - self.vocab_size_unit_1 = vocab_size_unit_1 - self.n_words = vocab_size_unit_1 + self.offsetting_special_char - - def get_vocab_size(self) -> int: - return self.n_words - - def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None): - if add_bos is None: - add_bos = self.add_bos - if add_eos is None: - add_eos = self.add_eos - - if self.bpe_delim: - tokens = text2bytes_bpe_delims( - text, - bpe_tokenizer=self.bpe_tokenizer, - bpe_id=self.bpe_id, - offsetting_special_char=self.offsetting_special_char, - add_bos=False, - add_eos=False, - ) - else: - tokens = bytes(text, encoding="utf-8", errors="ignore") - - # Offsetting - tokens = [int(unit) + self.offsetting_special_char for unit in tokens] - - if add_bos: - tokens.insert(0, self.bos_id) - if add_eos: - tokens.append(self.eos_id) - - return tokens - - def decode(self, tokens: list[int], cut_at_eos: bool = False): - if cut_at_eos: - for k, t in enumerate(tokens): - if t == self.eos_id: - tokens = tokens[: k + 1] - break - return bytes( - [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0] - ).decode("utf-8", errors="ignore") - - def get_token_offsets(self, text: str, tokens: list[int] | None = None): - # TODO: Figure out what this does - raise NotImplementedError() diff --git a/src/transformers/models/blt_wip/tokenizers/sentence_piece_tokenizer.py b/src/transformers/models/blt_wip/tokenizers/sentence_piece_tokenizer.py deleted file mode 100644 index fece8cbce7a8..000000000000 --- a/src/transformers/models/blt_wip/tokenizers/sentence_piece_tokenizer.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -import logging -import os - - -try: - from sentencepiece import SentencePieceProcessor - - has_sp = True -except ImportError: - has_sp = False - -from .abstract_tokenizer import Tokenizer - - -logger = logging.getLogger(__name__) - - -class SentencePieceTokenizer(Tokenizer): - def __init__(self, model_path: str, add_bos: bool = True, add_eos: bool = True) -> None: - assert os.path.isfile(model_path), model_path - self.sp_model = SentencePieceProcessor(model_file=model_path) - - logger.info(f"Reloaded SentencePiece model from {model_path}") - - # BOS / EOS token IDs - self.n_words: int = self.sp_model.vocab_size() - self.bos_id: int = self.sp_model.bos_id() - self.eos_id: int = self.sp_model.eos_id() - self.pad_id: int = self.sp_model.pad_id() - self.add_bos = add_bos - self.add_eos = add_eos - logger.info(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") - assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() - - def get_vocab_size(self) -> int: - return self.n_words - - def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None): - if add_bos is None: - add_bos = self.add_bos - - if add_eos is None: - add_eos = self.add_eos - assert type(s) is str - tokens = [self.bos_id] * add_bos + self.sp_model.encode(s) + [self.eos_id] * add_eos - return tokens - - def decode(self, tokens: list[int]): - return self.sp_model.decode(tokens) - - def get_token_offsets(self, text: str, tokens: list[int] | None = None) -> tuple[list[str], list[int]]: - pieces = self.sp_model.encode_as_immutable_proto(text).pieces - substrs = [p.surface for p in pieces] - offsets = [p.begin for p in pieces] - return substrs, offsets diff --git a/src/transformers/models/blt_wip/unified_blt_debug/config.json b/src/transformers/models/blt_wip/unified_blt_debug/config.json deleted file mode 100644 index 67ab54a27541..000000000000 --- a/src/transformers/models/blt_wip/unified_blt_debug/config.json +++ /dev/null @@ -1,144 +0,0 @@ -{ - "args": { - "alpha_depth": "disabled", - "architecture": "vanilla", - "attn_bias_type": "block_causal", - "attn_impl": "xformers", - "attn_to_keep": "all", - "conv_kernel_size": null, - "cross_attn_all_layers_decoder": true, - "cross_attn_all_layers_encoder": false, - "cross_attn_decoder": true, - "cross_attn_encoder": true, - "cross_attn_init_by_pooling": true, - "cross_attn_k": 2, - "cross_attn_nheads": 16, - "cross_attn_use_flex_attention": true, - "cross_attn_window_decoder": null, - "cross_attn_window_encoder": null, - "custom_bwd": false, - "dim": 512, - "dim_global": 2048, - "dim_local_decoder": 1024, - "dim_local_encoder": 1024, - "dim_patch_emb": null, - "dim_token": null, - "dim_token_emb": null, - "downsampling_by_pooling": "max", - "dropout": 0.0, - "encoder_enable_byte_group_hash": false, - "encoder_enable_byte_ngrams": false, - "encoder_hash_byte_group_nb_functions": 1, - "encoder_hash_byte_group_size": [ - 3, - 4, - 5, - 6, - 7, - 8 - ], - "encoder_hash_byte_group_vocab": 500002, - "encoder_lm_loss": false, - "encoder_ngram_table_dir": null, - "encoder_ngram_to_size_str": null, - "encoder_preds_low_entropy_toks": null, - "encoder_preds_random_toks": null, - "entropy_model_checkpoint_dir": null, - "entropy_model_is_ngram_model": false, - "eos_id": 2, - "ffn_dim_multiplier": 1.0, - "full_logging_n_layers": 4, - "fuse_sequence_parallel": false, - "global_local_decoder_residual_layer": null, - "head_dim": null, - "init_base_std": null, - "init_std_factor": "current_depth", - "init_use_depth": "current", - "init_use_gaussian": true, - "layer_ckpt": "none", - "local_attention_window_len": 512, - "log_patch_lengths": false, - "loss_parallel": false, - "max_encoder_seq_length": 24576, - "max_length": 256, - "max_patch_length": null, - "max_seqlen": 4096, - "monotonicity": false, - "multiple_of": 256, - "n_heads": 8, - "n_heads_global": 16, - "n_heads_local_decoder": 16, - "n_heads_local_encoder": 16, - "n_kv_heads": null, - "n_kv_heads_global": null, - "n_layers": 8, - "n_layers_global": 25, - "n_layers_local_decoder": 9, - "n_layers_local_encoder": 1, - "ngram_vocab_sizes": null, - "non_linearity": "swiglu", - "norm_affine": true, - "norm_eps": 1e-05, - "norm_type": "rmsnorm", - "output_size": -1, - "pad_to_max_length": true, - "patch_in_forward": true, - "patch_size": 4.5, - "patching_batch_size": 1, - "patching_device": "cuda", - "patching_mode": "entropy", - "patching_threshold": 1.335442066192627, - "patching_threshold_add": null, - "patching_thresholds_str": null, - "pm_size": 0, - "pre_norm": true, - "recompute_attn": false, - "recompute_fc1_out": false, - "recompute_fc3_out": false, - "rope_theta": 500000.0, - "rope_use_fp32_in_outer_product": true, - "seed": 42, - "sequence_parallel": false, - "share_encoder_decoder_emb": true, - "tie_local_encoder_decoder": false, - "tie_local_encoder_decoder_logits": false, - "tokenize_with_bpe_delimiter": false, - "use_fsdp": true, - "use_local_encoder_transformer": true, - "use_rope": true, - "vocab_size": 260, - "weight_tying": false - }, - "patch_in_forward": true, - "realtime_patching": true, - "patching_mode": "entropy", - "patch_size": 4.5, - "patching_threshold": 1.335442066192627, - "patching_threshold_add": null, - "max_patch_length": null, - "patching_batch_size": 1, - "patching_device": "cuda", - "monotonicity": false, - "patcher_vocab_size": 260, - "patcher_dim": 768, - "patcher_n_layers": 14, - "patcher_n_heads": 12, - "patcher_head_dim": null, - "patcher_n_kv_heads": null, - "patcher_max_seqlen": 8192, - "patcher_norm_eps": 1e-05, - "patcher_dropout": 0.0, - "patcher_sliding_window": 512, - "patcher_ffn_dim_multiplier": 1.0, - "patcher_multiple_of": 256, - "patcher_rope_theta": 10000.0, - "patcher_rope_use_fp32_in_outer_product": false, - "patcher_attn_impl": "xformers", - "patcher_attn_bias_type": "local_block_causal", - "patcher_init_base_std": null, - "patcher_init_std_factor": "current_depth", - "patcher_dim_token_emb": null, - "patcher_weight_tying": false, - "patcher_bos_token_id": 1, - "patcher_eos_token_id": 2 -} \ No newline at end of file From ad8c7a890290f9d0fd3c0bc6cba3ffddc0dc403e Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 16 Jun 2025 08:35:21 +0000 Subject: [PATCH 016/139] clean up dir --- .../convert_hf_blt_original_to_unified.py | 540 ------------------ 1 file changed, 540 deletions(-) delete mode 100644 src/transformers/models/blt_wip/convert_hf_blt_original_to_unified.py diff --git a/src/transformers/models/blt_wip/convert_hf_blt_original_to_unified.py b/src/transformers/models/blt_wip/convert_hf_blt_original_to_unified.py deleted file mode 100644 index dad247b19c62..000000000000 --- a/src/transformers/models/blt_wip/convert_hf_blt_original_to_unified.py +++ /dev/null @@ -1,540 +0,0 @@ -import argparse -import json -import logging -import os -from typing import Dict, Any, Optional - -import torch -from huggingface_hub import hf_hub_download, snapshot_download -from safetensors.torch import load_file, save_file - -from transformers.utils import logging as transformers_logging - -logger = transformers_logging.get_logger(__name__) -transformers_logging.set_verbosity_info() - -# For standalone execution, we'll skip the model validation to avoid import issues -# The script will create the unified config and weights files without testing model instantiation -ENABLE_MODEL_VALIDATION = False - -import sys -import os - -from transformers.models.blt_wip.modeling_blt_wip import BLTModel -from transformers.models.blt_wip.configuration_blt import BLTConfig - - -ENABLE_MODEL_VALIDATION = True - -def download_model_files(model_id: str, cache_dir: Optional[str] = None) -> Dict[str, str]: - """ - Download all necessary files from HuggingFace Hub. - - Args: - model_id: HuggingFace model ID (e.g., "facebook/blt-1b") - cache_dir: Optional cache directory - - Returns: - Dictionary with paths to downloaded files - """ - logger.info(f"Downloading model files from {model_id}...") - - try: - # Download main config - config_path = hf_hub_download( - repo_id=model_id, - filename="config.json", - cache_dir=cache_dir - ) - - # Download main model weights - weights_path = hf_hub_download( - repo_id=model_id, - filename="model.safetensors", - cache_dir=cache_dir - ) - - # Download entropy model params - entropy_params_path = hf_hub_download( - repo_id=model_id, - filename="entropy_model/params.json", - cache_dir=cache_dir - ) - - # Download entropy model weights - entropy_weights_path = hf_hub_download( - repo_id=model_id, - filename="entropy_model/consolidated.pth", - cache_dir=cache_dir - ) - - return { - "config": config_path, - "weights": weights_path, - "entropy_params": entropy_params_path, - "entropy_weights": entropy_weights_path - } - - except Exception as e: - logger.error(f"Failed to download files from {model_id}: {e}") - raise - - -def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]: - """ - Merge main configuration with entropy model parameters. - - Args: - config_path: Path to main config.json - entropy_params_path: Path to entropy_model/params.json - - Returns: - Merged configuration dictionary - """ - logger.info("Merging configurations...") - - # Load main configuration - with open(config_path, 'r') as f: - main_config = json.load(f) - - # Load entropy model parameters - with open(entropy_params_path, 'r') as f: - entropy_data = json.load(f) - - # Extract entropy model and patcher parameters - entropy_model_params = entropy_data.get("entropy_model", {}) - patcher_args = entropy_data.get("data", {}).get("patcher_args", {}) - - # Create unified configuration - unified_config = main_config.copy() - - # Ensure required main model parameters are present with correct types - # Sometimes the original config may have different key names - if "vocab_size" not in unified_config: - unified_config["vocab_size"] = int(main_config.get("vocab_size", 256)) - if "dim" not in unified_config: - unified_config["dim"] = int(main_config.get("dim", main_config.get("hidden_size", main_config.get("d_model", 512)))) - if "n_layers" not in unified_config: - unified_config["n_layers"] = int(main_config.get("n_layers", main_config.get("num_layers", main_config.get("num_hidden_layers", 8)))) - if "n_heads" not in unified_config: - unified_config["n_heads"] = int(main_config.get("n_heads", main_config.get("num_attention_heads", main_config.get("num_heads", 8)))) - if "max_seqlen" not in unified_config: - unified_config["max_seqlen"] = int(main_config.get("max_seqlen", main_config.get("max_position_embeddings", main_config.get("seq_length", 1024)))) - - # Ensure other integer parameters are properly typed - for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]: - if key in unified_config and not isinstance(unified_config[key], int): - unified_config[key] = int(unified_config[key]) - - # Convert all patch_size values to integers to avoid float/int type errors - patch_size = patcher_args.get("patch_size", 8) - if isinstance(patch_size, float): - patch_size = int(patch_size) - - # Add patching configuration - unified_config.update({ - "patch_in_forward": True, - "realtime_patching": True, - "patching_mode": "entropy", - - # Patcher arguments - "patch_size": patch_size, - "patching_threshold": patcher_args.get("threshold", 0.5), - "patching_threshold_add": patcher_args.get("threshold_add", 0.0), - "max_patch_length": patcher_args.get("max_patch_length"), - "patching_batch_size": patcher_args.get("patching_batch_size", 1), - "patching_device": patcher_args.get("patching_device", "cuda"), - "monotonicity": patcher_args.get("monotonicity", False), - - # Entropy model (patcher) architecture parameters - "patcher_vocab_size": int(entropy_model_params.get("vocab_size", 256)), - "patcher_dim": int(entropy_model_params.get("dim", 512)), - "patcher_n_layers": int(entropy_model_params.get("n_layers", 8)), - "patcher_n_heads": int(entropy_model_params.get("n_heads", 8)), - "patcher_head_dim": int(entropy_model_params.get("head_dim")) if entropy_model_params.get("head_dim") is not None else None, - "patcher_n_kv_heads": int(entropy_model_params.get("n_kv_heads")) if entropy_model_params.get("n_kv_heads") is not None else None, - "patcher_max_seqlen": int(entropy_model_params.get("max_seqlen", 1024)), - "patcher_norm_eps": entropy_model_params.get("norm_eps", 1e-5), - "patcher_dropout": entropy_model_params.get("dropout", 0.0), - "patcher_sliding_window": int(entropy_model_params.get("sliding_window", 512)) if entropy_model_params.get("sliding_window") is not None else None, - "patcher_ffn_dim_multiplier": entropy_model_params.get("ffn_dim_multiplier"), - "patcher_multiple_of": int(entropy_model_params.get("multiple_of", 256)), - "patcher_rope_theta": entropy_model_params.get("rope_theta", 10000.0), - "patcher_rope_use_fp32_in_outer_product": entropy_model_params.get("rope_use_fp32_in_outer_product", False), - "patcher_attn_impl": entropy_model_params.get("attn_impl", "sdpa"), - "patcher_attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"), - "patcher_init_base_std": entropy_model_params.get("init_base_std"), - "patcher_init_std_factor": entropy_model_params.get("init_std_factor", "disabled"), - "patcher_dim_token_emb": entropy_model_params.get("dim_token_emb"), - "patcher_weight_tying": entropy_model_params.get("weight_tying", False), - "patcher_bos_token_id": entropy_model_params.get("bos_token_id", 1), - "patcher_eos_token_id": entropy_model_params.get("eos_token_id", 2), - }) - - logger.info(f"Merged configuration with {len(unified_config)} parameters") - return unified_config - - -def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]: - """ - Merge main model weights with entropy model weights. - - Args: - weights_path: Path to main model.safetensors - entropy_weights_path: Path to entropy_model/consolidated.pth - - Returns: - Merged state dictionary - """ - logger.info("Merging model weights...") - - # Load main model weights - main_weights = load_file(weights_path) - logger.info(f"Loaded main model weights: {len(main_weights)} tensors") - - # Load entropy model weights - entropy_weights = torch.load(entropy_weights_path, map_location='cpu', weights_only=True) - - # Handle nested entropy model structure - if 'model' in entropy_weights: - entropy_weights = entropy_weights['model'] - elif 'state_dict' in entropy_weights: - entropy_weights = entropy_weights['state_dict'] - - logger.info(f"Loaded entropy model weights: {len(entropy_weights)} tensors") - - # Create unified state dict - unified_weights = main_weights.copy() - - # Add entropy model weights with "patcher." prefix - for key, tensor in entropy_weights.items(): - patcher_key = f"patcher.{key}" - unified_weights[patcher_key] = tensor - - logger.info(f"Merged weights: {len(unified_weights)} tensors total") - return unified_weights - - -def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): - """ - Create tokenizer configuration file. - - Args: - output_dir: Output directory - config: Model configuration - """ - logger.info("Creating tokenizer configuration...") - - tokenizer_config = { - "tokenizer_class": "BltTokenizer", - "vocab_size": config.get("vocab_size", 256), - "model_max_length": config.get("max_seqlen", 1024), - "add_bos_token": True, - "add_eos_token": True, - "bos_token": "", - "eos_token": "", - "pad_token": "", - "unk_token": "", - } - - tokenizer_path = os.path.join(output_dir, "tokenizer_config.json") - with open(tokenizer_path, 'w') as f: - json.dump(tokenizer_config, f, indent=2) - - logger.info(f"Tokenizer config saved to {tokenizer_path}") - - -def validate_unified_model(config: Dict[str, Any], weights: Dict[str, torch.Tensor]): - """ - Validate the unified model configuration and weights. - - Args: - config: Unified configuration - weights: Unified weights - """ - logger.info("Validating unified model...") - - # Check required configuration keys - required_keys = [ - "vocab_size", "dim", "n_layers", "n_heads", - "patch_in_forward", "patcher_vocab_size", "patcher_dim" - ] - - missing_keys = [key for key in required_keys if key not in config] - if missing_keys: - logger.warning(f"Missing configuration keys: {missing_keys}") - - # Check for patcher weights - patcher_weights = [key for key in weights.keys() if key.startswith("patcher.")] - if not patcher_weights: - logger.warning("No patcher weights found in unified weights") - else: - logger.info(f"Found {len(patcher_weights)} patcher weight tensors") - - # Check for main model weights - main_weights = [key for key in weights.keys() if not key.startswith("patcher.")] - logger.info(f"Found {len(main_weights)} main model weight tensors") - - # Try to create the model with the configuration (if imports are available) - if ENABLE_MODEL_VALIDATION and BLTConfig is not None and BLTModel is not None: - try: - logger.info("Testing model instantiation...") - - # Debug: Print config keys to help diagnose issues - logger.debug(f"Config keys: {list(config.keys())}") - logger.debug(f"Config vocab_size: {config.get('vocab_size')} (type: {type(config.get('vocab_size'))})") - logger.debug(f"Config dim: {config.get('dim')} (type: {type(config.get('dim'))})") - - blt_config = BLTConfig(**config) - model = BLTModel(blt_config) - logger.info("✓ Model instantiation successful") - - # Try to load the weights - logger.info("Testing weight loading...") - try: - missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False) - if missing_keys: - logger.warning(f"Missing keys during weight loading: {missing_keys[:5]}...") # Show first 5 - if unexpected_keys: - logger.warning(f"Unexpected keys during weight loading: {unexpected_keys[:5]}...") # Show first 5 - logger.info("✓ Weight loading successful") - except Exception as weight_error: - logger.warning(f"Weight loading failed: {weight_error}") - logger.info("Model instantiation successful, but weight loading had issues") - - except Exception as e: - logger.error(f"Model validation failed: {e}") - logger.debug(f"Full error details:", exc_info=True) - logger.warning("Model may not be compatible with modeling_blt_wip.py") - logger.info("You can still use the converted files and test manually") - else: - logger.info("Skipping model instantiation test (BLT classes not available)") - logger.info("You can test the model manually after conversion") - - logger.info("Model validation completed") - - -def convert_hf_blt_to_unified( - model_id: str, - output_dir: str, - config_name: str = "config.json", - weights_name: str = "pytorch_model.bin", - safe_serialization: bool = True, - cache_dir: Optional[str] = None, - validate: bool = True, -) -> None: - """ - Convert BLT model from HuggingFace Hub format to unified format. - - Args: - model_id: HuggingFace model ID (e.g., "facebook/blt-1b") - output_dir: Output directory for unified model - config_name: Name for unified config file - weights_name: Name for unified weights file - safe_serialization: Whether to use safetensors format - cache_dir: Cache directory for downloads - validate: Whether to validate the unified model - """ - logger.info(f"Converting {model_id} to unified format...") - - # Download model files - file_paths = download_model_files(model_id, cache_dir) - - # Merge configurations - unified_config = merge_configurations( - file_paths["config"], - file_paths["entropy_params"] - ) - - # Merge weights - unified_weights = merge_weights( - file_paths["weights"], - file_paths["entropy_weights"] - ) - - # Validate if requested - if validate: - validate_unified_model(unified_config, unified_weights) - - # Create output directory - os.makedirs(output_dir, exist_ok=True) - - # Save unified configuration - config_path = os.path.join(output_dir, config_name) - with open(config_path, 'w') as f: - json.dump(unified_config, f, indent=2) - logger.info(f"Unified config saved to {config_path}") - - # Save unified weights - if safe_serialization and weights_name.endswith('.bin'): - weights_name = weights_name.replace('.bin', '.safetensors') - elif not safe_serialization and weights_name.endswith('.safetensors'): - weights_name = weights_name.replace('.safetensors', '.bin') - - weights_path = os.path.join(output_dir, weights_name) - if safe_serialization: - save_file(unified_weights, weights_path) - else: - torch.save(unified_weights, weights_path) - logger.info(f"Unified weights saved to {weights_path}") - - # Create tokenizer config - create_tokenizer_config(output_dir, unified_config) - - # Create README - readme_path = os.path.join(output_dir, "README.md") - with open(readme_path, 'w') as f: - f.write(f"""# Unified BLT Model - -This model was converted from {model_id} to unified format compatible with modeling_blt_wip.py. - -## Files - -- `{config_name}`: Unified configuration (main config + entropy model params) -- `{weights_name}`: Unified weights (main model + entropy model weights with "patcher." prefix) -- `tokenizer_config.json`: Tokenizer configuration - -## Usage - -```python -import torch -import json -from modeling_blt_wip import BLTModel, BLTConfig - -# Load configuration -with open('{config_name}', 'r') as f: - config_dict = json.load(f) - -config = BLTConfig(**config_dict) - -# Load model -model = BLTModel(config) - -# Load weights -if '{weights_name}'.endswith('.safetensors'): - from safetensors.torch import load_file - state_dict = load_file('{weights_name}') -else: - state_dict = torch.load('{weights_name}', map_location='cpu') - -model.load_state_dict(state_dict, strict=False) -``` - -## Original Model - -Converted from: {model_id} -""") - - logger.info(f"Conversion completed! Unified model saved to: {output_dir}") - - -def main(): - parser = argparse.ArgumentParser( - description="Convert BLT models from HuggingFace Hub format to unified format", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Convert facebook/blt-1b to unified format - python convert_hf_blt_to_unified.py \\ - --model_id facebook/blt-1b \\ - --output_dir ./unified_blt_1b - - # Convert with custom file names - python convert_hf_blt_to_unified.py \\ - --model_id facebook/blt-7b \\ - --output_dir ./unified_blt_7b \\ - --config_name unified_config.json \\ - --weights_name unified_model.safetensors - - # Convert without validation - python convert_hf_blt_to_unified.py \\ - --model_id facebook/blt-1b \\ - --output_dir ./my_blt \\ - --no_validate - """ - ) - - # Required arguments (with defaults for debugging) - parser.add_argument( - "--model_id", - type=str, - default="facebook/blt-1b", - help="HuggingFace model ID (e.g., facebook/blt-1b)" - ) - parser.add_argument( - "--output_dir", - type=str, - default="./unified_blt_debug", - help="Output directory for unified model" - ) - - # Optional arguments - parser.add_argument( - "--config_name", - type=str, - default="config.json", - help="Name for unified config file (default: config.json)" - ) - parser.add_argument( - "--weights_name", - type=str, - default="pytorch_model.bin", - help="Name for unified weights file (default: pytorch_model.bin)" - ) - parser.add_argument( - "--safe_serialization", - action="store_true", - default=True, - help="Use safetensors format for weights (default: True)" - ) - parser.add_argument( - "--no_safe_serialization", - dest="safe_serialization", - action="store_false", - help="Use .bin format instead of safetensors" - ) - parser.add_argument( - "--cache_dir", - type=str, - default=None, - help="Cache directory for downloads" - ) - parser.add_argument( - "--no_validate", - dest="validate", - action="store_false", - default=True, - help="Skip model validation" - ) - parser.add_argument( - "--debug", - action="store_true", - default=True, # Enable debug by default for easier debugging - help="Enable debug logging" - ) - - args = parser.parse_args() - - # Setup logging - if args.debug: - transformers_logging.set_verbosity_debug() - logging.basicConfig(level=logging.DEBUG) - - # Run conversion - try: - convert_hf_blt_to_unified( - model_id=args.model_id, - output_dir=args.output_dir, - config_name=args.config_name, - weights_name=args.weights_name, - safe_serialization=args.safe_serialization, - cache_dir=args.cache_dir, - validate=args.validate, - ) - except Exception as e: - logger.error(f"Conversion failed: {e}") - raise - - -if __name__ == "__main__": - main() \ No newline at end of file From aff63d6b56da2f6835669117e48598e46a3eae5e Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 16 Jun 2025 11:34:39 +0000 Subject: [PATCH 017/139] cleaned up modeling further --- src/demo_hf.py | 2 +- .../models/blt_wip/modeling_blt.py | 433 +++++------------- 2 files changed, 106 insertions(+), 329 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index a1fa640c5c42..65715861f45f 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -3,7 +3,7 @@ import torch -from transformers.models.blt_wip.modeling_blt_wip import BLTModel +from transformers.models.blt_wip.modeling_blt import BLTModel from transformers.models.blt_wip.tokenization_blt import BLTTokenizer diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index fe846fc0ab9f..831791eabeaf 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -16,17 +16,6 @@ PatchingModeEnum, ) - -SEP = " " -BOS_ID: int = 1 -EOS_ID: int = 2 -PAD_ID: int = -1 -BOE_ID: int = 0 -BPE_ID: int = 3 -OFFSET: int = 4 - -BYTE_UNITS: int = 256 - RMSNorm = nn.RMSNorm logger = logging.getLogger() @@ -417,13 +406,6 @@ def forward( out = h + self.feed_forward(h_norm) return out - - - -def rightpad(seq, pad_id, max_len): - return seq + [pad_id] * (max_len - len(seq)) - - def check_non_zero_after_zero(tensor): zero_mask = tensor == 0 shifted_mask = torch.cat( @@ -436,17 +418,6 @@ def check_non_zero_after_zero(tensor): non_zero_after_zero = (tensor != 0) & shifted_mask return non_zero_after_zero.any() - -def fill_tokens(tokens, patch_size, fill_id): - batch_size, seq_len = tokens.shape - if seq_len % patch_size == 0: - return tokens - else: - remaining = patch_size - seq_len % patch_size - final_padding = tokens.new(batch_size, remaining).fill_(fill_id) - return torch.cat((tokens, final_padding), dim=1) - - def rolling_polynomial_hash(t, hash_func_nb: int = 0): primes = [ 1000000007, @@ -569,103 +540,35 @@ def patch_mask(b, h, q_idx, kv_idx): ) # [bs, 1, q_len, kv_len] -def get_blt_input( - tokens: torch.Tensor, - enforce_patch_size_multiple: bool, - nb_boe: torch.Tensor, - patch_size: int, - boe_id: int, -): - """ - This function returns X_et, X_gt and X_dt, the encoder, global, and decoder - tokens respectively. - - Consider the input and target sequences: - X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13] - Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14] - with patch_size=4 - - Note 1: that there will be no special tokens introduced at the patch level. - Note 2: X_e needs to be trimmed to be passed to Global - - Current without boe: - X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] - X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch - X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] - Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] - - --> lag fix: - X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]] - X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]] - X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] - Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] - - Dynamic (current): - X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos] - Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] - - entropy patching: - input: 7, bos, 9, 10 - pred (high entropy): eos, 8, 10, eos - - X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos] - X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]] - X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]] - Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] - - --> lag fix no boe (force single byte first patch): - X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] - X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch - X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] - Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] - - input: 4, 7, bos, 9, 10 - pred (high entropy): 5, eos, 8, 10, eos - - X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] - X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch - X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] - Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] - - Handle the last byte properly. - patch_lengths = [1, 1, 3, 2, 2 1 2 2 1] - X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] - X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch - X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]] - Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]] - - - bpe delim - X_et = [[3,4,5,6,7,,eos,bos,,8,9,,10,,eos,bos,11,12] - X_g = [[3], [4,5,6,7,], [eos,bos,], .. - X_dt = [[3,4,5,6,7], [,eos,bos], [,bos,8], .. - Y = [4,5,6,7,, eos,bos, 8,9,, .. - - - Note 1: that there will be no special tokens introduced at the patch level. - Note 2: X_e needs to be trimmed to be passed to Global - """ - batch_size, seq_len = tokens.shape - local_encoder_tokens = tokens - local_decoder_tokens = tokens +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor: + if max_patch_length is None: + return patch_lengths - if nb_boe > 0: - padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id) - local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1) - # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id) + batch_size = patch_lengths.size(0) + split_all = [] + max_len = 0 - # create global tokens, contains boe tokens and eos - # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) - # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size) - # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:] - # global_tokens += global_tokens.eq(0).int() * boe_id - # TODO: fix this when we want to use block causal in the global. + for seq in patch_lengths: + splits = [] + for length in seq[seq > 0]: + # Split long patches into max_patch_length chunks + full, rem = divmod(length.item(), max_patch_length) + splits.extend([max_patch_length] * full + ([rem] if rem else [])) + split_all.append(splits) + max_len = max(max_len, len(splits)) - if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0: - local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + # Pad sequences to the maximum length + padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) + for i, splits in enumerate(split_all): + if splits: + padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) - return local_encoder_tokens, None, local_decoder_tokens + # Trim trailing columns that are all zeros + last_non_zero = (padded != 0).flip(1).int().argmax(1).min() + if last_non_zero < padded.shape[1]: + padded = padded[:, :padded.shape[1] - last_non_zero] + return padded class LocalModelBase(nn.Module): def __init__(self, config: BLTConfig, component_type: str = "encoder"): @@ -705,7 +608,7 @@ def __init__(self, config: BLTConfig, component_type: str = "encoder"): self.cross_attn_k = getattr(config, "cross_attn_k", None) self.eos_id = config.eos_token_id - self.boe_id = BOE_ID + self.boe_id = config.boe_id # Initialize cross attention layers as None (will be set by subclasses if needed) self.cross_attn_layers = None @@ -1081,18 +984,12 @@ class GlobalTransformer(nn.Module): def __init__(self, config): super().__init__() - # Store config for later use self.config = config self.dim = config.dim_global - self.init_base_std = config.init_base_std - self.attn_impl = config.attn_impl - self.attn_bias_type = config.attn_bias_type - self.init_std_factor = config.init_std_factor - self.max_seqlen = config.max_seqlen self.rope_embeddings = RotaryEmbedding( theta=config.rope_theta, - head_dim=config.head_dim or config.dim_global // config.n_heads_global, + head_dim=config.head_dim or self.config.dim_global // config.n_heads_global, max_seqlen=config.max_seqlen, rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, ) @@ -1115,10 +1012,6 @@ def __init__(self, config): for _ in range(config.n_layers_global): self.layers.append(BLTTransformerLayer(layer_params)) - # GlobalTransformer specific attributes - self.dropout = config.dropout - self.dim_token_emb = config.global_dim_patch_emb - self.token_embedding_projection = None if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim: self.token_embedding_projection = nn.Linear( @@ -1144,8 +1037,8 @@ def forward( if mask is not None else create_causal_mask( seqlen, - self.attn_impl, - self.attn_bias_type, + self.config.attn_impl, + self.config.attn_bias_type, tokens=tokens, eos_id=self.eos_id, ) @@ -1154,12 +1047,12 @@ def forward( if self.token_embedding_projection is not None and h.shape[-1] != self.dim: h = self.token_embedding_projection(h) - h = F.dropout(h, p=self.dropout, training=self.training) + h = F.dropout(h, p=self.config.dropout, training=self.training) - freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) + freq_cis = self.rope_embeddings(seqlen=self.config.max_seqlen, tok_idx=tok_idx) for i, layer in enumerate(self.layers): - h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) + h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=self.config.attn_impl) return h, cache @@ -1304,72 +1197,6 @@ def __init__(self, config: BLTConfig): else: self.patcher = None - - - def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: - """ - Convert patch lengths to patch IDs for each token position. - - For each token position in the sequence, determines which patch it belongs to. - - Args: - patch_lengths: [batch_size, num_patches] - length of each patch - seq_len: total sequence length - - Returns: - patch_ids: [batch_size, seq_len] - patch index for each token position - - Example: - patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1 - seq_len = 10 - Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]] - # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3 - """ - batch_size, num_patches = patch_lengths.shape - - # Create patch start positions: [0, 3, 5, 9] for the example above - patch_starts = torch.cat( - [ - torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), - patch_lengths.cumsum(dim=-1)[:, :-1], # cumsum without the final total - ], - dim=-1, - ) - - # For each token position, find which patch it belongs to - # by finding the rightmost patch start that's <= the position - token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1] - - # Broadcasting: patch_starts[batch, patch] <= token_positions[position] - # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t - position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1) - - # Count how many patch starts are <= each position, then subtract 1 to get patch index - patch_ids = position_ge_patch_start.sum(dim=-1) - 1 - - return patch_ids - - def _decoder_patch_ids_from_lengths(self, patch_lengths: torch.Tensor, nb_boe: int, seq_len: int) -> torch.Tensor: - """ - Create decoder patch IDs by skipping the first encoder patch. - - The decoder starts after the first patch (which contains BOE tokens), - so we need to map decoder positions to the remaining patches. - - Args: - patch_lengths: [batch_size, num_patches] from encoder - nb_boe: number of beginning-of-example tokens in first patch - seq_len: decoder sequence length - - Returns: - decoder_patch_ids: [batch_size, seq_len] mapping decoder positions to patch indices - """ - # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens) - decoder_patch_lengths = patch_lengths[:, 1:] - - # Create patch IDs for the decoder sequence using the remaining patches - return self._patch_ids_from_lengths(decoder_patch_lengths, seq_len) - def forward( self, tokens: torch.Tensor, @@ -1380,15 +1207,7 @@ def forward( bs, N = tokens.shape # Batch size and sequence length - # Get megabyte inputs - nb_boe = int(0 if self.config.patching_mode != "" else self.config.patch_size - 1) - local_encoder_tokens, _, local_decoder_tokens = get_blt_input( - tokens=tokens, - enforce_patch_size_multiple=False, - nb_boe=nb_boe, - patch_size=self.config.patch_size, - boe_id=BOE_ID, - ) + local_encoder_tokens, local_decoder_tokens = tokens, tokens # Patching if patch_lengths is None: @@ -1403,8 +1222,6 @@ def forward( patch_size=self.config.patch_size, include_next_token=True, threshold=self.config.patching_threshold, - threshold_add=self.config.patching_threshold_add, - monotonicity=self.config.monotonicity, max_patch_length=self.config.max_patch_length, patching_batch_size=self.config.patching_batch_size, device=self.config.patching_device, @@ -1417,34 +1234,16 @@ def forward( (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device ) - # Apply any processing to patch lengths - if self.config.max_patch_length is not None: - # TODO: avoid going back to a list here. - patch_lengths = [ - BLTPatcher.split_large_numbers(pl, self.config.max_patch_length) - for pl in patch_lengths.tolist() - ] - max_len = max([len(pl) for pl in patch_lengths]) - patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] - patch_lengths = torch.tensor( - patch_lengths, dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device - ) - assert not check_non_zero_after_zero(patch_lengths) - # Find the last non-zero column index using argmax on a reversed version of the tensor - last_non_zero_col_reversed = (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() - # Slice the tensor up to the last non-zero column - patch_lengths = patch_lengths[:, : patch_lengths.shape[1] - last_non_zero_col_reversed] - else: - if nb_boe > 0: - patch_lengths[:, 0] += nb_boe + patch_lengths = process_patch_lengths(patch_lengths, self.config.max_patch_length) - assert torch.min(patch_lengths) >= 0 + + #assert torch.min(patch_lengths) >= 0 # Generate patch IDs from patch_lengths patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1]) - assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), ( - f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" - ) + # assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), ( + # f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" + # ) cross_attn_mask_enc = None # Cross-attention encoder @@ -1486,7 +1285,7 @@ def forward( h = h_cross.view(bs, patch_lengths.shape[1], -1) # Global transformer - global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(BOE_ID) + global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.config.boe_id) rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id) eos_patch_ids = patch_ids[rows, cols] global_tokens[rows, eos_patch_ids] = self.config.eos_token_id @@ -1497,20 +1296,21 @@ def forward( ) # Unpatching - dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :] - # Generate decoder patch IDs - decoder_patch_ids = self._decoder_patch_ids_from_lengths(patch_lengths, nb_boe, local_decoder_tokens.shape[-1]) - assert torch.max(decoder_patch_ids) + 1 <= h.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" - assert decoder_patch_ids.shape[1] == dec_embeds.shape[1], ( - f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" - ) + dec_embeds = h_encoder + + # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens), so we need to map decoder positions to the remaining patches. + decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], local_decoder_tokens.shape[-1]) + # assert torch.max(decoder_patch_ids) + 1 <= h.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" + # assert decoder_patch_ids.shape[1] == dec_embeds.shape[1], ( + # f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" + # ) # Cross-attention decoder if not self.config.cross_attn_decoder: h = torch.gather(h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])) cross_attn_mask_dec = None - assert local_decoder_tokens.shape == h.shape[:-1] + # assert local_decoder_tokens.shape == h.shape[:-1] else: cross_attn_mask_dec = cross_attn_mask( decoder_patch_ids, @@ -1530,7 +1330,50 @@ def forward( cross_mask=cross_attn_mask_dec, ) return output + + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: + """ + Convert patch lengths to patch IDs for each token position. + + For each token position in the sequence, determines which patch it belongs to. + + Args: + patch_lengths: [batch_size, num_patches] - length of each patch + seq_len: total sequence length + Returns: + patch_ids: [batch_size, seq_len] - patch index for each token position + + Example: + patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1 + seq_len = 10 + Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]] + # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3 + """ + batch_size, num_patches = patch_lengths.shape + + # Create patch start positions: [0, 3, 5, 9] for the example above + patch_starts = torch.cat( + [ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1], # cumsum without the final total + ], + dim=-1, + ) + + # For each token position, find which patch it belongs to + # by finding the rightmost patch start that's <= the position + token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1] + + # Broadcasting: patch_starts[batch, patch] <= token_positions[position] + # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t + position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1) + + # Count how many patch starts are <= each position, then subtract 1 to get patch index + patch_ids = position_ge_patch_start.sum(dim=-1) - 1 + + return patch_ids + class BLTPatcher(BLTPreTrainedModel): def __init__(self, config): @@ -1542,25 +1385,6 @@ def __init__(self, config): max_seqlen=config.patcher_max_seqlen, rope_use_fp32_in_outer_product=config.patcher_rope_use_fp32_in_outer_product, ) - # Handle both eos_id and eos_token_id for compatibility - self.eos_id = config.patcher_eos_token_id - - # Extract additional parameters for BLTTransformerLayer - n_kv_heads = ( - getattr(config, "patcher_n_kv_heads", None) - if hasattr(config, "patcher_dim") - else getattr(config, "n_kv_heads", None) - ) - multiple_of = ( - getattr(config, "patcher_multiple_of", 256) - if hasattr(config, "patcher_dim") - else getattr(config, "multiple_of", 256) - ) - ffn_dim_multiplier = ( - getattr(config, "patcher_ffn_dim_multiplier", None) - if hasattr(config, "patcher_dim") - else getattr(config, "ffn_dim_multiplier", None) - ) self.layers = nn.ModuleList() for _ in range(config.patcher_n_layers): @@ -1570,19 +1394,16 @@ def __init__(self, config): "dim": config.patcher_dim, "n_heads": config.patcher_n_heads, "head_dim": config.patcher_head_dim, - "n_kv_heads": n_kv_heads, + "n_kv_heads": config.patcher_n_kv_heads, "rope_theta": config.patcher_rope_theta, - "multiple_of": multiple_of, - "ffn_dim_multiplier": ffn_dim_multiplier, + "multiple_of": config.patcher_multiple_of, + "ffn_dim_multiplier": config.patcher_ffn_dim_multiplier, "norm_eps": config.patcher_norm_eps, } ) ) - # LMTransformer specific attributes - self.sliding_window = config.patcher_sliding_window - - assert config.patcher_vocab_size > 0 + #assert config.patcher_vocab_size > 0 self.tok_embeddings = torch.nn.Embedding(config.patcher_vocab_size, config.patcher_dim) @@ -1597,26 +1418,18 @@ def __init__(self, config): def forward( self, token_values: torch.Tensor, - target: Optional[torch.Tensor] = None, - tok_idx: Optional[torch.Tensor] = None, - mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, - attn_impl: str | None = None, patch_size: Optional[int] = None, include_next_token: bool = True, threshold: Optional[float] = None, - threshold_add: Optional[float] = None, - monotonicity: bool = False, max_patch_length: Optional[int] = None, patching_batch_size: int = 1, device: Optional[str] = None, - enable_grad: bool = False, ): - attn_impl = self.config.patcher_attn_impl if attn_impl is None else attn_impl # Handle chunked processing for entropy calculation entropies = [] preds = [] - max_length = min(getattr(self, "max_length", 8192), self.config.patcher_max_seqlen) + max_length = self.config.patcher_max_seqlen batch_numel = max_length * patching_batch_size splits = torch.split(token_values.flatten(), batch_numel) @@ -1633,16 +1446,16 @@ def forward( h = self.tok_embeddings(split) chunk_mask = create_causal_mask( seqlen, - attn_impl, + self.config.patcher_attn_impl , self.config.patcher_attn_bias_type, - sliding_window=self.sliding_window, + sliding_window=self.config.patcher_sliding_window, tokens=split, - eos_id=self.eos_id, + eos_id=self.config.eos_id, ) freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None) for i, layer in enumerate(self.layers): - h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=attn_impl) + h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=self.config.patcher_attn_impl ) pred = self.output(self.norm(h)) pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] @@ -1650,10 +1463,8 @@ def forward( pred_entropies = self.entropy(pred) entropies.append(pred_entropies) - concat_entropies = torch.cat(entropies, dim=0) - concat_entropies = concat_entropies.reshape(token_values.shape) - concat_preds = torch.cat(preds, dim=0) - concat_preds = concat_preds.reshape(token_values.shape[0], -1) + concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) + concat_preds = torch.cat(preds, dim=0).reshape(token_values.shape[0], -1) # Always compute patch lengths from concatenated entropies bs, seq_len = token_values.shape @@ -1665,34 +1476,17 @@ def forward( concat_entropies, patch_size, include_next_token=include_next_token, - threshold=threshold, - threshold_add=threshold_add, - monotonicity=monotonicity, + threshold=threshold ) patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok) else: # Default to byte-level patching patch_lengths = torch.ones((bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) - # Apply any processing to patch lengths - if max_patch_length is not None: - # TODO: avoid going back to a list here. - patch_lengths = [self.split_large_numbers(pl, max_patch_length) for pl in patch_lengths.tolist()] - max_len = max([len(pl) for pl in patch_lengths]) - patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] - patch_lengths = torch.tensor(patch_lengths, dtype=token_values.dtype, device=token_values.device) - assert not check_non_zero_after_zero(patch_lengths) - # Find the last non-zero column index using argmax on a reversed version of the tensor - last_non_zero_col_reversed = (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() - # Slice the tensor up to the last non-zero column - patch_lengths = patch_lengths[:, : patch_lengths.shape[1] - last_non_zero_col_reversed] - + patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) return concat_entropies, patch_lengths, concat_preds - - - @staticmethod def entropy(scores): """ @@ -1755,8 +1549,6 @@ def find_entropy_patch_start_ids( entropies, patch_size=None, threshold=None, - threshold_add=None, - monotonicity=False, include_next_token=True, ): """ @@ -1786,21 +1578,6 @@ def find_entropy_patch_start_ids( patch_start_ids = torch.cat((first_ids, patch_start_ids + preds_truncation_len), dim=1) return patch_start_ids - @staticmethod - def split_large_numbers(lst, m): - new_lst = [] - for i in lst: - if i > m: - while i > m: - new_lst.append(m) - i -= m - new_lst.append(i) - else: - new_lst.append(i) - assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" - return new_lst - - def init_hash_embeddings( config, local_encoder_dim: int, @@ -1833,4 +1610,4 @@ def init_hash_embeddings( "LocalEncoder", "LocalDecoder", "GlobalTransformer", -] \ No newline at end of file +] From f552e2759c55c61872e9ae7eb1ca46ffd1554f1e Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 16 Jun 2025 12:52:01 +0000 Subject: [PATCH 018/139] rename classes --- .../models/blt_wip/modeling_blt.py | 46 ++++++------------- 1 file changed, 14 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 831791eabeaf..94069d575410 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -299,8 +299,6 @@ def forward( return output - - class BLTMLP(nn.Module): def __init__( self, @@ -570,14 +568,12 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> return padded -class LocalModelBase(nn.Module): +class BLTLocalModelBase(nn.Module): def __init__(self, config: BLTConfig, component_type: str = "encoder"): super().__init__() - # Store config for later use self.config = config - # Use component-specific dimensions if component_type == "encoder": self.dim = config.dim_local_encoder self.n_layers = config.n_layers_local_encoder @@ -683,9 +679,7 @@ def apply_embedding(self, tokens, embeds): return self.tok_embeddings(tokens) - - -class LocalEncoder(LocalModelBase): +class BLTLocalEncoder(BLTLocalModelBase): def __init__(self, config: BLTConfig): super().__init__(config, component_type="encoder") @@ -798,7 +792,7 @@ def patch_reduce(self, h, max_num_patches, reduction, patch_ids): return reduced_embs -class LocalDecoder(LocalModelBase): +class BLTLocalDecoder(BLTLocalModelBase): def __init__(self, config: BLTConfig): super().__init__(config, component_type="decoder") @@ -886,11 +880,6 @@ def forward( class BLTCrossAttention(nn.Module): - """ - BLTCrossAttention block to attend to the encoder states from the decoder. - Rope is not supported. - """ - def __init__( self, dim: int, @@ -978,9 +967,7 @@ def forward( return x + output - - -class GlobalTransformer(nn.Module): +class BLTGlobalTransformer(nn.Module): def __init__(self, config): super().__init__() @@ -1057,8 +1044,6 @@ def forward( return h, cache - - def compute_hash_embeddings( local_encoder_tokens: torch.Tensor, local_encoder, @@ -1107,7 +1092,7 @@ class BLTPreTrainedModel(PreTrainedModel): config_class = BLTConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["BLTTransformerLayer", "LocalEncoder", "LocalDecoder", "GlobalTransformer"] + _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = False # BLT uses its own attention implementation _supports_sdpa = True @@ -1157,14 +1142,14 @@ def _init_weights(self, module): for emb in module.encoder_hash_tok_embedding: emb._custom_std = emb_std - elif isinstance(module, (LocalEncoder, LocalDecoder)): + elif isinstance(module, (BLTLocalEncoder, BLTLocalDecoder)): if module.token_embedding_projection is not None: module.token_embedding_projection._custom_std = module.dim ** (-0.5) if module.patch_embedding_projection is not None: module.patch_embedding_projection._custom_std = module.dim_patch_emb ** (-0.5) - elif isinstance(module, GlobalTransformer): + elif isinstance(module, BLTGlobalTransformer): if module.token_embedding_projection is not None: module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5) @@ -1179,9 +1164,9 @@ def __init__(self, config: BLTConfig): super().__init__(config) self.config = config - self.local_encoder = LocalEncoder(config) - self.global_transformer = GlobalTransformer(config) - self.local_decoder = LocalDecoder(config) + self.local_encoder = BLTLocalEncoder(config) + self.global_transformer = BLTGlobalTransformer(config) + self.local_decoder = BLTLocalDecoder(config) self.encoder_hash_tok_embedding = init_hash_embeddings( config, @@ -1236,9 +1221,7 @@ def forward( patch_lengths = process_patch_lengths(patch_lengths, self.config.max_patch_length) - #assert torch.min(patch_lengths) >= 0 - # Generate patch IDs from patch_lengths patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1]) # assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), ( @@ -1334,7 +1317,6 @@ def forward( def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: """ Convert patch lengths to patch IDs for each token position. - For each token position in the sequence, determines which patch it belongs to. Args: @@ -1607,7 +1589,7 @@ def init_hash_embeddings( "BLTPreTrainedModel", "BLTModel", "BLTPatcher", - "LocalEncoder", - "LocalDecoder", - "GlobalTransformer", -] + "BLTLocalEncoder", + "BLTLocalDecoder", + "BLTGlobalTransformer", +] \ No newline at end of file From 2b9dd64fd266e334793c5a243305e865a28c9679 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 18 Jun 2025 16:57:01 +0000 Subject: [PATCH 019/139] adding transformers Attention class and RotaryEmbedding class --- src/demo_hf.py | 5 +- .../models/blt_wip/configuration_blt.py | 13 + .../models/blt_wip/modeling_blt.py | 899 +++++++++--------- 3 files changed, 481 insertions(+), 436 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index 65715861f45f..c935add72575 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -3,7 +3,7 @@ import torch -from transformers.models.blt_wip.modeling_blt import BLTModel +from transformers.models.blt_wip.modeling_blt_modellike import BLTModel from transformers.models.blt_wip.tokenization_blt import BLTTokenizer @@ -11,6 +11,9 @@ os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" +import gc +gc.collect() +torch.cuda.empty_cache() def get_generation_range(prompt_tokens: list[list[int]] | None, max_gen_len: int) -> tuple[int, int]: batch_min_prompt_length = min([len(t) for t in prompt_tokens]) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index 7c645e8b18f6..bb05125ce248 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -514,6 +514,19 @@ def __init__( int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0 ] + # Rope + self.rope_scaling={ + "type": "dynamic", + "factor": 2.0, + "rope_type": "dynamic" + } + + self.num_key_value_heads=n_heads_local_encoder + self.max_position_embeddings=max_seqlen + self.hidden_size=dim_local_encoder + self.num_attention_heads=n_heads_local_encoder + # self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 94069d575410..ba237d92eb05 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -3,6 +3,7 @@ import logging import os from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn @@ -10,7 +11,9 @@ from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention -from ...modeling_utils import PreTrainedModel +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update + +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from .configuration_blt import ( BLTConfig, PatchingModeEnum, @@ -149,156 +152,6 @@ def apply_rotary_emb( return xq_out.type_as(xq), xk_out.type_as(xk) -# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed. -class RotaryEmbedding(torch.nn.Module): - """ - RotaryEmbedding Module - """ - - def __init__( - self, - theta: float, - head_dim: int, - max_seqlen: int = 1024, - rope_use_fp32_in_outer_product: bool = False, - ): - super().__init__() - - self.theta = theta - self.head_dim = head_dim - self.max_seqlen = max_seqlen - self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product - - self.register_buffer( - "freqs_cis", - precompute_freqs_cis( - dim=head_dim, - end=max_seqlen, - theta=theta, - rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, - ), - persistent=False, - ) - - - def forward(self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None): - """ - Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions - Args: - seqlen (int): Contiguous sequence length - tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen - - Returns: - Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis - """ - test = (seqlen is not None) or (tok_idx is not None) - assert test, "Should provide atleast seqlen or tok_idx" - if tok_idx is not None: - return self.freqs_cis[tok_idx] - elif seqlen is not None: - return self.freqs_cis[0:seqlen] - - -class BLTAttention(nn.Module): - def __init__( - self, - dim: int, - head_dim: int, - n_heads: int, - n_kv_heads: int, - rope_theta: float, - ): - super().__init__() - - self.dim = dim - self.head_dim = head_dim - self.rope_theta = rope_theta - - self.n_heads = n_heads - self.n_kv_heads = n_kv_heads - self.heads_per_group = self.n_heads // self.n_kv_heads - - self.wq = nn.Linear( - dim, - n_heads * head_dim, - bias=False, - ) - self.wk = nn.Linear( - dim, - n_kv_heads * head_dim, - bias=False, - ) - self.wv = nn.Linear( - dim, - n_kv_heads * head_dim, - bias=False, - ) - - self.wo = nn.Linear( - n_heads * head_dim, - dim, - bias=False, - ) - - def forward( - self, - x: torch.Tensor, - freq_cis: torch.Tensor, - tok_idx: Optional[torch.Tensor] = None, - mask: Optional[Union[BlockMask, str]] = None, - attn_impl: str = "sdpa", - ) -> torch.Tensor: - # B S D - bsz, seq_len, dim = x.shape - xq = self.wq(x.view_as(x)) - xk = self.wk(x.view_as(x)) - xv = self.wv(x.view_as(x)) - - output_shape = xq.shape - # B S D -> B S H D - xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) - xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) - xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) - - xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) - - # This condition helps us be easily compatible - # with inference by adding a pluggable KVCache - if hasattr(self, "kv_cache"): - xk, xv = self.kv_cache.update(xk, xv, tok_idx) - - xk = repeat_kv(xk, self.heads_per_group, dim=2) - xv = repeat_kv(xv, self.heads_per_group, dim=2) - - if attn_impl == "flex_attention": - assert mask is None or isinstance(mask, BlockMask) - xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) - output = flex_attention_comp(xq, xk, xv, block_mask=mask) - output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - - elif attn_impl == "sdpa": - xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) - assert mask is None or isinstance(mask, (str, torch.Tensor)) - is_causal = (mask == "causal") if isinstance(mask, str) else False - mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None - output = F.scaled_dot_product_attention( - xq, - xk, - xv, - is_causal=is_causal, - attn_mask=mask, - ) - output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - else: - raise NotImplementedError(f"Attention implementation {attn_impl} not supported") - - output_reshaped = output.reshape(output_shape) - - output = self.wo(output_reshaped) - - return output - - class BLTMLP(nn.Module): def __init__( self, @@ -343,37 +196,177 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return output +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + # TODO: not exactly equivalent to other transformers implementations,, need feedback + # Extract first head_dim//2 elements which correspond to the unique frequencies + # This matches the original BLT approach which uses head_dim//2 frequency pairs + head_dim = q.shape[-1] + cos_freqs = cos[..., :head_dim//2] # [B, S, D/2] + sin_freqs = sin[..., :head_dim//2] # [B, S, D/2] + + # Expand cos/sin to match query/key tensor format [B, H, S, D/2] + cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + + # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... + q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] + k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] + + # Extract real and i parts + q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] + k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] + + # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] + q_real_rot = cos_freqs * q_real - sin_freqs * q_imag + q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag + k_real_rot = cos_freqs * k_real - sin_freqs * k_imag + k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag + + # Recombine pairs and reshape back to original format + q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] + k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] + + return q_rot.type_as(q), k_rot.type_as(k) + + + +class BLTSelfAttention(nn.Module): + def __init__(self, config: BLTConfig, layer_idx: int): + super().__init__() + self.config = config + self.num_heads = config.num_attention_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = self.head_dim ** -0.5 + self.rope_theta = config.rope_theta + self.layer_idx = layer_idx + + self.wq = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.wk = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.wq(hidden_states) + key_states = self.wk(hidden_states) + value_states = self.wv(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + output_attentions = False + self.config._attn_implementation = "sdpa" + self.scaling = None + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value class BLTTransformerLayer(nn.Module): - def __init__(self, args): + def __init__(self, dim, n_heads, config, layer_idx=0): super().__init__() # Extract parameters from dictionary - dim = args["dim"] - n_heads = args["n_heads"] - head_dim = args["head_dim"] - n_kv_heads = args["n_kv_heads"] - rope_theta = args["rope_theta"] - multiple_of = args["multiple_of"] - ffn_dim_multiplier = args["ffn_dim_multiplier"] - norm_eps = args["norm_eps"] - - assert (head_dim is not None) or (n_heads is not None), "Should specify at least head_dim or n_heads" + dim = dim + n_heads = n_heads + head_dim = getattr(config, "head_dim", None) + n_kv_heads = getattr(config, "n_kv_heads", None) + rope_theta = getattr(config, "rope_theta", None) + multiple_of = getattr(config, "multiple_of", 256) + ffn_dim_multiplier = getattr(config, "ffn_dim_multiplier", None) + norm_eps = getattr(config, "norm_eps", None) + self.head_dim = head_dim or dim // n_heads self.n_heads = n_heads or dim // head_dim self.n_kv_heads = n_kv_heads or self.n_heads - assert n_heads % self.n_kv_heads == 0 - assert dim % n_heads == 0 + config.hidden_size = dim + + self.attention = BLTSelfAttention(config=config, layer_idx=layer_idx) - self.attention = BLTAttention( - dim=dim, - head_dim=self.head_dim, - n_heads=self.n_heads, - n_kv_heads=self.n_kv_heads, - rope_theta=rope_theta, - ) self.feed_forward = BLTMLP( dim=dim, hidden_dim=4 * dim, @@ -385,21 +378,33 @@ def __init__(self, args): def forward( self, - x: torch.Tensor, - freq_cis: torch.Tensor, - tok_idx: Optional[torch.Tensor] = None, - mask: Optional[Union[BlockMask, str]] = None, - attn_impl: str = "sdpa", + hidden_states: torch.Tensor, + past_key_value: Optional[bool] = None, + position_embeddings: Optional[torch.Tensor] = None, + + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: - norm_x = self.attention_norm(x) - attn_out = self.attention( - norm_x, - freq_cis, - tok_idx=tok_idx, - mask=mask, - attn_impl=attn_impl, + + residual = hidden_states + norm_hidden_states = self.attention_norm(hidden_states) + + + hidden_states, self_attn_weights, present_key_value = self.attention( + hidden_states=norm_hidden_states, + # TODO: = BLT, attn_out = self.attention(self.attention_norm(x), in TransformerBlock.forward, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + position_embeddings=position_embeddings ) - h = x + attn_out + + h = residual + hidden_states h_norm = self.ffn_norm(h) out = h + self.feed_forward(h_norm) return out @@ -568,156 +573,109 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> return padded -class BLTLocalModelBase(nn.Module): - def __init__(self, config: BLTConfig, component_type: str = "encoder"): + +def create_causal_mask_for_blt( + seqlen: int, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + sliding_window: Optional[int] = None, +) -> torch.Tensor: + """ + Creates a causal mask for BLT local encoder. + """ + min_value = torch.finfo(dtype).min + mask = torch.full( + (batch_size, 1, seqlen, seqlen), # Note: using seqlen, not total_seqlen + min_value, + dtype=dtype, + device=device, + ) + + if sliding_window is not None: + # Create local causal mask with sliding window + for i in range(seqlen): + start_idx = max(0, i - sliding_window + 1) + mask[:, :, i, start_idx:i + 1] = 0 + else: + # Create full causal mask + mask = torch.triu(mask, diagonal=0) + mask = mask.masked_fill(mask == 0, min_value) + + return mask + + +class BLTRotaryEmbedding(nn.Module): + def __init__(self, config: BLTConfig, device=None): super().__init__() + self.rope_type = config.rope_scaling["rope_type"] + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq - if component_type == "encoder": - self.dim = config.dim_local_encoder - self.n_layers = config.n_layers_local_encoder - self.n_heads = config.n_heads_local_encoder - self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen - self.attn_bias_type = "local_block_causal" - self.sliding_window = config.local_attention_window_len - elif component_type == "decoder": - self.dim = config.dim_local_decoder - self.n_layers = config.n_layers_local_decoder - self.n_heads = config.n_heads_local_decoder - self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen - self.attn_bias_type = "local_block_causal" - self.sliding_window = config.local_attention_window_len - else: - raise ValueError(f"Unknown component_type: {component_type}") + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() - self.dropout = config.dropout - self.vocab_size = config.vocab_size + config.pm_size - self.patch_size = config.patch_size + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling - self.attn_impl = config.attn_impl - self.use_rope = config.use_rope - self.init_std_factor = config.init_std_factor - self.init_base_std = config.init_base_std - self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None) - self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None) - self.cross_attn_k = getattr(config, "cross_attn_k", None) - self.eos_id = config.eos_token_id + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - self.boe_id = config.boe_id - # Initialize cross attention layers as None (will be set by subclasses if needed) - self.cross_attn_layers = None - # Create parameter dict for BLTTransformerLayers - layer_params = { - "dim": self.dim, - "n_heads": self.n_heads, - "head_dim": config.head_dim, - "n_kv_heads": getattr(config, "n_kv_heads", None), - "rope_theta": config.rope_theta, - "multiple_of": getattr(config, "multiple_of", 256), - "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), - "norm_eps": config.norm_eps, - } - - self.layers = nn.ModuleList([BLTTransformerLayer(layer_params) for _ in range(self.n_layers)]) - - if not self.use_rope: - self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length - else: - self.rope = RotaryEmbedding( - theta=config.rope_theta, - head_dim=config.head_dim or self.dim // self.n_heads, - max_seqlen=self.max_seqlen, - rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, - ) - self.pos_embeddings = None +class BLTLocalEncoder(nn.Module): + def __init__(self, config: BLTConfig): + super().__init__() + self.config = config #TODO: rm this - # Set dimension-specific embedding dimensions - if component_type == "encoder": - self.dim_token_emb = config.encoder_dim_token_emb - self.dim_patch_emb = config.encoder_dim_patch_emb - elif component_type == "decoder": - self.dim_token_emb = config.decoder_dim_token_emb - self.dim_patch_emb = config.dim_global + self.dropout = config.dropout + + self.layers = nn.ModuleList([BLTTransformerLayer(config.dim_local_encoder, config.n_heads_local_encoder, config) for _ in range(config.n_layers_local_encoder)]) + + self.rotary_emb = BLTRotaryEmbedding(config=config) + self.pos_embeddings = None self.token_embedding_projection = ( - nn.Linear(self.dim_token_emb, self.dim, bias=False) - if self.dim_token_emb is not None and self.dim_token_emb != self.dim + nn.Linear(config.encoder_dim_token_emb, config.dim_local_encoder, bias=False) + if config.encoder_dim_token_emb is not None and config.encoder_dim_token_emb != config.dim_local_encoder else None ) self.patch_embedding_projection = self._create_patch_projection(config) - def _should_create_patch_projection(self, config: BLTConfig): - dimension_mismatch = self.dim_patch_emb is not None and self.dim_patch_emb != self.dim - - # Check cross attention conditions - cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( - config.cross_attn_decoder and config.cross_attn_init_by_pooling - ) - - return dimension_mismatch or cross_attn_conditions + self.tok_embeddings = nn.Embedding(config.vocab_size + config.pm_size, config.dim_local_encoder) - def _create_patch_projection(self, config): - if not self._should_create_patch_projection(config): - return None - - output_dim = self.dim_token_emb * (self.cross_attn_k or 1) - - return nn.Linear( - in_features=self.dim_patch_emb, - out_features=output_dim, - bias=False, - ) - - def apply_embedding(self, tokens, embeds): - if embeds is not None: - return embeds - else: - return self.tok_embeddings(tokens) - - -class BLTLocalEncoder(BLTLocalModelBase): - def __init__(self, config: BLTConfig): - super().__init__(config, component_type="encoder") - - self.apply_transformer = config.use_local_encoder_transformer - self.downsampling_by_pooling = config.downsampling_by_pooling - self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None - self.cross_attn_encoder = config.cross_attn_encoder - self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder - self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling - self.cross_attn_nheads = config.cross_attn_nheads - - self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim) - - if self.cross_attn_encoder: - self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1 - for _ in range(layers_to_add): - self.cross_attn_layers.append( - BLTCrossAttention( - dim=self.dim, - head_dim=self.dim // self.cross_attn_nheads, - n_heads=self.cross_attn_nheads, - n_kv_heads=self.cross_attn_nheads, - norm_eps=config.norm_eps, - ) + # Initialize cross attention layers as None (will be set if needed) + self.cross_attn_layers = None + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = config.n_layers_local_encoder if config.cross_attn_all_layers_encoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention( + dim=config.dim_local_encoder, + head_dim=config.dim_local_encoder // config.cross_attn_nheads, + n_heads=config.cross_attn_nheads, + n_kv_heads=config.cross_attn_nheads, + norm_eps=config.norm_eps, ) - - def apply_embedding(self, tokens, embeds): - if embeds is not None: - assert self.expects_hash_embeddings, "Not expecting embeddings to be passed." - return embeds - else: - return self.tok_embeddings(tokens) + ) def forward( self, - tokens: torch.Tensor, - embeds: Optional[torch.Tensor] = None, + input_ids: torch.Tensor, + input_embeds: Optional[torch.Tensor] = None, patch_embeds: Optional[torch.Tensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, @@ -726,34 +684,52 @@ def forward( cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): """ """ - bs, seqlen = tokens.shape + bs, seqlen = input_ids.shape if mask is None: mask = create_causal_mask( seqlen, - self.attn_impl, + self.config.attn_impl, "local_block_causal", - sliding_window=self.sliding_window, - tokens=tokens, - eos_id=self.eos_id, + sliding_window=self.config.local_attention_window_len, + tokens=input_ids, + eos_id=self.config.eos_token_id, ) - h = self.apply_embedding(tokens, embeds) - freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + if input_embeds is None: + input_embeds = self.embed_tokens(input_ids) - h = F.dropout(h, p=self.dropout, training=self.training) + batch_size, seq_length, _ = input_embeds.shape - for i, layer in enumerate(self.layers): - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) - # check if cross attention should be applied to either all layer or only the last layer - if self.cross_attn_encoder and (i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): - # apply pooling and project - if self.cross_attn_init_by_pooling and patch_embeds is None: + if mask is None: + attention_mask = create_causal_mask_for_blt( + seqlen=seq_length, + batch_size=batch_size, + device=input_embeds.device, + dtype=input_embeds.dtype, + sliding_window=self.config.sliding_window, + ) + + h = input_embeds + + h_residual = input_embeds + h = nn.functional.dropout(h, p=self.dropout, training=self.training) + + position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(h, position_ids) + + h = F.dropout(h, p=self.config.dropout, training=self.training) + + for idx, layer in enumerate(self.layers): + h = layer(h, position_embeddings=position_embeddings, attention_mask=None) + + if getattr(self.config, "cross_attn_encoder", None) and (idx == len(self.layers) - 1 or self.config.cross_attn_all_layers_encoder): + if self.config.cross_attn_init_by_pooling and patch_embeds is None: patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) if self.patch_embedding_projection is not None: patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) + patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * getattr(self.config, "cross_attn_k", 1), self.config.dim_local_encoder) - layer_idx = i if self.cross_attn_all_layers_encoder else 0 + layer_idx = idx if self.config.cross_attn_all_layers_encoder else 0 patch_embeds_cross = self.cross_attn_layers[layer_idx]( x=patch_embeds, kv=h, @@ -761,8 +737,33 @@ def forward( ) patch_embeds = patch_embeds + patch_embeds_cross - h_residual = patch_embeds if self.cross_attn_encoder else None + h_residual = patch_embeds if getattr(self.config, "cross_attn_encoder", None) else None return (h, h_residual), cache + + def _create_patch_projection(self, config): + dimension_mismatch = config.encoder_dim_patch_emb is not None and config.encoder_dim_patch_emb != config.dim_local_encoder + + cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( + config.cross_attn_decoder and config.cross_attn_init_by_pooling + ) + + if not (dimension_mismatch or cross_attn_conditions): + return None + + output_dim = config.encoder_dim_token_emb * (getattr(config, "cross_attn_k", None) or 1) + + return nn.Linear( + in_features=config.encoder_dim_patch_emb, + out_features=output_dim, + bias=False, + ) + + def embed_tokens(self, tokens, embeds): + if embeds is not None: + assert self.config.encoder_hash_byte_group_size is not None, "Not expecting embeddings to be passed." + return embeds + else: + return self.tok_embeddings(tokens) def patch_reduce(self, h, max_num_patches, reduction, patch_ids): """ @@ -792,38 +793,70 @@ def patch_reduce(self, h, max_num_patches, reduction, patch_ids): return reduced_embs -class BLTLocalDecoder(BLTLocalModelBase): +class BLTLocalDecoder(nn.Module): def __init__(self, config: BLTConfig): - super().__init__(config, component_type="decoder") - - # Model configuration flags - self.cross_attn_decoder = config.cross_attn_decoder - self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder - self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling - self.cross_attn_nheads = config.cross_attn_nheads - - self.norm = RMSNorm(self.dim, eps=config.norm_eps) - - if self.cross_attn_decoder: - self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1 - for _ in range(layers_to_add): - self.cross_attn_layers.append( - BLTCrossAttention( - dim=self.dim, - head_dim=self.dim // self.cross_attn_nheads, - n_heads=self.cross_attn_nheads, - n_kv_heads=self.cross_attn_nheads, - norm_eps=config.norm_eps, - ) + super().__init__() + + self.config = config + + self.layers = nn.ModuleList([BLTTransformerLayer(config.dim_local_decoder, config.n_heads_local_decoder, config) for _ in range(config.n_layers_local_decoder)]) + + self.rotary_emb = BLTRotaryEmbedding(config=config) + + self.pos_embeddings = None + + self.token_embedding_projection = ( + nn.Linear(config.decoder_dim_token_emb, config.dim_local_decoder, bias=False) + if config.decoder_dim_token_emb is not None and config.decoder_dim_token_emb != config.dim_local_decoder + else None + ) + + self.patch_embedding_projection = self._create_patch_projection(config) + + self.norm = RMSNorm(config.dim_local_decoder, eps=config.norm_eps) + + # Initialize cross attention layers as None (will be set if needed) + self.cross_attn_layers = None + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = config.n_layers_local_decoder if config.cross_attn_all_layers_decoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention( + dim=config.dim_local_decoder, + head_dim=config.dim_local_decoder // config.cross_attn_nheads, + n_heads=config.cross_attn_nheads, + n_kv_heads=config.cross_attn_nheads, + norm_eps=config.norm_eps, ) + ) - self.output = nn.Linear( - self.dim, - config.vocab_size, + self.output = nn.Linear(config.dim_local_decoder, config.vocab_size, bias=False) + + def _create_patch_projection(self, config): + dimension_mismatch = config.dim_global is not None and config.dim_global != config.dim_local_decoder + + # Check cross attention conditions + cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( + config.cross_attn_decoder and config.cross_attn_init_by_pooling + ) + + if not (dimension_mismatch or cross_attn_conditions): + return None + + output_dim = config.decoder_dim_token_emb * (getattr(config, "cross_attn_k", None) or 1) + + return nn.Linear( + in_features=config.dim_global, + out_features=output_dim, bias=False, ) + def apply_embedding(self, tokens, embeds): + if embeds is not None: + return embeds + else: + return self.tok_embeddings(tokens) + def forward( self, tokens: torch.Tensor, @@ -839,29 +872,33 @@ def forward( if mask is None: mask = create_causal_mask( seqlen, - self.attn_impl, + self.config.attn_impl, "local_block_causal", - sliding_window=self.sliding_window, + sliding_window=self.config.local_attention_window_len, tokens=tokens, - eos_id=self.eos_id, + eos_id=self.config.eos_token_id, ) h = embeds + batch_size, seq_length, _ = embeds.shape + + if self.patch_embedding_projection is not None: assert patch_embeds is not None, "Patch embeddings must be passed." patch_embeds = self.patch_embedding_projection(patch_embeds) - if self.cross_attn_k is not None: - patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) + if getattr(self.config, "cross_attn_k", None) is not None: + patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.dim_local_decoder) - if patch_embeds is not None and not self.cross_attn_decoder: + if patch_embeds is not None and not getattr(self.config, "cross_attn_decoder", None): h = h + patch_embeds - freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(h, position_ids) - h = F.dropout(h, p=self.dropout, training=self.training) + h = F.dropout(h, p=self.config.dropout, training=self.training) for i, layer in enumerate(self.layers): - if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder): + if getattr(self.config, "cross_attn_decoder", None) and (i == 0 or self.config.cross_attn_all_layers_decoder): # Use cross attention to extract info from patch_embeds into h h_cross = self.cross_attn_layers[i]( x=h, @@ -870,10 +907,11 @@ def forward( ) h = h + h_cross - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) + h = layer(h, position_embeddings=position_embeddings, attention_mask=None) + h_preds = self.norm(h) - h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) + h_preds = F.dropout(h_preds, p=self.config.dropout, training=self.training) h_preds = self.output(h_preds) h_preds = h_preds.float() return h_preds, cache @@ -973,34 +1011,17 @@ def __init__(self, config): self.config = config - self.dim = config.dim_global - self.rope_embeddings = RotaryEmbedding( - theta=config.rope_theta, - head_dim=config.head_dim or self.config.dim_global // config.n_heads_global, - max_seqlen=config.max_seqlen, - rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, - ) - # Handle both eos_id and eos_token_id for compatibility - self.eos_id = getattr(config, "eos_id", getattr(config, "eos_token_id", 2)) - - # Create parameter dict for BLTTransformerLayers - layer_params = { - "dim": self.dim, - "n_heads": config.n_heads_global, - "head_dim": config.head_dim, - "n_kv_heads": getattr(config, "n_kv_heads_global", None), - "rope_theta": config.rope_theta, - "multiple_of": getattr(config, "multiple_of", 256), - "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), - "norm_eps": config.norm_eps, - } - self.layers = nn.ModuleList() + old = config.n_kv_heads + config.n_kv_heads = config.n_kv_heads_global for _ in range(config.n_layers_global): - self.layers.append(BLTTransformerLayer(layer_params)) + self.layers.append(BLTTransformerLayer(self.config.dim_global, self.config.n_heads_global, config)) + config.n_kv_heads = old + + self.rotary_emb = BLTRotaryEmbedding(config=config) self.token_embedding_projection = None - if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim: + if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != config.dim_global: self.token_embedding_projection = nn.Linear( config.global_dim_patch_emb, config.dim_global, @@ -1027,11 +1048,11 @@ def forward( self.config.attn_impl, self.config.attn_bias_type, tokens=tokens, - eos_id=self.eos_id, + eos_id=self.config.eos_id, ) ) - if self.token_embedding_projection is not None and h.shape[-1] != self.dim: + if self.token_embedding_projection is not None and h.shape[-1] != self.config.dim_global: h = self.token_embedding_projection(h) h = F.dropout(h, p=self.config.dropout, training=self.training) @@ -1123,10 +1144,10 @@ def _init_weights(self, module): b=3 * std, ) - elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)): - nn.init.ones_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) + # elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)): + # nn.init.ones_(module.weight) + # if module.bias is not None: + # nn.init.zeros_(module.bias) elif isinstance(module, RotaryEmbedding): module.freqs_cis[...] = precompute_freqs_cis( @@ -1138,16 +1159,23 @@ def _init_weights(self, module): elif isinstance(module, BLTModel): if module.encoder_hash_tok_embedding is not None: - emb_std = module.local_encoder.dim ** (-0.5) + emb_std = module.config.dim_local_encoder ** (-0.5) for emb in module.encoder_hash_tok_embedding: emb._custom_std = emb_std - elif isinstance(module, (BLTLocalEncoder, BLTLocalDecoder)): + elif isinstance(module, BLTLocalEncoder): + if module.token_embedding_projection is not None: + module.token_embedding_projection._custom_std = module.config.dim_local_encoder ** (-0.5) + + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.config.encoder_dim_patch_emb ** (-0.5) + + elif isinstance(module, BLTLocalDecoder): if module.token_embedding_projection is not None: - module.token_embedding_projection._custom_std = module.dim ** (-0.5) + module.token_embedding_projection._custom_std = module.config.dim_local_decoder ** (-0.5) if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.dim_patch_emb ** (-0.5) + module.patch_embedding_projection._custom_std = module.config.dim_global ** (-0.5) elif isinstance(module, BLTGlobalTransformer): if module.token_embedding_projection is not None: @@ -1170,7 +1198,7 @@ def __init__(self, config: BLTConfig): self.encoder_hash_tok_embedding = init_hash_embeddings( config, - local_encoder_dim=self.local_encoder.dim, + local_encoder_dim=config.dim_local_encoder, encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, ) @@ -1256,8 +1284,8 @@ def forward( # Local encoder (h_encoder, h_cross), cache_encoder = self.local_encoder( - tokens=local_encoder_tokens, - embeds=local_encoder_embeds, + input_ids=local_encoder_tokens, + input_embeds=local_encoder_embeds, patch_embeds=None, cross_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], @@ -1359,31 +1387,32 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> class BLTPatcher(BLTPreTrainedModel): def __init__(self, config): + config.num_attention_heads = config.patcher_n_heads + config.hidden_size = config.patcher_dim + + + config.n_heads = config.patcher_n_heads + + + config.num_key_value_heads = 12 # config.num_key_value_heads #TODO: add patcher_n_kv_heads + + config.head_dim = 64 #self.config.patcher_head_dim #TODO: add super().__init__(config) - self.rope_embeddings = RotaryEmbedding( - theta=config.patcher_rope_theta, - head_dim=config.patcher_head_dim or config.patcher_dim // config.patcher_n_heads, - max_seqlen=config.patcher_max_seqlen, - rope_use_fp32_in_outer_product=config.patcher_rope_use_fp32_in_outer_product, - ) + self.config.hidden_size = self.config.patcher_dim + self.config.n_heads = self.config.patcher_n_heads + self.config.head_dim = 64 #self.config.patcher_head_dim #TODO: add + + # Create a patcher-specific config copy to use patcher_rope_theta and simple rope + import copy + patcher_config = copy.deepcopy(config) + patcher_config.rope_theta = config.patcher_rope_theta + patcher_config.rope_scaling = {"rope_type": "default"} # Use simple default rope for patcher + self.rotary_emb = BLTRotaryEmbedding(config=patcher_config) self.layers = nn.ModuleList() for _ in range(config.patcher_n_layers): - self.layers.append( - BLTTransformerLayer( - { - "dim": config.patcher_dim, - "n_heads": config.patcher_n_heads, - "head_dim": config.patcher_head_dim, - "n_kv_heads": config.patcher_n_kv_heads, - "rope_theta": config.patcher_rope_theta, - "multiple_of": config.patcher_multiple_of, - "ffn_dim_multiplier": config.patcher_ffn_dim_multiplier, - "norm_eps": config.patcher_norm_eps, - } - ) - ) + self.layers.append(BLTTransformerLayer(config.patcher_dim, config.patcher_n_heads, config)) #assert config.patcher_vocab_size > 0 @@ -1425,21 +1454,21 @@ def forward( # Process chunk: embeddings -> layers -> output bsz, seqlen = split.shape - h = self.tok_embeddings(split) - chunk_mask = create_causal_mask( - seqlen, - self.config.patcher_attn_impl , - self.config.patcher_attn_bias_type, - sliding_window=self.config.patcher_sliding_window, - tokens=split, - eos_id=self.config.eos_id, - ) - freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None) + input_embeds = self.tok_embeddings(split) + + hidden_states = input_embeds + + batch_size, seq_length, _ = input_embeds.shape + + position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) # = BLT self.rope + for i, layer in enumerate(self.layers): - h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=self.config.patcher_attn_impl ) + hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) #, attn_impl=self.config.patcher_attn_impl ) - pred = self.output(self.norm(h)) + pred = self.output(self.norm(hidden_states)) pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] preds.append(pred) pred_entropies = self.entropy(pred) From f25a99b56c1b0cd7906f2c87b11cc13ef1b5fd04 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 19 Jun 2025 09:11:36 +0000 Subject: [PATCH 020/139] exchanged blt modules for transformers modules: attention, rotary_emb, create_causal_mask, etc --- .../models/blt_wip/modeling_blt.py | 170 ++---------------- 1 file changed, 17 insertions(+), 153 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index ba237d92eb05..bb388a5e6fd1 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -25,38 +25,6 @@ flex_attention_comp = flex_attention - -def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - -def create_causal_mask( - seqlen, - attn_impl: str, - attn_bias_type: str | None, - *, - eos_id: int | None = None, - tokens: torch.Tensor | None = None, - sliding_window: int | None = None, -): - if attn_impl == "sdpa": - BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) - - if attn_bias_type == "causal": - return "causal" - - if BLT_SUPPRESS_ATTN_ERROR == 1: - return "causal" - else: - raise ValueError( - "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" - ) - elif attn_impl == "flex_attention": - return create_block_mask(causal_mask, None, None, seqlen, seqlen) - else: - raise NotImplementedError(f"Attention {attn_impl} with {sliding_window} sliding window not implemented") - - def cross_entropy(pred, target, **kwargs): return F.nll_loss( F.log_softmax(pred.flatten(end_dim=-2).float(), -1), @@ -77,81 +45,6 @@ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: .reshape(bs, slen, n_kv_heads * n_rep, head_dim) ) - -def precompute_freqs_cis( - dim: int, - end: int, - theta: float = 10000.0, - rope_use_fp32_in_outer_product: bool = False, -): - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' - and the end index 'end'. The 'theta' parameter scales the frequencies. - The returned tensor contains complex values in complex64 data type. - - Args: - dim (int): Dimension of the frequency tensor. - end (int): End index for precomputing frequencies. - theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. - - Returns: - torch.Tensor: Precomputed frequency tensor with complex exponentials. - """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) - if rope_use_fp32_in_outer_product: - t = t.to(torch.float32) - - freqs = torch.outer(t, freqs).float() - - cos, sin = freqs.cos(), freqs.sin() - - return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): - """ - Reshape frequency tensor for broadcasting it with another tensor. - - This function reshapes the frequency tensor to have the same shape as the target tensor 'x' - for the purpose of broadcasting the frequency tensor during element-wise operations. - - Args: - freqs_cis (torch.Tensor): Frequency tensor to be reshaped. - x (torch.Tensor): Target tensor for broadcasting compatibility. - seq_dim (int): Sequence dimension index. - - Returns: - torch.Tensor: Reshaped frequency tensor. - """ - ndim = x.ndim - assert 0 <= seq_dim < ndim - assert freqs_cis.shape == ( - x.shape[seq_dim], - x.shape[-3], - 2, - 2, - ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" - shape = [d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])] + [2, 2] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - seq_dim: int, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 - xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 - freqs_cis = reshape_for_broadcast(freqs_cis, xq_, seq_dim).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 - xq_out = (xq_ * freqs_cis).sum(5).flatten(3) - xk_out = (xk_ * freqs_cis).sum(5).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - class BLTMLP(nn.Module): def __init__( self, @@ -685,16 +578,6 @@ def forward( ): """ """ bs, seqlen = input_ids.shape - if mask is None: - mask = create_causal_mask( - seqlen, - self.config.attn_impl, - "local_block_causal", - sliding_window=self.config.local_attention_window_len, - tokens=input_ids, - eos_id=self.config.eos_token_id, - ) - if input_embeds is None: input_embeds = self.embed_tokens(input_ids) @@ -870,15 +753,15 @@ def forward( assert embeds is not None, "Embeddings must be provided" if mask is None: - mask = create_causal_mask( - seqlen, - self.config.attn_impl, - "local_block_causal", - sliding_window=self.config.local_attention_window_len, - tokens=tokens, - eos_id=self.config.eos_token_id, + attention_mask = create_causal_mask_for_blt( + seqlen=seq_length, + batch_size=batch_size, + device=embeds.device, + dtype=embeds.dtype, + sliding_window=self.config.sliding_window, ) + h = embeds batch_size, seq_length, _ = embeds.shape @@ -1030,37 +913,26 @@ def __init__(self, config): def forward( self, - tokens: torch.Tensor, + input_ids: torch.Tensor, tok_idx: Optional[torch.Tensor] = None, - embeds: Optional[torch.Tensor] = None, + input_embeds: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - bs, seqlen = tokens.shape - - h = embeds + batch_size, seq_length, _ = input_embeds.shape - mask = ( - mask - if mask is not None - else create_causal_mask( - seqlen, - self.config.attn_impl, - self.config.attn_bias_type, - tokens=tokens, - eos_id=self.config.eos_id, - ) - ) + h = input_embeds if self.token_embedding_projection is not None and h.shape[-1] != self.config.dim_global: h = self.token_embedding_projection(h) h = F.dropout(h, p=self.config.dropout, training=self.training) - freq_cis = self.rope_embeddings(seqlen=self.config.max_seqlen, tok_idx=tok_idx) + position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(h, position_ids) for i, layer in enumerate(self.layers): - h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=self.config.attn_impl) + h = layer(h, position_embeddings=position_embeddings, attention_mask=None) return h, cache @@ -1148,15 +1020,7 @@ def _init_weights(self, module): # nn.init.ones_(module.weight) # if module.bias is not None: # nn.init.zeros_(module.bias) - - elif isinstance(module, RotaryEmbedding): - module.freqs_cis[...] = precompute_freqs_cis( - dim=module.head_dim, - end=module.max_seqlen, - theta=module.theta, - rope_use_fp32_in_outer_product=module.rope_use_fp32_in_outer_product, - ) - + elif isinstance(module, BLTModel): if module.encoder_hash_tok_embedding is not None: emb_std = module.config.dim_local_encoder ** (-0.5) @@ -1302,8 +1166,8 @@ def forward( global_tokens[rows, eos_patch_ids] = self.config.eos_token_id h, _ = self.global_transformer( - embeds=h, - tokens=global_tokens, + input_embeds=h, + input_ids=global_tokens, ) # Unpatching From 73f7e169ccc0e4d6a107ecdfe2e38f0a3cc84db3 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 19 Jun 2025 16:01:00 +0000 Subject: [PATCH 021/139] seperate out patcher config, update modeling and conversion script --- src/convert_blt_to_hf.py | 74 +++--- .../models/blt_wip/configuration_blt.py | 239 +++++++++++------- .../models/blt_wip/modeling_blt.py | 115 ++++----- 3 files changed, 235 insertions(+), 193 deletions(-) diff --git a/src/convert_blt_to_hf.py b/src/convert_blt_to_hf.py index cb933961b4e2..5a4d368b6b7e 100644 --- a/src/convert_blt_to_hf.py +++ b/src/convert_blt_to_hf.py @@ -9,7 +9,7 @@ from safetensors.torch import load_file, save_file from transformers.models.blt_wip.configuration_blt import BLTConfig -from transformers.models.blt_wip.modeling_blt_wip import BLTModel +from transformers.models.blt_wip.modeling_blt_modellike import BLTModel from transformers.utils import logging as transformers_logging @@ -61,6 +61,40 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str if isinstance(patch_size, float): patch_size = int(patch_size) + # Create patcher configuration dictionary + patcher_config = { + "vocab_size": int(entropy_model_params.get("vocab_size", 256)), + "dim": int(entropy_model_params.get("dim", 512)), + "n_layers": int(entropy_model_params.get("n_layers", 8)), + "n_heads": int(entropy_model_params.get("n_heads", 8)), + "head_dim": int(entropy_model_params.get("head_dim")) + if entropy_model_params.get("head_dim") is not None + else None, # Let BLTPatcherConfig compute this from dim // n_heads + "n_kv_heads": int(entropy_model_params.get("n_kv_heads")) + if entropy_model_params.get("n_kv_heads") is not None + else None, # Let BLTPatcherConfig default this to n_heads + "max_seqlen": int(entropy_model_params.get("max_seqlen", 1024)), + "norm_eps": entropy_model_params.get("norm_eps", 1e-5), + "dropout": entropy_model_params.get("dropout", 0.0), + "sliding_window": int(entropy_model_params.get("sliding_window", 512)) + if entropy_model_params.get("sliding_window") is not None + else None, + "ffn_dim_multiplier": entropy_model_params.get("ffn_dim_multiplier"), + "multiple_of": int(entropy_model_params.get("multiple_of", 256)), + "rope_theta": entropy_model_params.get("rope_theta", 10000.0), + "rope_use_fp32_in_outer_product": entropy_model_params.get( + "rope_use_fp32_in_outer_product", False + ), + "attn_impl": entropy_model_params.get("attn_impl", "sdpa"), + "attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"), + "init_base_std": entropy_model_params.get("init_base_std"), + "init_std_factor": entropy_model_params.get("init_std_factor", "disabled"), + "dim_token_emb": entropy_model_params.get("dim_token_emb"), + "weight_tying": entropy_model_params.get("weight_tying", False), + "bos_token_id": entropy_model_params.get("bos_token_id", 1), + "eos_token_id": entropy_model_params.get("eos_token_id", 2), + } + unified_config.update( { "patch_in_forward": True, @@ -73,36 +107,7 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str "patching_batch_size": patcher_args.get("patching_batch_size", 1), "patching_device": patcher_args.get("patching_device", "cuda"), "monotonicity": patcher_args.get("monotonicity", False), - "patcher_vocab_size": int(entropy_model_params.get("vocab_size", 256)), - "patcher_dim": int(entropy_model_params.get("dim", 512)), - "patcher_n_layers": int(entropy_model_params.get("n_layers", 8)), - "patcher_n_heads": int(entropy_model_params.get("n_heads", 8)), - "patcher_head_dim": int(entropy_model_params.get("head_dim")) - if entropy_model_params.get("head_dim") is not None - else None, - "patcher_n_kv_heads": int(entropy_model_params.get("n_kv_heads")) - if entropy_model_params.get("n_kv_heads") is not None - else None, - "patcher_max_seqlen": int(entropy_model_params.get("max_seqlen", 1024)), - "patcher_norm_eps": entropy_model_params.get("norm_eps", 1e-5), - "patcher_dropout": entropy_model_params.get("dropout", 0.0), - "patcher_sliding_window": int(entropy_model_params.get("sliding_window", 512)) - if entropy_model_params.get("sliding_window") is not None - else None, - "patcher_ffn_dim_multiplier": entropy_model_params.get("ffn_dim_multiplier"), - "patcher_multiple_of": int(entropy_model_params.get("multiple_of", 256)), - "patcher_rope_theta": entropy_model_params.get("rope_theta", 10000.0), - "patcher_rope_use_fp32_in_outer_product": entropy_model_params.get( - "rope_use_fp32_in_outer_product", False - ), - "patcher_attn_impl": entropy_model_params.get("attn_impl", "sdpa"), - "patcher_attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"), - "patcher_init_base_std": entropy_model_params.get("init_base_std"), - "patcher_init_std_factor": entropy_model_params.get("init_std_factor", "disabled"), - "patcher_dim_token_emb": entropy_model_params.get("dim_token_emb"), - "patcher_weight_tying": entropy_model_params.get("weight_tying", False), - "patcher_bos_token_id": entropy_model_params.get("bos_token_id", 1), - "patcher_eos_token_id": entropy_model_params.get("eos_token_id", 2), + "patcher_args": patcher_config, } ) @@ -168,8 +173,7 @@ def validate_unified_model(config: Dict[str, Any], weights: Dict[str, torch.Tens "n_layers", "n_heads", "patch_in_forward", - "patcher_vocab_size", - "patcher_dim", + "patcher_args", ] missing_keys = [key for key in required_keys if key not in config] @@ -339,7 +343,7 @@ def main(): parser.add_argument( "--push_to_hub", type=str, - default="itazap/blt-1b", + default="itazap/blt-1b-hf", ) parser.add_argument( "--hub_private", @@ -349,7 +353,7 @@ def main(): parser.add_argument( "--hub_token", type=str, - default="hf_your_token_here", + default="hf_token", ) args = parser.parse_args() diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index bb05125ce248..bd5173e9d9f3 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -1,3 +1,5 @@ +# new config + # coding=utf-8 # Copyright 2024 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. # @@ -37,6 +39,122 @@ class PatchingModeEnum(str, Enum): byte = "byte" +class BLTPatcherConfig(PretrainedConfig): + r""" + Configuration class for the BLT Patcher/Entropy model component. + + Args: + vocab_size (`int`, *optional*, defaults to 256): + Vocabulary size for the entropy model used in patching. + dim (`int`, *optional*, defaults to 512): + Hidden dimension for the entropy model. + n_layers (`int`, *optional*, defaults to 8): + Number of layers in the entropy model. + n_heads (`int`, *optional*, defaults to 8): + Number of attention heads in the entropy model. + head_dim (`int`, *optional*): + Dimension of each attention head in the entropy model. + n_kv_heads (`int`, *optional*): + Number of key-value heads in the entropy model. + max_seqlen (`int`, *optional*, defaults to 1024): + Maximum sequence length for the entropy model. + norm_eps (`float`, *optional*, defaults to 1e-5): + Layer normalization epsilon for the entropy model. + dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the entropy model. + sliding_window (`int`, *optional*): + Sliding window size for the entropy model attention. + ffn_dim_multiplier (`float`, *optional*): + Feedforward dimension multiplier for the entropy model. + multiple_of (`int`, *optional*, defaults to 256): + Make feedforward dimension multiple of this for the entropy model. + rope_theta (`float`, *optional*, defaults to 10000.0): + RoPE theta parameter for the entropy model. + rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False): + Whether to use fp32 in RoPE outer product for the entropy model. + attn_impl (`str`, *optional*, defaults to "sdpa"): + Attention implementation for the entropy model. + attn_bias_type (`str`, *optional*, defaults to "causal"): + Attention bias type for the entropy model. + init_base_std (`float`, *optional*): + Base initialization standard deviation for the entropy model. + init_std_factor (`str`, *optional*, defaults to "disabled"): + Initialization std factor for the entropy model. + dim_token_emb (`int`, *optional*): + Token embedding dimension for the entropy model. + weight_tying (`bool`, *optional*, defaults to False): + Whether to tie embeddings in the entropy model. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of sequence token id for the entropy model. + eos_token_id (`int`, *optional*, defaults to 2): + End of sequence token id for the entropy model. + """ + + model_type = "blt_patcher" + + def __init__( + self, + vocab_size=256, + dim=512, + n_layers=8, + n_heads=8, + head_dim=None, + n_kv_heads=None, + max_seqlen=1024, + norm_eps=1e-5, + dropout=0.0, + sliding_window=None, + ffn_dim_multiplier=None, + multiple_of=256, + rope_theta=10000.0, + rope_use_fp32_in_outer_product=False, + attn_impl="sdpa", + attn_bias_type="causal", + init_base_std=None, + init_std_factor="disabled", + dim_token_emb=None, + weight_tying=False, + bos_token_id=1, + eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.dim = dim + self.n_layers = n_layers + self.n_heads = n_heads + self.head_dim = head_dim if head_dim is not None else (dim // n_heads) + self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads + self.max_seqlen = max_seqlen + self.norm_eps = norm_eps + self.dropout = dropout + self.sliding_window = sliding_window + self.ffn_dim_multiplier = ffn_dim_multiplier + self.multiple_of = multiple_of + self.rope_theta = rope_theta + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product + self.attn_impl = attn_impl + self.attn_bias_type = attn_bias_type + self.init_base_std = init_base_std + self.init_std_factor = InitStdFactor(init_std_factor) + self.dim_token_emb = dim_token_emb + self.weight_tying = weight_tying + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + # Add attributes needed for compatibility with transformer models + self.hidden_size = dim + self.num_attention_heads = n_heads + self.num_key_value_heads = self.n_kv_heads # Use the computed n_kv_heads + self.max_position_embeddings = max_seqlen + + # Set simple rope scaling for patcher (no complex dynamic rope) + self.rope_scaling = {"rope_type": "default"} + + class BLTConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ByteLatentTransformer`]. It is used to instantiate a @@ -215,51 +333,10 @@ class BLTConfig(PretrainedConfig): pad_token_id (`int`, *optional*, defaults to -1): The id of the padding token. - # Patcher/Entropy model configuration - patcher_vocab_size (`int`, *optional*, defaults to 256): - Vocabulary size for the entropy model used in patching. - patcher_dim (`int`, *optional*, defaults to 512): - Hidden dimension for the entropy model. - patcher_n_layers (`int`, *optional*, defaults to 8): - Number of layers in the entropy model. - patcher_n_heads (`int`, *optional*, defaults to 8): - Number of attention heads in the entropy model. - patcher_head_dim (`int`, *optional*): - Dimension of each attention head in the entropy model. - patcher_n_kv_heads (`int`, *optional*): - Number of key-value heads in the entropy model. - patcher_max_seqlen (`int`, *optional*, defaults to 1024): - Maximum sequence length for the entropy model. - patcher_norm_eps (`float`, *optional*, defaults to 1e-5): - Layer normalization epsilon for the entropy model. - patcher_dropout (`float`, *optional*, defaults to 0.0): - Dropout probability for the entropy model. - patcher_sliding_window (`int`, *optional*): - Sliding window size for the entropy model attention. - patcher_ffn_dim_multiplier (`float`, *optional*): - Feedforward dimension multiplier for the entropy model. - patcher_multiple_of (`int`, *optional*, defaults to 256): - Make feedforward dimension multiple of this for the entropy model. - patcher_rope_theta (`float`, *optional*, defaults to 10000.0): - RoPE theta parameter for the entropy model. - patcher_rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False): - Whether to use fp32 in RoPE outer product for the entropy model. - patcher_attn_impl (`str`, *optional*, defaults to "sdpa"): - Attention implementation for the entropy model. - patcher_attn_bias_type (`str`, *optional*, defaults to "causal"): - Attention bias type for the entropy model. - patcher_init_base_std (`float`, *optional*): - Base initialization standard deviation for the entropy model. - patcher_init_std_factor (`str`, *optional*, defaults to "disabled"): - Initialization std factor for the entropy model. - patcher_dim_token_emb (`int`, *optional*): - Token embedding dimension for the entropy model. - patcher_weight_tying (`bool`, *optional*, defaults to False): - Whether to tie embeddings in the entropy model. - patcher_bos_token_id (`int`, *optional*, defaults to 1): - Beginning of sequence token id for the entropy model. - patcher_eos_token_id (`int`, *optional*, defaults to 2): - End of sequence token id for the entropy model. + # Patcher configuration + patcher_args (`dict`, *optional*): + Dictionary containing configuration arguments for the BLT patcher/entropy model. + If provided, these will be used to initialize a BLTPatcherConfig instance. ```python >>> from transformers import ByteLatentTransformer, BLTConfig @@ -363,32 +440,14 @@ def __init__( bos_token_id=1, eos_token_id=2, pad_token_id=-1, - # Patcher/Entropy model configuration - patcher_vocab_size=256, - patcher_dim=512, - patcher_n_layers=8, - patcher_n_heads=8, - patcher_head_dim=None, - patcher_n_kv_heads=None, - patcher_max_seqlen=1024, - patcher_norm_eps=1e-5, - patcher_dropout=0.0, - patcher_sliding_window=None, - patcher_ffn_dim_multiplier=None, - patcher_multiple_of=256, - patcher_rope_theta=10000.0, - patcher_rope_use_fp32_in_outer_product=False, - patcher_attn_impl="sdpa", - patcher_attn_bias_type="causal", - patcher_init_base_std=None, - patcher_init_std_factor="disabled", - patcher_dim_token_emb=None, - patcher_weight_tying=False, - patcher_bos_token_id=1, - patcher_eos_token_id=2, + boe_id=0, + # Patcher configuration + patcher_args=None, # Inherited **kwargs, ): + + self.sliding_window = None # Basic model configuration self.vocab_size = vocab_size self.max_seqlen = max_seqlen @@ -397,7 +456,7 @@ def __init__( self.dim = dim self.n_layers = n_layers self.n_heads = n_heads - self.head_dim = head_dim + self.head_dim = head_dim if head_dim is not None else (dim // n_heads) self.n_kv_heads = n_kv_heads # Component-specific dimensions @@ -483,30 +542,16 @@ def __init__( # Parameter mixing self.pm_size = pm_size + + # Special token IDs + self.boe_id = boe_id - # Patcher/Entropy model configuration - self.patcher_vocab_size = patcher_vocab_size - self.patcher_dim = patcher_dim - self.patcher_n_layers = patcher_n_layers - self.patcher_n_heads = patcher_n_heads - self.patcher_head_dim = patcher_head_dim - self.patcher_n_kv_heads = patcher_n_kv_heads - self.patcher_max_seqlen = patcher_max_seqlen - self.patcher_norm_eps = patcher_norm_eps - self.patcher_dropout = patcher_dropout - self.patcher_sliding_window = patcher_sliding_window - self.patcher_ffn_dim_multiplier = patcher_ffn_dim_multiplier - self.patcher_multiple_of = patcher_multiple_of - self.patcher_rope_theta = patcher_rope_theta - self.patcher_rope_use_fp32_in_outer_product = patcher_rope_use_fp32_in_outer_product - self.patcher_attn_impl = patcher_attn_impl - self.patcher_attn_bias_type = patcher_attn_bias_type - self.patcher_init_base_std = patcher_init_base_std - self.patcher_init_std_factor = InitStdFactor(patcher_init_std_factor) - self.patcher_dim_token_emb = patcher_dim_token_emb - self.patcher_weight_tying = patcher_weight_tying - self.patcher_bos_token_id = patcher_bos_token_id - self.patcher_eos_token_id = patcher_eos_token_id + # Initialize patcher configuration + if patcher_args is not None: + self.patcher_config = BLTPatcherConfig(**patcher_args) + else: + # Use default values if no patcher_args provided + self.patcher_config = BLTPatcherConfig() # Handle hash byte group size validation if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str: @@ -516,9 +561,8 @@ def __init__( # Rope self.rope_scaling={ - "type": "dynamic", - "factor": 2.0, - "rope_type": "dynamic" + "type": "default", + "rope_type": "default" } self.num_key_value_heads=n_heads_local_encoder @@ -534,6 +578,8 @@ def __init__( **kwargs, ) + + @property def encoder_dim_token_emb(self): """Compute encoder token embedding dimension.""" @@ -600,4 +646,5 @@ def get_init_std_factor(self, depth: int) -> float: return 1.0 -__all__ = ["BLTConfig", "InitStdFactor", "PatchingModeEnum"] +__all__ = ["BLTConfig", "BLTPatcherConfig", "InitStdFactor", "PatchingModeEnum"] + diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index bb388a5e6fd1..fbc8cbdd2a80 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -550,20 +550,21 @@ def __init__(self, config: BLTConfig): self.tok_embeddings = nn.Embedding(config.vocab_size + config.pm_size, config.dim_local_encoder) - # Initialize cross attention layers as None (will be set if needed) + # Initialize cross attention layers only if cross attention is enabled self.cross_attn_layers = None - self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = config.n_layers_local_encoder if config.cross_attn_all_layers_encoder else 1 - for _ in range(layers_to_add): - self.cross_attn_layers.append( - BLTCrossAttention( - dim=config.dim_local_encoder, - head_dim=config.dim_local_encoder // config.cross_attn_nheads, - n_heads=config.cross_attn_nheads, - n_kv_heads=config.cross_attn_nheads, - norm_eps=config.norm_eps, + if getattr(config, "cross_attn_encoder", False) and config.cross_attn_nheads is not None: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = config.n_layers_local_encoder if config.cross_attn_all_layers_encoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention( + dim=config.dim_local_encoder, + head_dim=config.dim_local_encoder // config.cross_attn_nheads, + n_heads=config.cross_attn_nheads, + n_kv_heads=config.cross_attn_nheads, + norm_eps=config.norm_eps, + ) ) - ) def forward( self, @@ -679,12 +680,15 @@ def patch_reduce(self, h, max_num_patches, reduction, patch_ids): class BLTLocalDecoder(nn.Module): def __init__(self, config: BLTConfig): super().__init__() - self.config = config self.layers = nn.ModuleList([BLTTransformerLayer(config.dim_local_decoder, config.n_heads_local_decoder, config) for _ in range(config.n_layers_local_decoder)]) - self.rotary_emb = BLTRotaryEmbedding(config=config) + decoder_config = config + decoder_config.head_dim = config.dim_local_decoder // config.n_heads_local_decoder + decoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen + + self.rotary_emb = BLTRotaryEmbedding(config=decoder_config) self.pos_embeddings = None @@ -698,20 +702,21 @@ def __init__(self, config: BLTConfig): self.norm = RMSNorm(config.dim_local_decoder, eps=config.norm_eps) - # Initialize cross attention layers as None (will be set if needed) + # Initialize cross attention layers only if cross attention is enabled self.cross_attn_layers = None - self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = config.n_layers_local_decoder if config.cross_attn_all_layers_decoder else 1 - for _ in range(layers_to_add): - self.cross_attn_layers.append( - BLTCrossAttention( - dim=config.dim_local_decoder, - head_dim=config.dim_local_decoder // config.cross_attn_nheads, - n_heads=config.cross_attn_nheads, - n_kv_heads=config.cross_attn_nheads, - norm_eps=config.norm_eps, + if getattr(config, "cross_attn_decoder", False) and config.cross_attn_nheads is not None: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = config.n_layers_local_decoder if config.cross_attn_all_layers_decoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention( + dim=config.dim_local_decoder, + head_dim=config.dim_local_decoder // config.cross_attn_nheads, + n_heads=config.cross_attn_nheads, + n_kv_heads=config.cross_attn_nheads, + norm_eps=config.norm_eps, + ) ) - ) self.output = nn.Linear(config.dim_local_decoder, config.vocab_size, bias=False) @@ -750,6 +755,8 @@ def forward( cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): bs, seqlen = tokens.shape + batch_size, seq_length, _ = embeds.shape + assert embeds is not None, "Embeddings must be provided" if mask is None: @@ -761,12 +768,8 @@ def forward( sliding_window=self.config.sliding_window, ) - h = embeds - batch_size, seq_length, _ = embeds.shape - - if self.patch_embedding_projection is not None: assert patch_embeds is not None, "Patch embeddings must be passed." patch_embeds = self.patch_embedding_projection(patch_embeds) @@ -792,7 +795,6 @@ def forward( h = layer(h, position_embeddings=position_embeddings, attention_mask=None) - h_preds = self.norm(h) h_preds = F.dropout(h_preds, p=self.config.dropout, training=self.training) h_preds = self.output(h_preds) @@ -901,7 +903,10 @@ def __init__(self, config): self.layers.append(BLTTransformerLayer(self.config.dim_global, self.config.n_heads_global, config)) config.n_kv_heads = old - self.rotary_emb = BLTRotaryEmbedding(config=config) + global_config = config + global_config.head_dim = config.dim_global // config.n_heads_global + + self.rotary_emb = BLTRotaryEmbedding(config=global_config) self.token_embedding_projection = None if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != config.dim_global: @@ -1046,7 +1051,7 @@ def _init_weights(self, module): module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5) elif isinstance(module, BLTPatcher): - emb_std = module.config.patcher_dim ** (-0.5) + emb_std = module.config.dim ** (-0.5) module.tok_embeddings._custom_std = emb_std module.output._custom_std = emb_std @@ -1251,42 +1256,28 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> class BLTPatcher(BLTPreTrainedModel): def __init__(self, config): - config.num_attention_heads = config.patcher_n_heads - config.hidden_size = config.patcher_dim - - - config.n_heads = config.patcher_n_heads - - - config.num_key_value_heads = 12 # config.num_key_value_heads #TODO: add patcher_n_kv_heads - - config.head_dim = 64 #self.config.patcher_head_dim #TODO: add - super().__init__(config) - - self.config.hidden_size = self.config.patcher_dim - self.config.n_heads = self.config.patcher_n_heads - self.config.head_dim = 64 #self.config.patcher_head_dim #TODO: add + # Store reference to main config for accessing non-patcher settings + self.main_config = config + + # Initialize with patcher config directly + super().__init__(config.patcher_config) - # Create a patcher-specific config copy to use patcher_rope_theta and simple rope - import copy - patcher_config = copy.deepcopy(config) - patcher_config.rope_theta = config.patcher_rope_theta - patcher_config.rope_scaling = {"rope_type": "default"} # Use simple default rope for patcher - self.rotary_emb = BLTRotaryEmbedding(config=patcher_config) + # Initialize rotary embeddings with patcher config + self.rotary_emb = BLTRotaryEmbedding(config=self.config) self.layers = nn.ModuleList() - for _ in range(config.patcher_n_layers): - self.layers.append(BLTTransformerLayer(config.patcher_dim, config.patcher_n_heads, config)) + for _ in range(self.config.n_layers): + self.layers.append(BLTTransformerLayer(self.config.dim, self.config.n_heads, self.config)) - #assert config.patcher_vocab_size > 0 + #assert self.config.vocab_size > 0 - self.tok_embeddings = torch.nn.Embedding(config.patcher_vocab_size, config.patcher_dim) + self.tok_embeddings = torch.nn.Embedding(self.config.vocab_size, self.config.dim) - self.norm = RMSNorm(config.patcher_dim, eps=config.patcher_norm_eps) + self.norm = RMSNorm(self.config.dim, eps=self.config.norm_eps) self.output = nn.Linear( - config.patcher_dim, - config.patcher_vocab_size, + self.config.dim, + self.config.vocab_size, bias=False, ) @@ -1304,7 +1295,7 @@ def forward( # Handle chunked processing for entropy calculation entropies = [] preds = [] - max_length = self.config.patcher_max_seqlen + max_length = self.config.max_seqlen batch_numel = max_length * patching_batch_size splits = torch.split(token_values.flatten(), batch_numel) From 8d4df991c71f1d65cf85d4e853692c08b335120d Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 20 Jun 2025 08:36:15 +0000 Subject: [PATCH 022/139] rename vars to be more transformers-like --- .../models/blt_wip/modeling_blt.py | 484 ++++++++++-------- 1 file changed, 265 insertions(+), 219 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index fbc8cbdd2a80..54ae9116f80b 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -36,13 +36,13 @@ def cross_entropy(pred, target, **kwargs): def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." - bs, slen, n_kv_heads, head_dim = x.shape + batch_size, slen, n_kv_heads, head_dim = x.shape if n_rep == 1: return x return ( x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + .expand(batch_size, slen, n_kv_heads, n_rep, head_dim) + .reshape(batch_size, slen, n_kv_heads * n_rep, head_dim) ) class BLTMLP(nn.Module): @@ -89,13 +89,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return output -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -297,10 +290,10 @@ def forward( position_embeddings=position_embeddings ) - h = residual + hidden_states - h_norm = self.ffn_norm(h) - out = h + self.feed_forward(h_norm) - return out + hidden_states = residual + hidden_states + normalized_hidden_states = self.ffn_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(normalized_hidden_states) + return hidden_states def check_non_zero_after_zero(tensor): zero_mask = tensor == 0 @@ -314,7 +307,7 @@ def check_non_zero_after_zero(tensor): non_zero_after_zero = (tensor != 0) & shifted_mask return non_zero_after_zero.any() -def rolling_polynomial_hash(t, hash_func_nb: int = 0): +def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): primes = [ 1000000007, 5915587277, @@ -328,25 +321,25 @@ def rolling_polynomial_hash(t, hash_func_nb: int = 0): 5463458053, 3367900313, ] - prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) - prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) - return torch.sum(t * prime_powers, dim=-1) + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) + prime_powers = torch.stack([prime**i for i in range(token_tensor.shape[-1])]) + return torch.sum(token_tensor * prime_powers, dim=-1) -def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): +def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): """ - Returns a hash of the input x and maps it to a value in the range [0, max_hash]. + Returns a hash of the input token_ids and maps it to a value in the range [0, max_hash]. - expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. + expects: token_ids of shape (batch_size, seq_len) with values as ids in the token vocab. returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. Note: max hash can make a big difference on the number of collisions. """ with torch.no_grad(): - bs, seq_len = x.shape - prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) - x = torch.cat([prefix, x], dim=1) - windows = x.unfold(1, group_size, 1) + batch_size, seq_len = token_ids.shape + prefix = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) + token_ids = torch.cat([prefix, token_ids], dim=1) + windows = token_ids.unfold(1, group_size, 1) # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) hashes = rolling_polynomial_hash(windows, hash_func_nb) hash_values_range = hashes % max_hash @@ -356,32 +349,32 @@ def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: def create_patch_mask_from_ids(patch_ids, num_patches, window=None, patches_as_queries=False): """ - Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) + Creates a tensor of shape [batch_size, seq_len, num_patches] where each element at position (i, j, k) is True if the patch id at position (i, j) is less than or equal to k. Args: - patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. + patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. num_patches (int): Total number of patches. window (int): If not None, only considers patches within a window of size window. patches_as_queries (bool): If True, the patches are used as queries Returns: - torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. + torch.Tensor: Tensor of shape [batch_size, q_len, kv_len] with the desired mask. """ - bs, seq_len = patch_ids.shape + batch_size, seq_len = patch_ids.shape if not patches_as_queries: - q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) + q_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) kv_ids = ( torch.arange(num_patches, device=patch_ids.device) .unsqueeze(0) .unsqueeze(0) - .expand(bs, seq_len, num_patches) + .expand(batch_size, seq_len, num_patches) ) else: - kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) + kv_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) q_ids = ( torch.arange(num_patches, device=patch_ids.device) .unsqueeze(0) .unsqueeze(-1) - .expand(bs, num_patches, seq_len) + .expand(batch_size, num_patches, seq_len) ) if window is None: mask = q_ids == kv_ids @@ -399,7 +392,7 @@ def cross_attn_mask( window=None, block_mask=True, ): - bs = patch_ids.shape[0] + batch_size = patch_ids.shape[0] with torch.no_grad(): # Create the patch mask cross_mask = create_patch_mask_from_ids( @@ -411,19 +404,19 @@ def cross_attn_mask( q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k assert cross_mask.shape == ( - bs, + batch_size, q_len, kv_len, - ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" + ), f"{cross_mask.shape} != {(batch_size, q_len, kv_len)}" block_mask = None if block_mask: - def patch_mask(b, h, q_idx, kv_idx): + def patch_mask(b, num_heads, q_idx, kv_idx): return cross_mask[b, q_idx, kv_idx] block_mask = create_block_mask( patch_mask, - B=bs, + B=batch_size, H=None, Q_LEN=q_len, KV_LEN=kv_len, @@ -433,7 +426,7 @@ def patch_mask(b, h, q_idx, kv_idx): else: return torch.where(cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).unsqueeze( 1 - ) # [bs, 1, q_len, kv_len] + ) # [batch_size, 1, q_len, kv_len] def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor: @@ -531,38 +524,53 @@ def forward(self, x, position_ids): class BLTLocalEncoder(nn.Module): def __init__(self, config: BLTConfig): super().__init__() - self.config = config #TODO: rm this - + + # Extract config values to instance attributes self.dropout = config.dropout + self.dim_local_encoder = config.dim_local_encoder + self.n_layers_local_encoder = config.n_layers_local_encoder + self.n_heads_local_encoder = config.n_heads_local_encoder + self.vocab_size = config.vocab_size + self.pm_size = config.pm_size + self.cross_attn_encoder = getattr(config, "cross_attn_encoder", False) + self.cross_attn_nheads = config.cross_attn_nheads + self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder + self.cross_attn_init_by_pooling = getattr(config, "cross_attn_init_by_pooling", False) + self.cross_attn_k = getattr(config, "cross_attn_k", 1) + self.norm_eps = config.norm_eps + self.sliding_window = getattr(config, "sliding_window", None) - self.layers = nn.ModuleList([BLTTransformerLayer(config.dim_local_encoder, config.n_heads_local_encoder, config) for _ in range(config.n_layers_local_encoder)]) + self.layers = nn.ModuleList([BLTTransformerLayer(self.dim_local_encoder, self.n_heads_local_encoder, config) for _ in range(self.n_layers_local_encoder)]) - self.rotary_emb = BLTRotaryEmbedding(config=config) - self.pos_embeddings = None + # Set up config for rotary embedding + encoder_config = config + encoder_config.head_dim = self.dim_local_encoder // self.n_heads_local_encoder + encoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen + self.rotary_emb = BLTRotaryEmbedding(config=encoder_config) self.token_embedding_projection = ( - nn.Linear(config.encoder_dim_token_emb, config.dim_local_encoder, bias=False) - if config.encoder_dim_token_emb is not None and config.encoder_dim_token_emb != config.dim_local_encoder + nn.Linear(config.encoder_dim_token_emb, self.dim_local_encoder, bias=False) + if config.encoder_dim_token_emb is not None and config.encoder_dim_token_emb != self.dim_local_encoder else None ) self.patch_embedding_projection = self._create_patch_projection(config) - self.tok_embeddings = nn.Embedding(config.vocab_size + config.pm_size, config.dim_local_encoder) + self.tok_embeddings = nn.Embedding(self.vocab_size + self.pm_size, self.dim_local_encoder) # Initialize cross attention layers only if cross attention is enabled self.cross_attn_layers = None - if getattr(config, "cross_attn_encoder", False) and config.cross_attn_nheads is not None: + if self.cross_attn_encoder and self.cross_attn_nheads is not None: self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = config.n_layers_local_encoder if config.cross_attn_all_layers_encoder else 1 + layers_to_add = self.n_layers_local_encoder if self.cross_attn_all_layers_encoder else 1 for _ in range(layers_to_add): self.cross_attn_layers.append( BLTCrossAttention( - dim=config.dim_local_encoder, - head_dim=config.dim_local_encoder // config.cross_attn_nheads, - n_heads=config.cross_attn_nheads, - n_kv_heads=config.cross_attn_nheads, - norm_eps=config.norm_eps, + dim=self.dim_local_encoder, + head_dim=self.dim_local_encoder // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=self.norm_eps, ) ) @@ -578,7 +586,7 @@ def forward( cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): """ """ - bs, seqlen = input_ids.shape + batch_size, sequence_length = input_ids.shape if input_embeds is None: input_embeds = self.embed_tokens(input_ids) @@ -590,39 +598,40 @@ def forward( batch_size=batch_size, device=input_embeds.device, dtype=input_embeds.dtype, - sliding_window=self.config.sliding_window, + sliding_window=self.sliding_window, ) - h = input_embeds + hidden_states = input_embeds - h_residual = input_embeds - h = nn.functional.dropout(h, p=self.dropout, training=self.training) + residual_hidden_states = input_embeds + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(h, position_ids) + position_embeddings = self.rotary_emb(hidden_states, position_ids) - h = F.dropout(h, p=self.config.dropout, training=self.training) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) for idx, layer in enumerate(self.layers): - h = layer(h, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) - if getattr(self.config, "cross_attn_encoder", None) and (idx == len(self.layers) - 1 or self.config.cross_attn_all_layers_encoder): - if self.config.cross_attn_init_by_pooling and patch_embeds is None: - patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) + if self.cross_attn_encoder and (idx == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): + # Initialize patch_embeds if not provided when cross attention is enabled + if patch_embeds is None: + patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) if self.patch_embedding_projection is not None: patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * getattr(self.config, "cross_attn_k", 1), self.config.dim_local_encoder) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.dim_local_encoder) - layer_idx = idx if self.config.cross_attn_all_layers_encoder else 0 - patch_embeds_cross = self.cross_attn_layers[layer_idx]( - x=patch_embeds, - kv=h, + layer_idx = idx if self.cross_attn_all_layers_encoder else 0 + cross_attention_output = self.cross_attn_layers[layer_idx]( + query_states=patch_embeds, + kv=hidden_states, mask=cross_mask, ) - patch_embeds = patch_embeds + patch_embeds_cross + patch_embeds = patch_embeds + cross_attention_output - h_residual = patch_embeds if getattr(self.config, "cross_attn_encoder", None) else None - return (h, h_residual), cache + encoder_cross_states = patch_embeds if self.cross_attn_encoder else None + return (hidden_states, encoder_cross_states), cache def _create_patch_projection(self, config): dimension_mismatch = config.encoder_dim_patch_emb is not None and config.encoder_dim_patch_emb != config.dim_local_encoder @@ -642,14 +651,14 @@ def _create_patch_projection(self, config): bias=False, ) - def embed_tokens(self, tokens, embeds): + def embed_tokens(self, tokens, embeds=None): if embeds is not None: assert self.config.encoder_hash_byte_group_size is not None, "Not expecting embeddings to be passed." return embeds else: return self.tok_embeddings(tokens) - def patch_reduce(self, h, max_num_patches, reduction, patch_ids): + def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): """ Reduce variable length patches to single embedding per patch Note: this works with variable number of patches for different sequences in the batch @@ -660,65 +669,75 @@ def patch_reduce(self, h, max_num_patches, reduction, patch_ids): (i.e. if the sum(patch_lengths[i]) < seq_len for any i) will be sent to a dummy patch, which is trimmed before returning. """ - bs, seq_len, emb_dim = h.shape + batch_size, seq_len, embedding_dim = hidden_states.shape - patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) - reduced_embs = torch.zeros((bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device) - reduced_embs = reduced_embs.scatter_reduce( - src=h, + reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device) + reduced_embeddings = reduced_embeddings.scatter_reduce( + src=hidden_states, dim=1, index=patch_ids, reduce=reduction, include_self=False, ) - reduced_embs = reduced_embs[:, :max_num_patches, :] + reduced_embeddings = reduced_embeddings[:, :max_num_patches, :] - return reduced_embs + return reduced_embeddings class BLTLocalDecoder(nn.Module): def __init__(self, config: BLTConfig): super().__init__() - self.config = config - self.layers = nn.ModuleList([BLTTransformerLayer(config.dim_local_decoder, config.n_heads_local_decoder, config) for _ in range(config.n_layers_local_decoder)]) + # Extract config values to instance attributes + self.dim_local_decoder = config.dim_local_decoder + self.n_heads_local_decoder = config.n_heads_local_decoder + self.n_layers_local_decoder = config.n_layers_local_decoder + self.vocab_size = config.vocab_size + self.norm_eps = config.norm_eps + self.dropout = config.dropout + self.cross_attn_decoder = getattr(config, "cross_attn_decoder", False) + self.cross_attn_nheads = config.cross_attn_nheads + self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder + self.cross_attn_k = getattr(config, "cross_attn_k", 1) + self.sliding_window = getattr(config, "sliding_window", None) + + self.layers = nn.ModuleList([BLTTransformerLayer(self.dim_local_decoder, self.n_heads_local_decoder, config) for _ in range(self.n_layers_local_decoder)]) decoder_config = config - decoder_config.head_dim = config.dim_local_decoder // config.n_heads_local_decoder + decoder_config.head_dim = self.dim_local_decoder // self.n_heads_local_decoder decoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen self.rotary_emb = BLTRotaryEmbedding(config=decoder_config) - self.pos_embeddings = None - self.token_embedding_projection = ( - nn.Linear(config.decoder_dim_token_emb, config.dim_local_decoder, bias=False) - if config.decoder_dim_token_emb is not None and config.decoder_dim_token_emb != config.dim_local_decoder + nn.Linear(config.decoder_dim_token_emb, self.dim_local_decoder, bias=False) + if config.decoder_dim_token_emb is not None and config.decoder_dim_token_emb != self.dim_local_decoder else None ) self.patch_embedding_projection = self._create_patch_projection(config) - self.norm = RMSNorm(config.dim_local_decoder, eps=config.norm_eps) + self.norm = RMSNorm(self.dim_local_decoder, eps=self.norm_eps) # Initialize cross attention layers only if cross attention is enabled self.cross_attn_layers = None - if getattr(config, "cross_attn_decoder", False) and config.cross_attn_nheads is not None: + if self.cross_attn_decoder and self.cross_attn_nheads is not None: self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = config.n_layers_local_decoder if config.cross_attn_all_layers_decoder else 1 + layers_to_add = self.n_layers_local_decoder if self.cross_attn_all_layers_decoder else 1 for _ in range(layers_to_add): self.cross_attn_layers.append( BLTCrossAttention( - dim=config.dim_local_decoder, - head_dim=config.dim_local_decoder // config.cross_attn_nheads, - n_heads=config.cross_attn_nheads, - n_kv_heads=config.cross_attn_nheads, - norm_eps=config.norm_eps, + dim=self.dim_local_decoder, + head_dim=self.dim_local_decoder // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=self.norm_eps, ) ) - self.output = nn.Linear(config.dim_local_decoder, config.vocab_size, bias=False) + self.output = nn.Linear(self.dim_local_decoder, self.vocab_size, bias=False) def _create_patch_projection(self, config): dimension_mismatch = config.dim_global is not None and config.dim_global != config.dim_local_decoder @@ -754,7 +773,7 @@ def forward( cross_mask: Optional[torch.Tensor] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - bs, seqlen = tokens.shape + batch_size, sequence_length = tokens.shape batch_size, seq_length, _ = embeds.shape assert embeds is not None, "Embeddings must be provided" @@ -765,41 +784,41 @@ def forward( batch_size=batch_size, device=embeds.device, dtype=embeds.dtype, - sliding_window=self.config.sliding_window, + sliding_window=self.sliding_window, ) - h = embeds + hidden_states = embeds if self.patch_embedding_projection is not None: assert patch_embeds is not None, "Patch embeddings must be passed." patch_embeds = self.patch_embedding_projection(patch_embeds) - if getattr(self.config, "cross_attn_k", None) is not None: - patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.dim_local_decoder) + if self.cross_attn_k is not None: + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.dim_local_decoder) - if patch_embeds is not None and not getattr(self.config, "cross_attn_decoder", None): - h = h + patch_embeds + if patch_embeds is not None and not self.cross_attn_decoder: + hidden_states = hidden_states + patch_embeds position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(h, position_ids) + position_embeddings = self.rotary_emb(hidden_states, position_ids) - h = F.dropout(h, p=self.config.dropout, training=self.training) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): - if getattr(self.config, "cross_attn_decoder", None) and (i == 0 or self.config.cross_attn_all_layers_decoder): - # Use cross attention to extract info from patch_embeds into h - h_cross = self.cross_attn_layers[i]( - x=h, + if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder): + # Use cross attention to extract info from patch_embeds into hidden_states + cross_attention_output = self.cross_attn_layers[i]( + query_states=hidden_states, kv=patch_embeds, mask=cross_mask, ) - h = h + h_cross + hidden_states = hidden_states + cross_attention_output - h = layer(h, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) - h_preds = self.norm(h) - h_preds = F.dropout(h_preds, p=self.config.dropout, training=self.training) - h_preds = self.output(h_preds) - h_preds = h_preds.float() - return h_preds, cache + logits = self.norm(hidden_states) + logits = F.dropout(logits, p=self.dropout, training=self.training) + logits = self.output(logits) + logits = logits.float() + return logits, cache class BLTCrossAttention(nn.Module): @@ -847,39 +866,43 @@ def __init__( def forward( self, - x: torch.Tensor, + query_states: torch.Tensor, kv: torch.Tensor, mask: Optional[Union[BlockMask, str]] = None, ) -> torch.Tensor: # B S D - bsz, seq_len, _ = x.shape - _, slen_kv, _ = kv.shape - x_norm = self.cross_attn_norm_q(x) + batch_size, seq_len, _ = query_states.shape + _, kv_seq_len, _ = kv.shape + + # Store original input for residual connection + residual = query_states + + query_norm = self.cross_attn_norm_q(query_states) kv = self.cross_attn_norm_kv(kv) - xq = self.wq(x_norm) - xk = self.wk(kv) - xv = self.wv(kv) + query_proj = self.wq(query_norm) + key_states = self.wk(kv) + value_states = self.wv(kv) - output_shape = xq.shape + output_shape = query_proj.shape # B S D -> B S H D - xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) - xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) - xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + query_proj = query_proj.view(batch_size, seq_len, self.n_heads, self.head_dim) + key_states = key_states.view(batch_size, kv_seq_len, self.n_kv_heads, self.head_dim) + value_states = value_states.view(batch_size, kv_seq_len, self.n_kv_heads, self.head_dim) - xk = repeat_kv(xk, self.heads_per_group, dim=2) - xv = repeat_kv(xv, self.heads_per_group, dim=2) + key_states = repeat_kv(key_states, self.heads_per_group, dim=2) + value_states = repeat_kv(value_states, self.heads_per_group, dim=2) # assert mask is None or isinstance(mask, BlockMask) - xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) - # output = flex_attention_comp(xq, xk, xv, block_mask=mask) + query_proj, key_states, value_states = (e.transpose(1, 2) for e in (query_proj, key_states, value_states)) + # output = flex_attention_comp(query_proj, key_states, value_states, block_mask=mask) is_causal = (mask == "causal") if isinstance(mask, str) else False mask = mask if isinstance(mask, torch.Tensor) else None - mask = mask.to(dtype=xq.dtype).to(xq.device) + mask = mask.to(dtype=query_proj.dtype).to(query_proj.device) output = F.scaled_dot_product_attention( - xq, - xk, - xv, + query_proj, + key_states, + value_states, is_causal=is_causal, attn_mask=mask, ) @@ -887,32 +910,36 @@ def forward( output = self.wo(output.reshape(output_shape)) - return x + output + return residual + output class BLTGlobalTransformer(nn.Module): def __init__(self, config): super().__init__() - self.config = config + # Extract config values to instance attributes + self.dim_global = config.dim_global + self.n_heads_global = config.n_heads_global + self.n_layers_global = config.n_layers_global + self.dropout = config.dropout self.layers = nn.ModuleList() old = config.n_kv_heads config.n_kv_heads = config.n_kv_heads_global - for _ in range(config.n_layers_global): - self.layers.append(BLTTransformerLayer(self.config.dim_global, self.config.n_heads_global, config)) + for _ in range(self.n_layers_global): + self.layers.append(BLTTransformerLayer(self.dim_global, self.n_heads_global, config)) config.n_kv_heads = old global_config = config - global_config.head_dim = config.dim_global // config.n_heads_global + global_config.head_dim = self.dim_global // self.n_heads_global self.rotary_emb = BLTRotaryEmbedding(config=global_config) self.token_embedding_projection = None - if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != config.dim_global: + if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim_global: self.token_embedding_projection = nn.Linear( config.global_dim_patch_emb, - config.dim_global, + self.dim_global, bias=False, ) @@ -926,20 +953,20 @@ def forward( ): batch_size, seq_length, _ = input_embeds.shape - h = input_embeds + hidden_states = input_embeds - if self.token_embedding_projection is not None and h.shape[-1] != self.config.dim_global: - h = self.token_embedding_projection(h) + if self.token_embedding_projection is not None and hidden_states.shape[-1] != self.dim_global: + hidden_states = self.token_embedding_projection(hidden_states) - h = F.dropout(h, p=self.config.dropout, training=self.training) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(h, position_ids) + position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, layer in enumerate(self.layers): - h = layer(h, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) - return h, cache + return hidden_states, cache def compute_hash_embeddings( @@ -1020,11 +1047,6 @@ def _init_weights(self, module): a=-3 * std, b=3 * std, ) - - # elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)): - # nn.init.ones_(module.weight) - # if module.bias is not None: - # nn.init.zeros_(module.bias) elif isinstance(module, BLTModel): if module.encoder_hash_tok_embedding is not None: @@ -1060,7 +1082,23 @@ class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) - self.config = config + # Extract frequently used config values + self.patch_in_forward = config.patch_in_forward + self.patching_mode = config.patching_mode + self.patch_size = config.patch_size + self.patching_threshold = config.patching_threshold + self.max_patch_length = config.max_patch_length + self.patching_batch_size = config.patching_batch_size + self.patching_device = config.patching_device + self.cross_attn_encoder = getattr(config, "cross_attn_encoder", False) + self.cross_attn_decoder = getattr(config, "cross_attn_decoder", False) + self.cross_attn_k = getattr(config, "cross_attn_k", None) + self.cross_attn_window_encoder = getattr(config, "cross_attn_window_encoder", None) + self.cross_attn_window_decoder = getattr(config, "cross_attn_window_decoder", None) + self.cross_attn_use_flex_attention = getattr(config, "cross_attn_use_flex_attention", True) + self.boe_id = config.boe_id + self.eos_token_id = config.eos_token_id + self.local_encoder = BLTLocalEncoder(config) self.global_transformer = BLTGlobalTransformer(config) self.local_decoder = BLTLocalDecoder(config) @@ -1071,7 +1109,7 @@ def __init__(self, config: BLTConfig): encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, ) - if config.patch_in_forward: + if self.patch_in_forward: self.patcher = BLTPatcher(config) self.patcher.eval() for param in self.patcher.parameters(): @@ -1087,36 +1125,36 @@ def forward( # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings # are no longer used in the final BLT model - bs, N = tokens.shape # Batch size and sequence length + batch_size, sequence_length = tokens.shape # Batch size and sequence length local_encoder_tokens, local_decoder_tokens = tokens, tokens # Patching if patch_lengths is None: # assert ( - # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward + # getattr(self, "patch_in_forward", None) is not None and self.patch_in_forward # ), "Patch in forward not enabled and no patch_lengths passed." # PATCHER MODEL DEFINED - if self.config.patching_mode == PatchingModeEnum.entropy: + if self.patching_mode == PatchingModeEnum.entropy: _, patch_lengths, _ = self.patcher( local_encoder_tokens, - patch_size=self.config.patch_size, + patch_size=self.patch_size, include_next_token=True, - threshold=self.config.patching_threshold, - max_patch_length=self.config.max_patch_length, - patching_batch_size=self.config.patching_batch_size, - device=self.config.patching_device, + threshold=self.patching_threshold, + max_patch_length=self.max_patch_length, + patching_batch_size=self.patching_batch_size, + device=self.patching_device, ) else: - # self.config.patching_mode == PatchingModeEnum.byte - bs, seq_len = local_encoder_tokens.shape + # self.patching_mode == PatchingModeEnum.byte + batch_size_tokens, seq_len = local_encoder_tokens.shape seq_len_next_tok = seq_len + 1 # include_next_token=True patch_lengths = torch.ones( - (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + (batch_size_tokens, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device ) - patch_lengths = process_patch_lengths(patch_lengths, self.config.max_patch_length) + patch_lengths = process_patch_lengths(patch_lengths, self.max_patch_length) #assert torch.min(patch_lengths) >= 0 # Generate patch IDs from patch_lengths @@ -1127,15 +1165,15 @@ def forward( cross_attn_mask_enc = None # Cross-attention encoder - if self.config.cross_attn_encoder: + if self.cross_attn_encoder: cross_attn_mask_enc = cross_attn_mask( patch_ids, patch_lengths, - N, + sequence_length, patches_as_queries=True, - cross_attn_k=self.config.cross_attn_k, - window=self.config.cross_attn_window_encoder, - block_mask=self.config.cross_attn_use_flex_attention, + cross_attn_k=self.cross_attn_k, + window=self.cross_attn_window_encoder, + block_mask=self.cross_attn_use_flex_attention, ) # Hashing and embedding @@ -1152,7 +1190,7 @@ def forward( # The final BLT model uses only hash-based n-gram embeddings # Local encoder - (h_encoder, h_cross), cache_encoder = self.local_encoder( + (encoder_hidden_states, encoder_cross_states), cache_encoder = self.local_encoder( input_ids=local_encoder_tokens, input_embeds=local_encoder_embeds, patch_embeds=None, @@ -1162,50 +1200,58 @@ def forward( ) # Downsampling - h = h_cross.view(bs, patch_lengths.shape[1], -1) + if encoder_cross_states is not None: + # Cross attention is enabled - use cross states + global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) + else: + # Cross attention is disabled - use reduced embeddings from encoder hidden states + global_hidden_states = self.local_encoder.patch_reduce( + encoder_hidden_states, patch_lengths.shape[1], "amax", patch_ids + ) # Global transformer - global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.config.boe_id) - rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id) + global_tokens = tokens.new(global_hidden_states.shape[0], global_hidden_states.shape[1]).fill_(self.boe_id) + rows, cols = torch.where(local_encoder_tokens == self.eos_token_id) eos_patch_ids = patch_ids[rows, cols] - global_tokens[rows, eos_patch_ids] = self.config.eos_token_id + global_tokens[rows, eos_patch_ids] = self.eos_token_id - h, _ = self.global_transformer( - input_embeds=h, + global_hidden_states, _ = self.global_transformer( + input_embeds=global_hidden_states, input_ids=global_tokens, ) # Unpatching - dec_embeds = h_encoder + decoder_embeds = encoder_hidden_states # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens), so we need to map decoder positions to the remaining patches. decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], local_decoder_tokens.shape[-1]) - # assert torch.max(decoder_patch_ids) + 1 <= h.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" - # assert decoder_patch_ids.shape[1] == dec_embeds.shape[1], ( - # f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" + # assert torch.max(decoder_patch_ids) + 1 <= global_hidden_states.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {global_hidden_states.shape[1]}" + # assert decoder_patch_ids.shape[1] == decoder_embeds.shape[1], ( + # f"{decoder_patch_ids.shape[1]} != {decoder_embeds.shape[1]}" # ) # Cross-attention decoder - if not self.config.cross_attn_decoder: - h = torch.gather(h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])) + if not self.cross_attn_decoder: + patch_hidden_states = torch.gather(global_hidden_states, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, global_hidden_states.shape[-1])) cross_attn_mask_dec = None - # assert local_decoder_tokens.shape == h.shape[:-1] + # assert local_decoder_tokens.shape == patch_hidden_states.shape[:-1] else: + patch_hidden_states = global_hidden_states cross_attn_mask_dec = cross_attn_mask( decoder_patch_ids, patch_lengths, - N, + sequence_length, patches_as_queries=False, - cross_attn_k=self.config.cross_attn_k, - window=self.config.cross_attn_window_decoder, - block_mask=self.config.cross_attn_use_flex_attention, + cross_attn_k=self.cross_attn_k, + window=self.cross_attn_window_decoder, + block_mask=self.cross_attn_use_flex_attention, ) # Local decoder output, _ = self.local_decoder( - embeds=dec_embeds, - patch_embeds=h, + embeds=decoder_embeds, + patch_embeds=patch_hidden_states, tokens=local_decoder_tokens, cross_mask=cross_attn_mask_dec, ) @@ -1294,7 +1340,7 @@ def forward( # Handle chunked processing for entropy calculation entropies = [] - preds = [] + predictions = [] max_length = self.config.max_seqlen batch_numel = max_length * patching_batch_size splits = torch.split(token_values.flatten(), batch_numel) @@ -1308,7 +1354,7 @@ def forward( split = split.to(device) # Process chunk: embeddings -> layers -> output - bsz, seqlen = split.shape + batch_size, sequence_length = split.shape input_embeds = self.tok_embeddings(split) hidden_states = input_embeds @@ -1323,18 +1369,18 @@ def forward( for i, layer in enumerate(self.layers): hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) #, attn_impl=self.config.patcher_attn_impl ) - pred = self.output(self.norm(hidden_states)) - pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] - preds.append(pred) - pred_entropies = self.entropy(pred) - entropies.append(pred_entropies) + logits = self.output(self.norm(hidden_states)) + logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] + predictions.append(logits) + prediction_entropies = self.entropy(logits) + entropies.append(prediction_entropies) concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) - concat_preds = torch.cat(preds, dim=0).reshape(token_values.shape[0], -1) + concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) # Always compute patch lengths from concatenated entropies - bs, seq_len = token_values.shape - seq_len_next_tok = seq_len + 1 if include_next_token else seq_len + batch_size, sequence_length = token_values.shape + seq_len_next_tok = sequence_length + 1 if include_next_token else sequence_length # Find patch start IDs based on entropy if patch_size is not None: @@ -1347,17 +1393,17 @@ def forward( patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok) else: # Default to byte-level patching - patch_lengths = torch.ones((bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) + patch_lengths = torch.ones((batch_size, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) - return concat_entropies, patch_lengths, concat_preds + return concat_entropies, patch_lengths, concat_predictions @staticmethod def entropy(scores): """ - scores: [bs, seq_len, vocab] - returns [bs, seq_len] + scores: [batch_size, seq_len, vocab] + returns [batch_size, seq_len] Computes the entropy for each token in the batch. Note: uses natural log. @@ -1370,26 +1416,26 @@ def entropy(scores): @staticmethod def patch_start_ids_from_patch_start_mask(patch_start_mask): - bs, trunc_seq_len = patch_start_mask.shape + batch_size, trunc_seq_len = patch_start_mask.shape max_patches = patch_start_mask.sum(dim=1).max() if max_patches == 0: patch_start_ids = torch.full( - (bs, trunc_seq_len), + (batch_size, trunc_seq_len), trunc_seq_len, dtype=torch.long, device=patch_start_mask.device, ) else: - patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(bs, 1) + patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(batch_size, 1) extra_patch_ids = torch.full( - (bs, trunc_seq_len), + (batch_size, trunc_seq_len), trunc_seq_len, dtype=torch.long, device=patch_start_mask.device, ) all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1) - patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, trunc_seq_len)[:, :max_patches] + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(batch_size, trunc_seq_len)[:, :max_patches] return patch_start_ids @staticmethod @@ -1425,13 +1471,13 @@ def find_entropy_patch_start_ids( different sequences, but patches can be identified incrementally rather than decided globally using the entire sequence. """ - bs, seq_len = entropies.shape[:2] + batch_size, sequence_length = entropies.shape[:2] - first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(bs, 1) - preds_truncation_len = first_ids.shape[1] # remove the first preds because they will be start of patches. + first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + predictions_truncation_len = first_ids.shape[1] # remove the first predictions because they will be start of patches. entropies = entropies[:, 1:] if threshold is None: - num_patches = seq_len // patch_size + num_patches = sequence_length // patch_size patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices patch_start_ids = patch_start_ids.sort(dim=1).values else: @@ -1441,7 +1487,7 @@ def find_entropy_patch_start_ids( # patch_start_mask[1:] |= tokens[:-1] < OFFSET patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) - patch_start_ids = torch.cat((first_ids, patch_start_ids + preds_truncation_len), dim=1) + patch_start_ids = torch.cat((first_ids, patch_start_ids + predictions_truncation_len), dim=1) return patch_start_ids def init_hash_embeddings( From d938a2f19ba4eb156cf85bcf997a6848dd326b22 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 20 Jun 2025 08:47:50 +0000 Subject: [PATCH 023/139] rm unused functions --- .../models/blt_wip/modeling_blt.py | 57 ++----------------- 1 file changed, 4 insertions(+), 53 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 54ae9116f80b..ca6ff8c07b93 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -460,37 +460,6 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> return padded -def create_causal_mask_for_blt( - seqlen: int, - batch_size: int, - device: torch.device, - dtype: torch.dtype, - sliding_window: Optional[int] = None, -) -> torch.Tensor: - """ - Creates a causal mask for BLT local encoder. - """ - min_value = torch.finfo(dtype).min - mask = torch.full( - (batch_size, 1, seqlen, seqlen), # Note: using seqlen, not total_seqlen - min_value, - dtype=dtype, - device=device, - ) - - if sliding_window is not None: - # Create local causal mask with sliding window - for i in range(seqlen): - start_idx = max(0, i - sliding_window + 1) - mask[:, :, i, start_idx:i + 1] = 0 - else: - # Create full causal mask - mask = torch.triu(mask, diagonal=0) - mask = mask.masked_fill(mask == 0, min_value) - - return mask - - class BLTRotaryEmbedding(nn.Module): def __init__(self, config: BLTConfig, device=None): super().__init__() @@ -592,15 +561,6 @@ def forward( batch_size, seq_length, _ = input_embeds.shape - if mask is None: - attention_mask = create_causal_mask_for_blt( - seqlen=seq_length, - batch_size=batch_size, - device=input_embeds.device, - dtype=input_embeds.dtype, - sliding_window=self.sliding_window, - ) - hidden_states = input_embeds residual_hidden_states = input_embeds @@ -773,20 +733,11 @@ def forward( cross_mask: Optional[torch.Tensor] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - batch_size, sequence_length = tokens.shape - batch_size, seq_length, _ = embeds.shape + batch_size, _ = tokens.shape + batch_size, _, _ = embeds.shape assert embeds is not None, "Embeddings must be provided" - if mask is None: - attention_mask = create_causal_mask_for_blt( - seqlen=seq_length, - batch_size=batch_size, - device=embeds.device, - dtype=embeds.dtype, - sliding_window=self.sliding_window, - ) - hidden_states = embeds if self.patch_embedding_projection is not None: @@ -924,11 +875,11 @@ def __init__(self, config): self.dropout = config.dropout self.layers = nn.ModuleList() - old = config.n_kv_heads + # old = config.n_kv_heads config.n_kv_heads = config.n_kv_heads_global for _ in range(self.n_layers_global): self.layers.append(BLTTransformerLayer(self.dim_global, self.n_heads_global, config)) - config.n_kv_heads = old + # config.n_kv_heads = old global_config = config global_config.head_dim = self.dim_global // self.n_heads_global From 3bcfc03d4e1f08a312d8d9ac17cd03c4ddb49ce7 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 20 Jun 2025 10:21:47 +0000 Subject: [PATCH 024/139] adding cross attention from transformers --- .../models/blt_wip/modeling_blt.py | 260 +++++++++--------- 1 file changed, 130 insertions(+), 130 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index ca6ff8c07b93..3b6495431d50 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -2,9 +2,10 @@ import logging import os -from typing import List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union +from ...cache_utils import Cache + import torch import torch.nn import torch.nn as nn @@ -88,7 +89,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = self.w2(F.silu(x1) * x3) return output - def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -238,12 +238,12 @@ def __init__(self, dim, n_heads, config, layer_idx=0): # Extract parameters from dictionary dim = dim n_heads = n_heads - head_dim = getattr(config, "head_dim", None) - n_kv_heads = getattr(config, "n_kv_heads", None) - rope_theta = getattr(config, "rope_theta", None) - multiple_of = getattr(config, "multiple_of", 256) - ffn_dim_multiplier = getattr(config, "ffn_dim_multiplier", None) - norm_eps = getattr(config, "norm_eps", None) + head_dim = config.head_dim + n_kv_heads = config.n_kv_heads + rope_theta = config.rope_theta + multiple_of = config.multiple_of + ffn_dim_multiplier = config.ffn_dim_multiplier + norm_eps = config.norm_eps self.head_dim = head_dim or dim // n_heads self.n_heads = n_heads or dim // head_dim @@ -501,13 +501,13 @@ def __init__(self, config: BLTConfig): self.n_heads_local_encoder = config.n_heads_local_encoder self.vocab_size = config.vocab_size self.pm_size = config.pm_size - self.cross_attn_encoder = getattr(config, "cross_attn_encoder", False) + self.cross_attn_encoder = config.cross_attn_encoder self.cross_attn_nheads = config.cross_attn_nheads self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder - self.cross_attn_init_by_pooling = getattr(config, "cross_attn_init_by_pooling", False) - self.cross_attn_k = getattr(config, "cross_attn_k", 1) + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_k = config.cross_attn_k self.norm_eps = config.norm_eps - self.sliding_window = getattr(config, "sliding_window", None) + self.sliding_window = config.sliding_window self.layers = nn.ModuleList([BLTTransformerLayer(self.dim_local_encoder, self.n_heads_local_encoder, config) for _ in range(self.n_layers_local_encoder)]) @@ -532,15 +532,9 @@ def __init__(self, config: BLTConfig): if self.cross_attn_encoder and self.cross_attn_nheads is not None: self.cross_attn_layers = torch.nn.ModuleList() layers_to_add = self.n_layers_local_encoder if self.cross_attn_all_layers_encoder else 1 - for _ in range(layers_to_add): + for layer_idx in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention( - dim=self.dim_local_encoder, - head_dim=self.dim_local_encoder // self.cross_attn_nheads, - n_heads=self.cross_attn_nheads, - n_kv_heads=self.cross_attn_nheads, - norm_eps=self.norm_eps, - ) + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.dim_local_encoder) ) def forward( @@ -559,11 +553,10 @@ def forward( if input_embeds is None: input_embeds = self.embed_tokens(input_ids) - batch_size, seq_length, _ = input_embeds.shape + batch_size, _, _ = input_embeds.shape hidden_states = input_embeds - residual_hidden_states = input_embeds hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) @@ -583,10 +576,13 @@ def forward( patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.dim_local_encoder) layer_idx = idx if self.cross_attn_all_layers_encoder else 0 - cross_attention_output = self.cross_attn_layers[layer_idx]( - query_states=patch_embeds, - kv=hidden_states, - mask=cross_mask, + cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( + hidden_states=patch_embeds, + cross_attention_states=hidden_states, + attention_mask=cross_mask, + output_attentions=False, + use_cache=False, + cache_position=None, ) patch_embeds = patch_embeds + cross_attention_output @@ -603,7 +599,7 @@ def _create_patch_projection(self, config): if not (dimension_mismatch or cross_attn_conditions): return None - output_dim = config.encoder_dim_token_emb * (getattr(config, "cross_attn_k", None) or 1) + output_dim = config.encoder_dim_token_emb * config.cross_attn_k return nn.Linear( in_features=config.encoder_dim_patch_emb, @@ -657,11 +653,11 @@ def __init__(self, config: BLTConfig): self.vocab_size = config.vocab_size self.norm_eps = config.norm_eps self.dropout = config.dropout - self.cross_attn_decoder = getattr(config, "cross_attn_decoder", False) + self.cross_attn_decoder = config.cross_attn_decoder self.cross_attn_nheads = config.cross_attn_nheads self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder - self.cross_attn_k = getattr(config, "cross_attn_k", 1) - self.sliding_window = getattr(config, "sliding_window", None) + self.cross_attn_k = config.cross_attn_k + self.sliding_window = config.sliding_window self.layers = nn.ModuleList([BLTTransformerLayer(self.dim_local_decoder, self.n_heads_local_decoder, config) for _ in range(self.n_layers_local_decoder)]) @@ -686,15 +682,9 @@ def __init__(self, config: BLTConfig): if self.cross_attn_decoder and self.cross_attn_nheads is not None: self.cross_attn_layers = torch.nn.ModuleList() layers_to_add = self.n_layers_local_decoder if self.cross_attn_all_layers_decoder else 1 - for _ in range(layers_to_add): + for layer_idx in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention( - dim=self.dim_local_decoder, - head_dim=self.dim_local_decoder // self.cross_attn_nheads, - n_heads=self.cross_attn_nheads, - n_kv_heads=self.cross_attn_nheads, - norm_eps=self.norm_eps, - ) + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.dim_local_decoder) ) self.output = nn.Linear(self.dim_local_decoder, self.vocab_size, bias=False) @@ -710,7 +700,7 @@ def _create_patch_projection(self, config): if not (dimension_mismatch or cross_attn_conditions): return None - output_dim = config.decoder_dim_token_emb * (getattr(config, "cross_attn_k", None) or 1) + output_dim = config.decoder_dim_token_emb * config.cross_attn_k return nn.Linear( in_features=config.dim_global, @@ -733,8 +723,8 @@ def forward( cross_mask: Optional[torch.Tensor] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - batch_size, _ = tokens.shape - batch_size, _, _ = embeds.shape + batch_size, sequence_length = tokens.shape + batch_size, seq_length, _ = embeds.shape assert embeds is not None, "Embeddings must be provided" @@ -756,10 +746,13 @@ def forward( for i, layer in enumerate(self.layers): if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder): # Use cross attention to extract info from patch_embeds into hidden_states - cross_attention_output = self.cross_attn_layers[i]( - query_states=hidden_states, - kv=patch_embeds, - mask=cross_mask, + cross_attention_output, _, _ = self.cross_attn_layers[i]( + hidden_states=hidden_states, + cross_attention_states=patch_embeds, + attention_mask=cross_mask, + output_attentions=False, + use_cache=False, + cache_position=None, ) hidden_states = hidden_states + cross_attention_output @@ -773,95 +766,102 @@ def forward( class BLTCrossAttention(nn.Module): - def __init__( - self, - dim: int, - head_dim: int, - n_heads: int, - n_kv_heads: int, - norm_eps: float, - ): - super().__init__() - - self.dim = dim - self.head_dim = head_dim + """Cross-attention module for BLT, following transformers style""" - self.n_heads = n_heads - self.n_kv_heads = n_kv_heads - self.heads_per_group = self.n_heads // self.n_kv_heads - - self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) - self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) + def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + # Use provided hidden_size or fallback to encoder dimension + self.hidden_size = hidden_size or config.dim_local_encoder + self.num_heads = config.cross_attn_nheads + self.num_key_value_heads = config.cross_attn_nheads # Assuming same for cross attention + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = self.head_dim ** -0.5 + self.dropout = config.dropout - self.wq = nn.Linear( - dim, - n_heads * head_dim, - bias=False, - ) - self.wk = nn.Linear( - dim, - n_kv_heads * head_dim, - bias=False, - ) - self.wv = nn.Linear( - dim, - n_kv_heads * head_dim, - bias=False, - ) + self.wq = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.wk = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.wo = nn.Linear( - n_heads * head_dim, - dim, - bias=False, - ) + self.cross_attn_norm_q = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) + self.cross_attn_norm_kv = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) def forward( self, - query_states: torch.Tensor, - kv: torch.Tensor, - mask: Optional[Union[BlockMask, str]] = None, - ) -> torch.Tensor: - # B S D - batch_size, seq_len, _ = query_states.shape - _, kv_seq_len, _ = kv.shape - - # Store original input for residual connection - residual = query_states + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() - query_norm = self.cross_attn_norm_q(query_states) - kv = self.cross_attn_norm_kv(kv) - - query_proj = self.wq(query_norm) - key_states = self.wk(kv) - value_states = self.wv(kv) - - output_shape = query_proj.shape - # B S D -> B S H D - query_proj = query_proj.view(batch_size, seq_len, self.n_heads, self.head_dim) - key_states = key_states.view(batch_size, kv_seq_len, self.n_kv_heads, self.head_dim) - value_states = value_states.view(batch_size, kv_seq_len, self.n_kv_heads, self.head_dim) - - key_states = repeat_kv(key_states, self.heads_per_group, dim=2) - value_states = repeat_kv(value_states, self.heads_per_group, dim=2) - - # assert mask is None or isinstance(mask, BlockMask) - query_proj, key_states, value_states = (e.transpose(1, 2) for e in (query_proj, key_states, value_states)) - # output = flex_attention_comp(query_proj, key_states, value_states, block_mask=mask) - is_causal = (mask == "causal") if isinstance(mask, str) else False - mask = mask if isinstance(mask, torch.Tensor) else None - mask = mask.to(dtype=query_proj.dtype).to(query_proj.device) - output = F.scaled_dot_product_attention( - query_proj, + query_states = self.cross_attn_norm_q(hidden_states) # BLT normalizes first + query_states = self.wq(query_states) + + if cross_attention_states is not None: + cross_attention_states = self.cross_attn_norm_kv(cross_attention_states) # BLT normalizes first + key_states = self.wk(cross_attention_states) + value_states = self.wv(cross_attention_states) + if past_key_value is not None: + # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states + # we still update the cross key states, past_cross_states, new_cross_states. And use it! + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position is not None and cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + if cross_attention_states is None: + raise ValueError( + "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!" + ) + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + attn_output, attn_weights = attention_interface( + self, + query_states, key_states, value_states, - is_causal=is_causal, - attn_mask=mask, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, ) - output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - output = self.wo(output.reshape(output_shape)) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.wo(attn_output) + + attn_output = attn_output + hidden_states #TODO: they add the residual twice?? move this out - return residual + output + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value class BLTGlobalTransformer(nn.Module): @@ -875,11 +875,11 @@ def __init__(self, config): self.dropout = config.dropout self.layers = nn.ModuleList() - # old = config.n_kv_heads + old = config.n_kv_heads config.n_kv_heads = config.n_kv_heads_global for _ in range(self.n_layers_global): self.layers.append(BLTTransformerLayer(self.dim_global, self.n_heads_global, config)) - # config.n_kv_heads = old + config.n_kv_heads = old global_config = config global_config.head_dim = self.dim_global // self.n_heads_global @@ -1041,12 +1041,12 @@ def __init__(self, config: BLTConfig): self.max_patch_length = config.max_patch_length self.patching_batch_size = config.patching_batch_size self.patching_device = config.patching_device - self.cross_attn_encoder = getattr(config, "cross_attn_encoder", False) - self.cross_attn_decoder = getattr(config, "cross_attn_decoder", False) - self.cross_attn_k = getattr(config, "cross_attn_k", None) - self.cross_attn_window_encoder = getattr(config, "cross_attn_window_encoder", None) - self.cross_attn_window_decoder = getattr(config, "cross_attn_window_decoder", None) - self.cross_attn_use_flex_attention = getattr(config, "cross_attn_use_flex_attention", True) + self.cross_attn_encoder = config.cross_attn_encoder + self.cross_attn_decoder = config.cross_attn_decoder + self.cross_attn_k = config.cross_attn_k + self.cross_attn_window_encoder = config.cross_attn_window_encoder + self.cross_attn_window_decoder = config.cross_attn_window_decoder + self.cross_attn_use_flex_attention = config.cross_attn_use_flex_attention self.boe_id = config.boe_id self.eos_token_id = config.eos_token_id From 2a7778c36f36e9a2d41915df01a91a99924aabe2 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 20 Jun 2025 10:31:24 +0000 Subject: [PATCH 025/139] pass arg --- src/transformers/models/blt_wip/modeling_blt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 3b6495431d50..f8454d551d2c 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -778,7 +778,7 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.num_key_value_heads = config.cross_attn_nheads # Assuming same for cross attention self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim ** -0.5 + self.scaling = None #self.head_dim ** -0.5 self.dropout = config.dropout self.wq = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) @@ -848,7 +848,7 @@ def forward( key_states, value_states, attention_mask, - dropout=0.0 if not self.training else self.dropout, + dropout=0.0, #if not self.training else self.dropout, scaling=self.scaling, **kwargs, ) From 9ed04fda09db10727270654987828e8a37077562 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 20 Jun 2025 12:26:32 +0000 Subject: [PATCH 026/139] rename weights --- src/demo_hf.py | 5 +- .../models/blt_wip/configuration_blt.py | 9 +- .../models/blt_wip/modeling_blt.py | 1460 +++++++++-------- 3 files changed, 798 insertions(+), 676 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index c935add72575..7140d9749e0b 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -101,7 +101,7 @@ def generate( def main(prompt: str = "my name is", model_name: str = "blt-1b"): device = "cuda" - blt_repo = "itazap/blt-1b" + blt_repo = "itazap/blt-1b-converted" model = BLTModel.from_pretrained(blt_repo).to(device) tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) @@ -111,11 +111,14 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): outputs = generate(prompts, model=model, tokenizer=tokenizer, max_gen_len=200, device=device) text_outputs = [tokenizer.decode(t) for t in outputs] + for p, t in zip(prompts, text_outputs): print(f'Prompt: "{p}"') print(f'Completion: "{t}"') print() + print('here') + if __name__ == "__main__": main() diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index bd5173e9d9f3..a10bee26c182 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -411,7 +411,7 @@ def __init__( cross_attn_decoder=False, cross_attn_window_encoder=None, cross_attn_window_decoder=None, - cross_attn_k=None, + cross_attn_k=1, cross_attn_nheads=None, cross_attn_all_layers_decoder=False, cross_attn_all_layers_encoder=False, @@ -569,7 +569,6 @@ def __init__( self.max_position_embeddings=max_seqlen self.hidden_size=dim_local_encoder self.num_attention_heads=n_heads_local_encoder - # self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads super().__init__( bos_token_id=bos_token_id, @@ -579,10 +578,8 @@ def __init__( ) - @property def encoder_dim_token_emb(self): - """Compute encoder token embedding dimension.""" if self.dim_token is not None: return self.dim_token elif self.use_local_encoder_transformer: @@ -594,7 +591,6 @@ def encoder_dim_token_emb(self): @property def encoder_dim_patch_emb(self): - """Compute encoder patch embedding dimension.""" if self.cross_attn_encoder: if self.cross_attn_init_by_pooling: return self.dim_local_encoder @@ -604,7 +600,6 @@ def encoder_dim_patch_emb(self): @property def global_dim_patch_emb(self): - """Compute global patch embedding dimension.""" dim_token_emb = self.encoder_dim_token_emb if self.cross_attn_encoder: cross_attn_k = self.cross_attn_k if self.cross_attn_k is not None else 1 @@ -614,7 +609,6 @@ def global_dim_patch_emb(self): or not self.downsampling_by_pooling or len(self.downsampling_by_pooling) == 0 ): - # Use default patch_size of 8 if not set patch_size = self.patch_size if self.patch_size is not None else 8 return dim_token_emb * patch_size else: @@ -622,7 +616,6 @@ def global_dim_patch_emb(self): @property def decoder_dim_token_emb(self): - """Compute decoder token embedding dimension.""" if self.share_encoder_decoder_emb: return self.encoder_dim_token_emb elif self.dim_token is not None: diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index f8454d551d2c..0a6eaf1408dc 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -1,10 +1,10 @@ +#blt old + # Copyright (c) Meta Platforms, Inc. and affiliates. import logging import os -from typing import Callable, List, Optional, Tuple, Union - -from ...cache_utils import Cache +from typing import List, Optional, Tuple, Union import torch import torch.nn @@ -12,10 +12,8 @@ from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update - -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from .configuration_blt import ( +from ...modeling_utils import PreTrainedModel +from .configuration_blt_og import ( BLTConfig, PatchingModeEnum, ) @@ -26,6 +24,38 @@ flex_attention_comp = flex_attention + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "sdpa": + BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) + + if attn_bias_type == "causal": + return "causal" + + if BLT_SUPPRESS_ATTN_ERROR == 1: + return "causal" + else: + raise ValueError( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" + ) + elif attn_impl == "flex_attention": + return create_block_mask(causal_mask, None, None, seqlen, seqlen) + else: + raise NotImplementedError(f"Attention {attn_impl} with {sliding_window} sliding window not implemented") + + def cross_entropy(pred, target, **kwargs): return F.nll_loss( F.log_softmax(pred.flatten(end_dim=-2).float(), -1), @@ -37,15 +67,242 @@ def cross_entropy(pred, target, **kwargs): def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." - batch_size, slen, n_kv_heads, head_dim = x.shape + bs, slen, n_kv_heads, head_dim = x.shape if n_rep == 1: return x return ( x[:, :, :, None, :] - .expand(batch_size, slen, n_kv_heads, n_rep, head_dim) - .reshape(batch_size, slen, n_kv_heads * n_rep, head_dim) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) ) + +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + rope_use_fp32_in_outer_product: bool = False, +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + if rope_use_fp32_in_outer_product: + t = t.to(torch.float32) + + freqs = torch.outer(t, freqs).float() + + cos, sin = freqs.cos(), freqs.sin() + + return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + seq_dim (int): Sequence dimension index. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= seq_dim < ndim + assert freqs_cis.shape == ( + x.shape[seq_dim], + x.shape[-3], + 2, + 2, + ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" + shape = [d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])] + [2, 2] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + seq_dim: int, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, seq_dim).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 + xq_out = (xq_ * freqs_cis).sum(5).flatten(3) + xk_out = (xk_ * freqs_cis).sum(5).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed. +class RotaryEmbedding(torch.nn.Module): + """ + RotaryEmbedding Module + """ + + def __init__( + self, + theta: float, + head_dim: int, + max_seqlen: int = 1024, + rope_use_fp32_in_outer_product: bool = False, + ): + super().__init__() + + self.theta = theta + self.head_dim = head_dim + self.max_seqlen = max_seqlen + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + dim=head_dim, + end=max_seqlen, + theta=theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ), + persistent=False, + ) + + + def forward(self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None): + """ + Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions + Args: + seqlen (int): Contiguous sequence length + tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen + + Returns: + Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis + """ + test = (seqlen is not None) or (tok_idx is not None) + assert test, "Should provide atleast seqlen or tok_idx" + if tok_idx is not None: + return self.freqs_cis[tok_idx] + elif seqlen is not None: + return self.freqs_cis[0:seqlen] + + +class BLTSelfAttention(nn.Module): + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + rope_theta: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + self.rope_theta = rope_theta + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + # B S D + bsz, seq_len, dim = x.shape + + xq = self.wq(x.view_as(x)) + xk = self.wk(x.view_as(x)) + xv = self.wv(x.view_as(x)) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) + + # This condition helps us be easily compatible + # with inference by adding a pluggable KVCache + if hasattr(self, "kv_cache"): + xk, xv = self.kv_cache.update(xk, xv, tok_idx) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + if attn_impl == "flex_attention": + assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) + output = flex_attention_comp(xq, xk, xv, block_mask=mask) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + elif attn_impl == "sdpa": + xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) + assert mask is None or isinstance(mask, (str, torch.Tensor)) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + else: + raise NotImplementedError(f"Attention implementation {attn_impl} not supported") + + output_reshaped = output.reshape(output_shape) + + output = self.wo(output_reshaped) + + return output + + class BLTMLP(nn.Module): def __init__( self, @@ -89,170 +346,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = self.w2(F.silu(x1) * x3) return output -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - # TODO: not exactly equivalent to other transformers implementations,, need feedback - # Extract first head_dim//2 elements which correspond to the unique frequencies - # This matches the original BLT approach which uses head_dim//2 frequency pairs - head_dim = q.shape[-1] - cos_freqs = cos[..., :head_dim//2] # [B, S, D/2] - sin_freqs = sin[..., :head_dim//2] # [B, S, D/2] - - # Expand cos/sin to match query/key tensor format [B, H, S, D/2] - cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] - sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] - - # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... - q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] - k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] - - # Extract real and i parts - q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] - k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] - - # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] - q_real_rot = cos_freqs * q_real - sin_freqs * q_imag - q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag - k_real_rot = cos_freqs * k_real - sin_freqs * k_imag - k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag - - # Recombine pairs and reshape back to original format - q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] - k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] - - return q_rot.type_as(q), k_rot.type_as(k) - - - -class BLTSelfAttention(nn.Module): - def __init__(self, config: BLTConfig, layer_idx: int): - super().__init__() - self.config = config - self.num_heads = config.num_attention_heads - self.dropout = config.dropout - self.hidden_size = config.hidden_size - self.num_key_value_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // self.num_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim ** -0.5 - self.rope_theta = config.rope_theta - self.layer_idx = layer_idx - - self.wq = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.wk = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.wv = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, - output_attentions: bool = False, - use_cache: bool = False, - past_key_value=None, - cache_position=None, - **kwargs, - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.wq(hidden_states) - key_states = self.wk(hidden_states) - value_states = self.wv(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - output_attentions = False - self.config._attn_implementation = "sdpa" - self.scaling = None - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.wo(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value class BLTTransformerLayer(nn.Module): - def __init__(self, dim, n_heads, config, layer_idx=0): + def __init__(self, args): super().__init__() # Extract parameters from dictionary - dim = dim - n_heads = n_heads - head_dim = config.head_dim - n_kv_heads = config.n_kv_heads - rope_theta = config.rope_theta - multiple_of = config.multiple_of - ffn_dim_multiplier = config.ffn_dim_multiplier - norm_eps = config.norm_eps - + dim = args["dim"] + n_heads = args["n_heads"] + head_dim = args["head_dim"] + n_kv_heads = args["n_kv_heads"] + rope_theta = args["rope_theta"] + multiple_of = args["multiple_of"] + ffn_dim_multiplier = args["ffn_dim_multiplier"] + norm_eps = args["norm_eps"] + + assert (head_dim is not None) or (n_heads is not None), "Should specify at least head_dim or n_heads" self.head_dim = head_dim or dim // n_heads self.n_heads = n_heads or dim // head_dim self.n_kv_heads = n_kv_heads or self.n_heads - config.hidden_size = dim - - self.attention = BLTSelfAttention(config=config, layer_idx=layer_idx) + assert n_heads % self.n_kv_heads == 0 + assert dim % n_heads == 0 + self.attention = BLTSelfAttention( + dim=dim, + head_dim=self.head_dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + rope_theta=rope_theta, + ) self.feed_forward = BLTMLP( dim=dim, hidden_dim=4 * dim, @@ -264,36 +389,24 @@ def __init__(self, dim, n_heads, config, layer_idx=0): def forward( self, - hidden_states: torch.Tensor, - past_key_value: Optional[bool] = None, - position_embeddings: Optional[torch.Tensor] = None, - - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, str]] = None, + attn_impl: str = "sdpa", ) -> torch.Tensor: - - residual = hidden_states - norm_hidden_states = self.attention_norm(hidden_states) - - - hidden_states, self_attn_weights, present_key_value = self.attention( - hidden_states=norm_hidden_states, - # TODO: = BLT, attn_out = self.attention(self.attention_norm(x), in TransformerBlock.forward, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - cache_position=cache_position, - position_embeddings=position_embeddings + norm_x = self.attention_norm(x) + attn_out = self.attention( + norm_x, + freq_cis, + tok_idx=tok_idx, + mask=mask, + attn_impl=attn_impl, ) - - hidden_states = residual + hidden_states - normalized_hidden_states = self.ffn_norm(hidden_states) - hidden_states = hidden_states + self.feed_forward(normalized_hidden_states) - return hidden_states + h = x + attn_out + h_norm = self.ffn_norm(h) + out = h + self.feed_forward(h_norm) + return out def check_non_zero_after_zero(tensor): zero_mask = tensor == 0 @@ -307,7 +420,7 @@ def check_non_zero_after_zero(tensor): non_zero_after_zero = (tensor != 0) & shifted_mask return non_zero_after_zero.any() -def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): +def rolling_polynomial_hash(t, hash_func_nb: int = 0): primes = [ 1000000007, 5915587277, @@ -321,25 +434,25 @@ def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): 5463458053, 3367900313, ] - prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) - prime_powers = torch.stack([prime**i for i in range(token_tensor.shape[-1])]) - return torch.sum(token_tensor * prime_powers, dim=-1) + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) + prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) + return torch.sum(t * prime_powers, dim=-1) -def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): +def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): """ - Returns a hash of the input token_ids and maps it to a value in the range [0, max_hash]. + Returns a hash of the input x and maps it to a value in the range [0, max_hash]. - expects: token_ids of shape (batch_size, seq_len) with values as ids in the token vocab. + expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. Note: max hash can make a big difference on the number of collisions. """ with torch.no_grad(): - batch_size, seq_len = token_ids.shape - prefix = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) - token_ids = torch.cat([prefix, token_ids], dim=1) - windows = token_ids.unfold(1, group_size, 1) + bs, seq_len = x.shape + prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) + x = torch.cat([prefix, x], dim=1) + windows = x.unfold(1, group_size, 1) # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) hashes = rolling_polynomial_hash(windows, hash_func_nb) hash_values_range = hashes % max_hash @@ -349,32 +462,32 @@ def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_ def create_patch_mask_from_ids(patch_ids, num_patches, window=None, patches_as_queries=False): """ - Creates a tensor of shape [batch_size, seq_len, num_patches] where each element at position (i, j, k) + Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) is True if the patch id at position (i, j) is less than or equal to k. Args: - patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. + patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. num_patches (int): Total number of patches. window (int): If not None, only considers patches within a window of size window. patches_as_queries (bool): If True, the patches are used as queries Returns: - torch.Tensor: Tensor of shape [batch_size, q_len, kv_len] with the desired mask. + torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. """ - batch_size, seq_len = patch_ids.shape + bs, seq_len = patch_ids.shape if not patches_as_queries: - q_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) + q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) kv_ids = ( torch.arange(num_patches, device=patch_ids.device) .unsqueeze(0) .unsqueeze(0) - .expand(batch_size, seq_len, num_patches) + .expand(bs, seq_len, num_patches) ) else: - kv_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) + kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) q_ids = ( torch.arange(num_patches, device=patch_ids.device) .unsqueeze(0) .unsqueeze(-1) - .expand(batch_size, num_patches, seq_len) + .expand(bs, num_patches, seq_len) ) if window is None: mask = q_ids == kv_ids @@ -392,7 +505,7 @@ def cross_attn_mask( window=None, block_mask=True, ): - batch_size = patch_ids.shape[0] + bs = patch_ids.shape[0] with torch.no_grad(): # Create the patch mask cross_mask = create_patch_mask_from_ids( @@ -404,19 +517,19 @@ def cross_attn_mask( q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k assert cross_mask.shape == ( - batch_size, + bs, q_len, kv_len, - ), f"{cross_mask.shape} != {(batch_size, q_len, kv_len)}" + ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" block_mask = None if block_mask: - def patch_mask(b, num_heads, q_idx, kv_idx): + def patch_mask(b, h, q_idx, kv_idx): return cross_mask[b, q_idx, kv_idx] block_mask = create_block_mask( patch_mask, - B=batch_size, + B=bs, H=None, Q_LEN=q_len, KV_LEN=kv_len, @@ -426,7 +539,7 @@ def patch_mask(b, num_heads, q_idx, kv_idx): else: return torch.where(cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).unsqueeze( 1 - ) # [batch_size, 1, q_len, kv_len] + ) # [bs, 1, q_len, kv_len] def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor: @@ -459,88 +572,156 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> return padded - -class BLTRotaryEmbedding(nn.Module): - def __init__(self, config: BLTConfig, device=None): +class BLTLocalModelBase(nn.Module): + def __init__(self, config: BLTConfig, component_type: str = "encoder"): super().__init__() - self.rope_type = config.rope_scaling["rope_type"] - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() + if component_type == "encoder": + self.dim = config.dim_local_encoder + self.n_layers = config.n_layers_local_encoder + self.n_heads = config.n_heads_local_encoder + self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen + self.attn_bias_type = "local_block_causal" + self.sliding_window = config.local_attention_window_len + elif component_type == "decoder": + self.dim = config.dim_local_decoder + self.n_layers = config.n_layers_local_decoder + self.n_heads = config.n_heads_local_decoder + self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen + self.attn_bias_type = "local_block_causal" + self.sliding_window = config.local_attention_window_len + else: + raise ValueError(f"Unknown component_type: {component_type}") - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling + self.dropout = config.dropout + self.vocab_size = config.vocab_size + config.pm_size + self.patch_size = config.patch_size - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + self.attn_impl = config.attn_impl + self.use_rope = config.use_rope + self.init_std_factor = config.init_std_factor + self.init_base_std = config.init_base_std + self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None) + self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None) + self.cross_attn_k = getattr(config, "cross_attn_k", None) + self.eos_id = config.eos_token_id + self.boe_id = config.boe_id + # Initialize cross attention layers as None (will be set by subclasses if needed) + self.cross_attn_layers = None -class BLTLocalEncoder(nn.Module): - def __init__(self, config: BLTConfig): - super().__init__() - - # Extract config values to instance attributes - self.dropout = config.dropout - self.dim_local_encoder = config.dim_local_encoder - self.n_layers_local_encoder = config.n_layers_local_encoder - self.n_heads_local_encoder = config.n_heads_local_encoder - self.vocab_size = config.vocab_size - self.pm_size = config.pm_size - self.cross_attn_encoder = config.cross_attn_encoder - self.cross_attn_nheads = config.cross_attn_nheads - self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder - self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling - self.cross_attn_k = config.cross_attn_k - self.norm_eps = config.norm_eps - self.sliding_window = config.sliding_window - - self.layers = nn.ModuleList([BLTTransformerLayer(self.dim_local_encoder, self.n_heads_local_encoder, config) for _ in range(self.n_layers_local_encoder)]) + # Create parameter dict for BLTTransformerLayers + layer_params = { + "dim": self.dim, + "n_heads": self.n_heads, + "head_dim": config.head_dim, + "n_kv_heads": getattr(config, "n_kv_heads", None), + "rope_theta": config.rope_theta, + "multiple_of": getattr(config, "multiple_of", 256), + "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), + "norm_eps": config.norm_eps, + } + + self.layers = nn.ModuleList([BLTTransformerLayer(layer_params) for _ in range(self.n_layers)]) + + if not self.use_rope: + self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length + else: + self.rope = RotaryEmbedding( + theta=config.rope_theta, + head_dim=config.head_dim or self.dim // self.n_heads, + max_seqlen=self.max_seqlen, + rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, + ) + self.pos_embeddings = None - # Set up config for rotary embedding - encoder_config = config - encoder_config.head_dim = self.dim_local_encoder // self.n_heads_local_encoder - encoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen - self.rotary_emb = BLTRotaryEmbedding(config=encoder_config) + # Set dimension-specific embedding dimensions + if component_type == "encoder": + self.dim_token_emb = config.encoder_dim_token_emb + self.dim_patch_emb = config.encoder_dim_patch_emb + elif component_type == "decoder": + self.dim_token_emb = config.decoder_dim_token_emb + self.dim_patch_emb = config.dim_global self.token_embedding_projection = ( - nn.Linear(config.encoder_dim_token_emb, self.dim_local_encoder, bias=False) - if config.encoder_dim_token_emb is not None and config.encoder_dim_token_emb != self.dim_local_encoder + nn.Linear(self.dim_token_emb, self.dim, bias=False) + if self.dim_token_emb is not None and self.dim_token_emb != self.dim else None ) self.patch_embedding_projection = self._create_patch_projection(config) - self.tok_embeddings = nn.Embedding(self.vocab_size + self.pm_size, self.dim_local_encoder) + def _should_create_patch_projection(self, config: BLTConfig): + dimension_mismatch = self.dim_patch_emb is not None and self.dim_patch_emb != self.dim - # Initialize cross attention layers only if cross attention is enabled - self.cross_attn_layers = None - if self.cross_attn_encoder and self.cross_attn_nheads is not None: + # Check cross attention conditions + cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( + config.cross_attn_decoder and config.cross_attn_init_by_pooling + ) + + return dimension_mismatch or cross_attn_conditions + + def _create_patch_projection(self, config): + if not self._should_create_patch_projection(config): + return None + + output_dim = self.dim_token_emb * (self.cross_attn_k or 1) + + return nn.Linear( + in_features=self.dim_patch_emb, + out_features=output_dim, + bias=False, + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + return embeds + else: + return self.tok_embeddings(tokens) + + +class BLTLocalEncoder(BLTLocalModelBase): + def __init__(self, config: BLTConfig): + super().__init__(config, component_type="encoder") + + self.apply_transformer = config.use_local_encoder_transformer + self.downsampling_by_pooling = config.downsampling_by_pooling + self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None + self.cross_attn_encoder = config.cross_attn_encoder + self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_nheads = config.cross_attn_nheads + + self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim) + + if self.cross_attn_encoder: self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.n_layers_local_encoder if self.cross_attn_all_layers_encoder else 1 - for layer_idx in range(layers_to_add): + layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1 + for _ in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.dim_local_encoder) + BLTCrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=config.norm_eps, + ) ) + def apply_embedding(self, tokens, embeds): + if embeds is not None: + assert self.expects_hash_embeddings, "Not expecting embeddings to be passed." + return embeds + else: + return self.tok_embeddings(tokens) + def forward( self, - input_ids: torch.Tensor, - input_embeds: Optional[torch.Tensor] = None, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor] = None, patch_embeds: Optional[torch.Tensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, @@ -549,72 +730,48 @@ def forward( cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): """ """ - batch_size, sequence_length = input_ids.shape - if input_embeds is None: - input_embeds = self.embed_tokens(input_ids) - - batch_size, _, _ = input_embeds.shape - - hidden_states = input_embeds - - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - - position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) + bs, seqlen = tokens.shape + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) - hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + h = self.apply_embedding(tokens, embeds) + + + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + - for idx, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + h = F.dropout(h, p=self.dropout, training=self.training) - if self.cross_attn_encoder and (idx == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): - # Initialize patch_embeds if not provided when cross attention is enabled - if patch_embeds is None: - patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) + for i, layer in enumerate(self.layers): + h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl) + # check if cross attention should be applied to either all layer or only the last layer + if self.cross_attn_encoder and (i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): + # apply pooling and project + if self.cross_attn_init_by_pooling and patch_embeds is None: + patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) if self.patch_embedding_projection is not None: patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.dim_local_encoder) - - layer_idx = idx if self.cross_attn_all_layers_encoder else 0 - cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( - hidden_states=patch_embeds, - cross_attention_states=hidden_states, - attention_mask=cross_mask, - output_attentions=False, - use_cache=False, - cache_position=None, - ) - patch_embeds = patch_embeds + cross_attention_output - - encoder_cross_states = patch_embeds if self.cross_attn_encoder else None - return (hidden_states, encoder_cross_states), cache - - def _create_patch_projection(self, config): - dimension_mismatch = config.encoder_dim_patch_emb is not None and config.encoder_dim_patch_emb != config.dim_local_encoder - - cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( - config.cross_attn_decoder and config.cross_attn_init_by_pooling - ) + patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) - if not (dimension_mismatch or cross_attn_conditions): - return None + layer_idx = i if self.cross_attn_all_layers_encoder else 0 + patch_embeds_cross = self.cross_attn_layers[layer_idx]( + x=patch_embeds, + kv=h, + mask=cross_mask, + ) + patch_embeds = patch_embeds + patch_embeds_cross - output_dim = config.encoder_dim_token_emb * config.cross_attn_k + h_residual = patch_embeds if self.cross_attn_encoder else None + return (h, h_residual), cache - return nn.Linear( - in_features=config.encoder_dim_patch_emb, - out_features=output_dim, - bias=False, - ) - - def embed_tokens(self, tokens, embeds=None): - if embeds is not None: - assert self.config.encoder_hash_byte_group_size is not None, "Not expecting embeddings to be passed." - return embeds - else: - return self.tok_embeddings(tokens) - - def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): + def patch_reduce(self, h, max_num_patches, reduction, patch_ids): """ Reduce variable length patches to single embedding per patch Note: this works with variable number of patches for different sequences in the batch @@ -625,95 +782,55 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): (i.e. if the sum(patch_lengths[i]) < seq_len for any i) will be sent to a dummy patch, which is trimmed before returning. """ - batch_size, seq_len, embedding_dim = hidden_states.shape + bs, seq_len, emb_dim = h.shape - patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) - reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device) - reduced_embeddings = reduced_embeddings.scatter_reduce( - src=hidden_states, + reduced_embs = torch.zeros((bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device) + reduced_embs = reduced_embs.scatter_reduce( + src=h, dim=1, index=patch_ids, reduce=reduction, include_self=False, ) - reduced_embeddings = reduced_embeddings[:, :max_num_patches, :] + reduced_embs = reduced_embs[:, :max_num_patches, :] - return reduced_embeddings + return reduced_embs -class BLTLocalDecoder(nn.Module): +class BLTLocalDecoder(BLTLocalModelBase): def __init__(self, config: BLTConfig): - super().__init__() + super().__init__(config, component_type="decoder") - # Extract config values to instance attributes - self.dim_local_decoder = config.dim_local_decoder - self.n_heads_local_decoder = config.n_heads_local_decoder - self.n_layers_local_decoder = config.n_layers_local_decoder - self.vocab_size = config.vocab_size - self.norm_eps = config.norm_eps - self.dropout = config.dropout + # Model configuration flags self.cross_attn_decoder = config.cross_attn_decoder - self.cross_attn_nheads = config.cross_attn_nheads self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder - self.cross_attn_k = config.cross_attn_k - self.sliding_window = config.sliding_window - - self.layers = nn.ModuleList([BLTTransformerLayer(self.dim_local_decoder, self.n_heads_local_decoder, config) for _ in range(self.n_layers_local_decoder)]) - - decoder_config = config - decoder_config.head_dim = self.dim_local_decoder // self.n_heads_local_decoder - decoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen - - self.rotary_emb = BLTRotaryEmbedding(config=decoder_config) - - self.token_embedding_projection = ( - nn.Linear(config.decoder_dim_token_emb, self.dim_local_decoder, bias=False) - if config.decoder_dim_token_emb is not None and config.decoder_dim_token_emb != self.dim_local_decoder - else None - ) - - self.patch_embedding_projection = self._create_patch_projection(config) + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_nheads = config.cross_attn_nheads - self.norm = RMSNorm(self.dim_local_decoder, eps=self.norm_eps) + self.norm = RMSNorm(self.dim, eps=config.norm_eps) - # Initialize cross attention layers only if cross attention is enabled - self.cross_attn_layers = None - if self.cross_attn_decoder and self.cross_attn_nheads is not None: + if self.cross_attn_decoder: self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.n_layers_local_decoder if self.cross_attn_all_layers_decoder else 1 - for layer_idx in range(layers_to_add): + layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1 + for _ in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.dim_local_decoder) + BLTCrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=config.norm_eps, + ) ) - self.output = nn.Linear(self.dim_local_decoder, self.vocab_size, bias=False) - - def _create_patch_projection(self, config): - dimension_mismatch = config.dim_global is not None and config.dim_global != config.dim_local_decoder - - # Check cross attention conditions - cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( - config.cross_attn_decoder and config.cross_attn_init_by_pooling - ) - - if not (dimension_mismatch or cross_attn_conditions): - return None - - output_dim = config.decoder_dim_token_emb * config.cross_attn_k - - return nn.Linear( - in_features=config.dim_global, - out_features=output_dim, + self.output = nn.Linear( + self.dim, + config.vocab_size, bias=False, ) - def apply_embedding(self, tokens, embeds): - if embeds is not None: - return embeds - else: - return self.tok_embeddings(tokens) - def forward( self, tokens: torch.Tensor, @@ -723,201 +840,214 @@ def forward( cross_mask: Optional[torch.Tensor] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - batch_size, sequence_length = tokens.shape - batch_size, seq_length, _ = embeds.shape - + bs, seqlen = tokens.shape assert embeds is not None, "Embeddings must be provided" - hidden_states = embeds + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) + + h = embeds if self.patch_embedding_projection is not None: assert patch_embeds is not None, "Patch embeddings must be passed." patch_embeds = self.patch_embedding_projection(patch_embeds) if self.cross_attn_k is not None: - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.dim_local_decoder) + patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) if patch_embeds is not None and not self.cross_attn_decoder: - hidden_states = hidden_states + patch_embeds + h = h + patch_embeds - position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None - hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + h = F.dropout(h, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder): - # Use cross attention to extract info from patch_embeds into hidden_states - cross_attention_output, _, _ = self.cross_attn_layers[i]( - hidden_states=hidden_states, - cross_attention_states=patch_embeds, - attention_mask=cross_mask, - output_attentions=False, - use_cache=False, - cache_position=None, + # Use cross attention to extract info from patch_embeds into h + h_cross = self.cross_attn_layers[i]( + x=h, + kv=patch_embeds, + mask=cross_mask, ) - hidden_states = hidden_states + cross_attention_output + h = h + h_cross - hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl) - logits = self.norm(hidden_states) - logits = F.dropout(logits, p=self.dropout, training=self.training) - logits = self.output(logits) - logits = logits.float() - return logits, cache + h_preds = self.norm(h) + h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) + h_preds = self.output(h_preds) + h_preds = h_preds.float() + return h_preds, cache class BLTCrossAttention(nn.Module): - """Cross-attention module for BLT, following transformers style""" - - def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None): + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + ): super().__init__() - self.config = config - self.layer_idx = layer_idx - # Use provided hidden_size or fallback to encoder dimension - self.hidden_size = hidden_size or config.dim_local_encoder - self.num_heads = config.cross_attn_nheads - self.num_key_value_heads = config.cross_attn_nheads # Assuming same for cross attention - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = None #self.head_dim ** -0.5 - self.dropout = config.dropout - - self.wq = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.wk = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.wv = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.cross_attn_norm_q = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) - self.cross_attn_norm_kv = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) + self.dim = dim + self.head_dim = head_dim - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - bsz, q_len, _ = hidden_states.size() - - query_states = self.cross_attn_norm_q(hidden_states) # BLT normalizes first - query_states = self.wq(query_states) - - if cross_attention_states is not None: - cross_attention_states = self.cross_attn_norm_kv(cross_attention_states) # BLT normalizes first - key_states = self.wk(cross_attention_states) - value_states = self.wv(cross_attention_states) - if past_key_value is not None: - # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states - # we still update the cross key states, past_cross_states, new_cross_states. And use it! - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - elif cache_position is not None and cache_position[0] != 0: - key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], - ) - else: - if cross_attention_states is None: - raise ValueError( - "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!" - ) + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads - attention_interface: Callable = eager_attention_forward + self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0, #if not self.training else self.dropout, - scaling=self.scaling, - **kwargs, + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.wo(attn_output) + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) - attn_output = attn_output + hidden_states #TODO: they add the residual twice?? move this out + def forward( + self, + x: torch.Tensor, + kv: torch.Tensor, + mask: Optional[Union[BlockMask, str]] = None, + ) -> torch.Tensor: + # B S D + bsz, seq_len, _ = x.shape + _, slen_kv, _ = kv.shape + x_norm = self.cross_attn_norm_q(x) + kv = self.cross_attn_norm_kv(kv) + + xq = self.wq(x_norm) + xk = self.wk(kv) + xv = self.wv(kv) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + # assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) + # output = flex_attention_comp(xq, xk, xv, block_mask=mask) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask if isinstance(mask, torch.Tensor) else None + mask = mask.to(dtype=xq.dtype).to(xq.device) + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - if not output_attentions: - attn_weights = None + output = self.wo(output.reshape(output_shape)) - return attn_output, attn_weights, past_key_value + return x + output class BLTGlobalTransformer(nn.Module): def __init__(self, config): super().__init__() - # Extract config values to instance attributes - self.dim_global = config.dim_global - self.n_heads_global = config.n_heads_global - self.n_layers_global = config.n_layers_global - self.dropout = config.dropout - - self.layers = nn.ModuleList() - old = config.n_kv_heads - config.n_kv_heads = config.n_kv_heads_global - for _ in range(self.n_layers_global): - self.layers.append(BLTTransformerLayer(self.dim_global, self.n_heads_global, config)) - config.n_kv_heads = old + self.config = config - global_config = config - global_config.head_dim = self.dim_global // self.n_heads_global + self.dim = config.dim_global + self.rope_embeddings = RotaryEmbedding( + theta=config.rope_theta, + head_dim=config.head_dim or self.config.dim_global // config.n_heads_global, + max_seqlen=config.max_seqlen, + rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, + ) + # Handle both eos_id and eos_token_id for compatibility + self.eos_id = getattr(config, "eos_id", getattr(config, "eos_token_id", 2)) + + # Create parameter dict for BLTTransformerLayers + layer_params = { + "dim": self.dim, + "n_heads": config.n_heads_global, + "head_dim": config.head_dim, + "n_kv_heads": getattr(config, "n_kv_heads_global", None), + "rope_theta": config.rope_theta, + "multiple_of": getattr(config, "multiple_of", 256), + "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), + "norm_eps": config.norm_eps, + } - self.rotary_emb = BLTRotaryEmbedding(config=global_config) + self.layers = nn.ModuleList() + for _ in range(config.n_layers_global): + self.layers.append(BLTTransformerLayer(layer_params)) self.token_embedding_projection = None - if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim_global: + if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim: self.token_embedding_projection = nn.Linear( config.global_dim_patch_emb, - self.dim_global, + config.dim_global, bias=False, ) def forward( self, - input_ids: torch.Tensor, + tokens: torch.Tensor, tok_idx: Optional[torch.Tensor] = None, - input_embeds: Optional[torch.Tensor] = None, + embeds: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - batch_size, seq_length, _ = input_embeds.shape - - hidden_states = input_embeds - - if self.token_embedding_projection is not None and hidden_states.shape[-1] != self.dim_global: - hidden_states = self.token_embedding_projection(hidden_states) + bs, seqlen = tokens.shape + + h = embeds + + mask = ( + mask + if mask is not None + else create_causal_mask( + seqlen, + self.config.attn_impl, + self.config.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) + ) - hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + if self.token_embedding_projection is not None and h.shape[-1] != self.dim: + h = self.token_embedding_projection(h) - position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) + h = F.dropout(h, p=self.config.dropout, training=self.training) + freq_cis = self.rope_embeddings(seqlen=self.config.max_seqlen, tok_idx=tok_idx) for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + h = layer(h, freq_cis, tok_idx=None, mask=mask, attn_impl=self.config.attn_impl) - return hidden_states, cache + return h, cache def compute_hash_embeddings( @@ -998,33 +1128,39 @@ def _init_weights(self, module): a=-3 * std, b=3 * std, ) - + + elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)): + nn.init.ones_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, RotaryEmbedding): + module.freqs_cis[...] = precompute_freqs_cis( + dim=module.head_dim, + end=module.max_seqlen, + theta=module.theta, + rope_use_fp32_in_outer_product=module.rope_use_fp32_in_outer_product, + ) + elif isinstance(module, BLTModel): if module.encoder_hash_tok_embedding is not None: - emb_std = module.config.dim_local_encoder ** (-0.5) + emb_std = module.local_encoder.dim ** (-0.5) for emb in module.encoder_hash_tok_embedding: emb._custom_std = emb_std - elif isinstance(module, BLTLocalEncoder): + elif isinstance(module, (BLTLocalEncoder, BLTLocalDecoder)): if module.token_embedding_projection is not None: - module.token_embedding_projection._custom_std = module.config.dim_local_encoder ** (-0.5) + module.token_embedding_projection._custom_std = module.dim ** (-0.5) if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.encoder_dim_patch_emb ** (-0.5) - - elif isinstance(module, BLTLocalDecoder): - if module.token_embedding_projection is not None: - module.token_embedding_projection._custom_std = module.config.dim_local_decoder ** (-0.5) - - if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.dim_global ** (-0.5) + module.patch_embedding_projection._custom_std = module.dim_patch_emb ** (-0.5) elif isinstance(module, BLTGlobalTransformer): if module.token_embedding_projection is not None: module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5) elif isinstance(module, BLTPatcher): - emb_std = module.config.dim ** (-0.5) + emb_std = module.config.patcher_dim ** (-0.5) module.tok_embeddings._custom_std = emb_std module.output._custom_std = emb_std @@ -1033,34 +1169,18 @@ class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) - # Extract frequently used config values - self.patch_in_forward = config.patch_in_forward - self.patching_mode = config.patching_mode - self.patch_size = config.patch_size - self.patching_threshold = config.patching_threshold - self.max_patch_length = config.max_patch_length - self.patching_batch_size = config.patching_batch_size - self.patching_device = config.patching_device - self.cross_attn_encoder = config.cross_attn_encoder - self.cross_attn_decoder = config.cross_attn_decoder - self.cross_attn_k = config.cross_attn_k - self.cross_attn_window_encoder = config.cross_attn_window_encoder - self.cross_attn_window_decoder = config.cross_attn_window_decoder - self.cross_attn_use_flex_attention = config.cross_attn_use_flex_attention - self.boe_id = config.boe_id - self.eos_token_id = config.eos_token_id - + self.config = config self.local_encoder = BLTLocalEncoder(config) self.global_transformer = BLTGlobalTransformer(config) self.local_decoder = BLTLocalDecoder(config) self.encoder_hash_tok_embedding = init_hash_embeddings( config, - local_encoder_dim=config.dim_local_encoder, + local_encoder_dim=self.local_encoder.dim, encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, ) - if self.patch_in_forward: + if config.patch_in_forward: self.patcher = BLTPatcher(config) self.patcher.eval() for param in self.patcher.parameters(): @@ -1076,36 +1196,36 @@ def forward( # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings # are no longer used in the final BLT model - batch_size, sequence_length = tokens.shape # Batch size and sequence length + bs, N = tokens.shape # Batch size and sequence length local_encoder_tokens, local_decoder_tokens = tokens, tokens # Patching if patch_lengths is None: # assert ( - # getattr(self, "patch_in_forward", None) is not None and self.patch_in_forward + # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward # ), "Patch in forward not enabled and no patch_lengths passed." # PATCHER MODEL DEFINED - if self.patching_mode == PatchingModeEnum.entropy: + if self.config.patching_mode == PatchingModeEnum.entropy: _, patch_lengths, _ = self.patcher( local_encoder_tokens, - patch_size=self.patch_size, + patch_size=self.config.patch_size, include_next_token=True, - threshold=self.patching_threshold, - max_patch_length=self.max_patch_length, - patching_batch_size=self.patching_batch_size, - device=self.patching_device, + threshold=self.config.patching_threshold, + max_patch_length=self.config.max_patch_length, + patching_batch_size=self.config.patching_batch_size, + device=self.config.patching_device, ) else: - # self.patching_mode == PatchingModeEnum.byte - batch_size_tokens, seq_len = local_encoder_tokens.shape + # self.config.patching_mode == PatchingModeEnum.byte + bs, seq_len = local_encoder_tokens.shape seq_len_next_tok = seq_len + 1 # include_next_token=True patch_lengths = torch.ones( - (batch_size_tokens, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device ) - patch_lengths = process_patch_lengths(patch_lengths, self.max_patch_length) + patch_lengths = process_patch_lengths(patch_lengths, self.config.max_patch_length) #assert torch.min(patch_lengths) >= 0 # Generate patch IDs from patch_lengths @@ -1116,15 +1236,15 @@ def forward( cross_attn_mask_enc = None # Cross-attention encoder - if self.cross_attn_encoder: + if self.config.cross_attn_encoder: cross_attn_mask_enc = cross_attn_mask( patch_ids, patch_lengths, - sequence_length, + N, patches_as_queries=True, - cross_attn_k=self.cross_attn_k, - window=self.cross_attn_window_encoder, - block_mask=self.cross_attn_use_flex_attention, + cross_attn_k=self.config.cross_attn_k, + window=self.config.cross_attn_window_encoder, + block_mask=self.config.cross_attn_use_flex_attention, ) # Hashing and embedding @@ -1141,9 +1261,9 @@ def forward( # The final BLT model uses only hash-based n-gram embeddings # Local encoder - (encoder_hidden_states, encoder_cross_states), cache_encoder = self.local_encoder( - input_ids=local_encoder_tokens, - input_embeds=local_encoder_embeds, + (h_encoder, h_cross), cache_encoder = self.local_encoder( + tokens=local_encoder_tokens, + embeds=local_encoder_embeds, patch_embeds=None, cross_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], @@ -1151,58 +1271,50 @@ def forward( ) # Downsampling - if encoder_cross_states is not None: - # Cross attention is enabled - use cross states - global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - else: - # Cross attention is disabled - use reduced embeddings from encoder hidden states - global_hidden_states = self.local_encoder.patch_reduce( - encoder_hidden_states, patch_lengths.shape[1], "amax", patch_ids - ) + h = h_cross.view(bs, patch_lengths.shape[1], -1) # Global transformer - global_tokens = tokens.new(global_hidden_states.shape[0], global_hidden_states.shape[1]).fill_(self.boe_id) - rows, cols = torch.where(local_encoder_tokens == self.eos_token_id) + global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.config.boe_id) + rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id) eos_patch_ids = patch_ids[rows, cols] - global_tokens[rows, eos_patch_ids] = self.eos_token_id + global_tokens[rows, eos_patch_ids] = self.config.eos_token_id - global_hidden_states, _ = self.global_transformer( - input_embeds=global_hidden_states, - input_ids=global_tokens, + h, _ = self.global_transformer( + embeds=h, + tokens=global_tokens, ) # Unpatching - decoder_embeds = encoder_hidden_states + dec_embeds = h_encoder # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens), so we need to map decoder positions to the remaining patches. decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], local_decoder_tokens.shape[-1]) - # assert torch.max(decoder_patch_ids) + 1 <= global_hidden_states.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {global_hidden_states.shape[1]}" - # assert decoder_patch_ids.shape[1] == decoder_embeds.shape[1], ( - # f"{decoder_patch_ids.shape[1]} != {decoder_embeds.shape[1]}" + # assert torch.max(decoder_patch_ids) + 1 <= h.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" + # assert decoder_patch_ids.shape[1] == dec_embeds.shape[1], ( + # f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" # ) # Cross-attention decoder - if not self.cross_attn_decoder: - patch_hidden_states = torch.gather(global_hidden_states, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, global_hidden_states.shape[-1])) + if not self.config.cross_attn_decoder: + h = torch.gather(h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])) cross_attn_mask_dec = None - # assert local_decoder_tokens.shape == patch_hidden_states.shape[:-1] + # assert local_decoder_tokens.shape == h.shape[:-1] else: - patch_hidden_states = global_hidden_states cross_attn_mask_dec = cross_attn_mask( decoder_patch_ids, patch_lengths, - sequence_length, + N, patches_as_queries=False, - cross_attn_k=self.cross_attn_k, - window=self.cross_attn_window_decoder, - block_mask=self.cross_attn_use_flex_attention, + cross_attn_k=self.config.cross_attn_k, + window=self.config.cross_attn_window_decoder, + block_mask=self.config.cross_attn_use_flex_attention, ) # Local decoder output, _ = self.local_decoder( - embeds=decoder_embeds, - patch_embeds=patch_hidden_states, + embeds=dec_embeds, + patch_embeds=h, tokens=local_decoder_tokens, cross_mask=cross_attn_mask_dec, ) @@ -1253,28 +1365,41 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> class BLTPatcher(BLTPreTrainedModel): def __init__(self, config): - # Store reference to main config for accessing non-patcher settings - self.main_config = config - - # Initialize with patcher config directly - super().__init__(config.patcher_config) + super().__init__(config) - # Initialize rotary embeddings with patcher config - self.rotary_emb = BLTRotaryEmbedding(config=self.config) + self.rope_embeddings = RotaryEmbedding( + theta=config.patcher_rope_theta, + head_dim=config.patcher_head_dim or config.patcher_dim // config.patcher_n_heads, + max_seqlen=config.patcher_max_seqlen, + rope_use_fp32_in_outer_product=config.patcher_rope_use_fp32_in_outer_product, + ) self.layers = nn.ModuleList() - for _ in range(self.config.n_layers): - self.layers.append(BLTTransformerLayer(self.config.dim, self.config.n_heads, self.config)) + for _ in range(config.patcher_n_layers): + self.layers.append( + BLTTransformerLayer( + { + "dim": config.patcher_dim, + "n_heads": config.patcher_n_heads, + "head_dim": config.patcher_head_dim, + "n_kv_heads": config.patcher_n_kv_heads, + "rope_theta": config.patcher_rope_theta, + "multiple_of": config.patcher_multiple_of, + "ffn_dim_multiplier": config.patcher_ffn_dim_multiplier, + "norm_eps": config.patcher_norm_eps, + } + ) + ) - #assert self.config.vocab_size > 0 + #assert config.patcher_vocab_size > 0 - self.tok_embeddings = torch.nn.Embedding(self.config.vocab_size, self.config.dim) + self.tok_embeddings = torch.nn.Embedding(config.patcher_vocab_size, config.patcher_dim) - self.norm = RMSNorm(self.config.dim, eps=self.config.norm_eps) + self.norm = RMSNorm(config.patcher_dim, eps=config.patcher_norm_eps) self.output = nn.Linear( - self.config.dim, - self.config.vocab_size, + config.patcher_dim, + config.patcher_vocab_size, bias=False, ) @@ -1291,8 +1416,8 @@ def forward( # Handle chunked processing for entropy calculation entropies = [] - predictions = [] - max_length = self.config.max_seqlen + preds = [] + max_length = self.config.patcher_max_seqlen batch_numel = max_length * patching_batch_size splits = torch.split(token_values.flatten(), batch_numel) @@ -1305,33 +1430,34 @@ def forward( split = split.to(device) # Process chunk: embeddings -> layers -> output - batch_size, sequence_length = split.shape - input_embeds = self.tok_embeddings(split) - - hidden_states = input_embeds - + bsz, seqlen = split.shape + h = self.tok_embeddings(split) + chunk_mask = create_causal_mask( + seqlen, + self.config.patcher_attn_impl , + self.config.patcher_attn_bias_type, + sliding_window=self.config.patcher_sliding_window, + tokens=split, + eos_id=self.config.eos_id, + ) - batch_size, seq_length, _ = input_embeds.shape + freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None) - position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - - position_embeddings = self.rotary_emb(hidden_states, position_ids) # = BLT self.rope - for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) #, attn_impl=self.config.patcher_attn_impl ) + h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=self.config.patcher_attn_impl) - logits = self.output(self.norm(hidden_states)) - logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] - predictions.append(logits) - prediction_entropies = self.entropy(logits) - entropies.append(prediction_entropies) + pred = self.output(self.norm(h)) + pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] + preds.append(pred) + pred_entropies = self.entropy(pred) + entropies.append(pred_entropies) concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) - concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) + concat_preds = torch.cat(preds, dim=0).reshape(token_values.shape[0], -1) # Always compute patch lengths from concatenated entropies - batch_size, sequence_length = token_values.shape - seq_len_next_tok = sequence_length + 1 if include_next_token else sequence_length + bs, seq_len = token_values.shape + seq_len_next_tok = seq_len + 1 if include_next_token else seq_len # Find patch start IDs based on entropy if patch_size is not None: @@ -1344,17 +1470,17 @@ def forward( patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok) else: # Default to byte-level patching - patch_lengths = torch.ones((batch_size, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) + patch_lengths = torch.ones((bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) - return concat_entropies, patch_lengths, concat_predictions + return concat_entropies, patch_lengths, concat_preds @staticmethod def entropy(scores): """ - scores: [batch_size, seq_len, vocab] - returns [batch_size, seq_len] + scores: [bs, seq_len, vocab] + returns [bs, seq_len] Computes the entropy for each token in the batch. Note: uses natural log. @@ -1367,26 +1493,26 @@ def entropy(scores): @staticmethod def patch_start_ids_from_patch_start_mask(patch_start_mask): - batch_size, trunc_seq_len = patch_start_mask.shape + bs, trunc_seq_len = patch_start_mask.shape max_patches = patch_start_mask.sum(dim=1).max() if max_patches == 0: patch_start_ids = torch.full( - (batch_size, trunc_seq_len), + (bs, trunc_seq_len), trunc_seq_len, dtype=torch.long, device=patch_start_mask.device, ) else: - patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(batch_size, 1) + patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(bs, 1) extra_patch_ids = torch.full( - (batch_size, trunc_seq_len), + (bs, trunc_seq_len), trunc_seq_len, dtype=torch.long, device=patch_start_mask.device, ) all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1) - patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(batch_size, trunc_seq_len)[:, :max_patches] + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, trunc_seq_len)[:, :max_patches] return patch_start_ids @staticmethod @@ -1422,13 +1548,13 @@ def find_entropy_patch_start_ids( different sequences, but patches can be identified incrementally rather than decided globally using the entire sequence. """ - batch_size, sequence_length = entropies.shape[:2] + bs, seq_len = entropies.shape[:2] - first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) - predictions_truncation_len = first_ids.shape[1] # remove the first predictions because they will be start of patches. + first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(bs, 1) + preds_truncation_len = first_ids.shape[1] # remove the first preds because they will be start of patches. entropies = entropies[:, 1:] if threshold is None: - num_patches = sequence_length // patch_size + num_patches = seq_len // patch_size patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices patch_start_ids = patch_start_ids.sort(dim=1).values else: @@ -1438,7 +1564,7 @@ def find_entropy_patch_start_ids( # patch_start_mask[1:] |= tokens[:-1] < OFFSET patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) - patch_start_ids = torch.cat((first_ids, patch_start_ids + predictions_truncation_len), dim=1) + patch_start_ids = torch.cat((first_ids, patch_start_ids + preds_truncation_len), dim=1) return patch_start_ids def init_hash_embeddings( @@ -1473,4 +1599,4 @@ def init_hash_embeddings( "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer", -] \ No newline at end of file +] From e6c7b68d280aafc8d4ca3d8f89b15b178a078456 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 20 Jun 2025 12:30:46 +0000 Subject: [PATCH 027/139] updated conversion script --- src/convert_blt_to_hf.py | 151 +++++++++++++-------------------------- 1 file changed, 50 insertions(+), 101 deletions(-) diff --git a/src/convert_blt_to_hf.py b/src/convert_blt_to_hf.py index 5a4d368b6b7e..54b242d0f7c6 100644 --- a/src/convert_blt_to_hf.py +++ b/src/convert_blt_to_hf.py @@ -17,40 +17,18 @@ transformers_logging.set_verbosity_info() -def download_model_files(model_id: str, cache_dir: Optional[str] = None) -> Dict[str, str]: - config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir) - - weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir) - - entropy_params_path = hf_hub_download(repo_id=model_id, filename="entropy_model/params.json", cache_dir=cache_dir) - - entropy_weights_path = hf_hub_download( - repo_id=model_id, filename="entropy_model/consolidated.pth", cache_dir=cache_dir - ) - - return { - "config": config_path, - "weights": weights_path, - "entropy_params": entropy_params_path, - "entropy_weights": entropy_weights_path, - } - - def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]: - logger.info("Merging confi") + logger.info("Merging configurations") - # Load BLT configuration with open(config_path, "r") as f: main_config = json.load(f) - # Load Patcher entropy model parameters with open(entropy_params_path, "r") as f: entropy_data = json.load(f) entropy_model_params = entropy_data.get("entropy_model", {}) patcher_args = entropy_data.get("data", {}).get("patcher_args", {}) - # Create unified configuration unified_config = main_config.copy()["args"] for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]: @@ -61,7 +39,6 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str if isinstance(patch_size, float): patch_size = int(patch_size) - # Create patcher configuration dictionary patcher_config = { "vocab_size": int(entropy_model_params.get("vocab_size", 256)), "dim": int(entropy_model_params.get("dim", 512)), @@ -69,10 +46,10 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str "n_heads": int(entropy_model_params.get("n_heads", 8)), "head_dim": int(entropy_model_params.get("head_dim")) if entropy_model_params.get("head_dim") is not None - else None, # Let BLTPatcherConfig compute this from dim // n_heads + else None, "n_kv_heads": int(entropy_model_params.get("n_kv_heads")) if entropy_model_params.get("n_kv_heads") is not None - else None, # Let BLTPatcherConfig default this to n_heads + else None, "max_seqlen": int(entropy_model_params.get("max_seqlen", 1024)), "norm_eps": entropy_model_params.get("norm_eps", 1e-5), "dropout": entropy_model_params.get("dropout", 0.0), @@ -115,11 +92,41 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str return unified_config -def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]: - logger.info("Merging model weights") +def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + component_mappings = { + ".attention.": ".self_attn.", + ".feed_forward.": ".mlp.", + ".attention_norm.": ".input_layernorm.", + ".ffn_norm.": ".post_attention_layernorm.", + ".tok_embeddings.": ".embed_tokens.", + ".cross_attn_norm_q.": ".q_norm.", + ".cross_attn_norm_kv.": ".k_norm.", + ".w1.": ".gate_proj.", + ".w2.": ".down_proj.", + ".w3.": ".up_proj.", + ".wq.": ".q_proj.", + ".wk.": ".k_proj.", + ".wv.": ".v_proj.", + ".wo.": ".o_proj.", + ".output.": ".lm_head.", + } + + new_state_dict = {} + + for old_key, tensor in state_dict.items(): + new_key = old_key + + for old_pattern, new_pattern in component_mappings.items(): + if old_pattern in new_key: + new_key = new_key.replace(old_pattern, new_pattern) + + new_state_dict[new_key] = tensor + + return new_state_dict + +def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]: main_weights = load_file(weights_path) - logger.info(f"Loaded main model weights: {len(main_weights)} tensors") entropy_weights = torch.load(entropy_weights_path, map_location="cpu", weights_only=True) @@ -130,21 +137,18 @@ def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, tor logger.info(f"Loaded entropy model weights: {len(entropy_weights)} tensors") - # unified state dict unified_weights = main_weights.copy() - # Add entropy model weights with "patcher." prefix for key, tensor in entropy_weights.items(): patcher_key = f"patcher.{key}" unified_weights[patcher_key] = tensor - - logger.info(f"Merged weights: {len(unified_weights)} tensors total") + + unified_weights = apply_weight_mapping(unified_weights) + return unified_weights def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): - logger.info("Creating tokenizer config") - tokenizer_config = { "tokenizer_class": "BltTokenizer", "vocab_size": config.get("vocab_size", 256), @@ -164,54 +168,6 @@ def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): logger.info(f"Tokenizer config saved to {tokenizer_path}") -def validate_unified_model(config: Dict[str, Any], weights: Dict[str, torch.Tensor]): - logger.info("Validating unified model") - - required_keys = [ - "vocab_size", - "dim", - "n_layers", - "n_heads", - "patch_in_forward", - "patcher_args", - ] - - missing_keys = [key for key in required_keys if key not in config] - if missing_keys: - logger.warning(f"Missing configuration keys: {missing_keys}") - - # Check for patcher weights - patcher_weights = [key for key in weights.keys() if key.startswith("patcher.")] - if not patcher_weights: - logger.warning("No patcher weights found in unified weights") - else: - logger.info(f"Found {len(patcher_weights)} patcher weight tensors") - - main_weights = [key for key in weights.keys() if not key.startswith("patcher.")] - logger.info(f"Found {len(main_weights)} main model weight tensors") - - try: - logger.info("Testing model instantiation...") - blt_config = BLTConfig(**config) - model = BLTModel(blt_config) - - logger.info("Testing weight loading...") - try: - missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False) - if missing_keys: - logger.warning(f"Missing keys during weight loading: {missing_keys}") - if unexpected_keys: - logger.warning(f"Unexpected keys during weight loading: {unexpected_keys}") - logger.info("Weight loading successful") - except Exception as weight_error: - logger.warning(f"Weight loading failed: {weight_error}") - - except Exception as e: - logger.error(f"Model validation failed: {e}") - - logger.info("Model validation completed") - - def push_to_hub( local_dir: str, repo_id: str, @@ -220,7 +176,6 @@ def push_to_hub( token: Optional[str] = None, ) -> None: try: - # Upload the entire directory to the Hub upload_folder( folder_path=local_dir, repo_id=repo_id, @@ -246,17 +201,16 @@ def convert_hf_blt_to_unified( hub_private: bool = False, hub_token: Optional[str] = None, ) -> None: - logger.info(f"Converting {model_id} to unified transformers format") - - file_paths = download_model_files(model_id, cache_dir) - - # Merge configurations - unified_config = merge_configurations(file_paths["config"], file_paths["entropy_params"]) - - # Merge weights - unified_weights = merge_weights(file_paths["weights"], file_paths["entropy_weights"]) + # Download model files + config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir) + weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir) + entropy_params_path = hf_hub_download(repo_id=model_id, filename="entropy_model/params.json", cache_dir=cache_dir) + entropy_weights_path = hf_hub_download( + repo_id=model_id, filename="entropy_model/consolidated.pth", cache_dir=cache_dir + ) - validate_unified_model(unified_config, unified_weights) + unified_config = merge_configurations(config_path, entropy_params_path) + unified_weights = merge_weights(weights_path, entropy_weights_path) os.makedirs(output_dir, exist_ok=True) @@ -275,8 +229,6 @@ def convert_hf_blt_to_unified( else: torch.save(unified_weights, weights_path) - logger.info(f"Unified config and weights saved to {weights_path}") - create_tokenizer_config(output_dir, unified_config) logger.info(f"Conversion completed, model saved to: {output_dir}") @@ -305,10 +257,8 @@ def main(): parser.add_argument( "--output_dir", type=str, - default="./new_unified_blt_debug", + default="./blt_converted", ) - - # Optional parser.add_argument( "--config_name", type=str, @@ -339,11 +289,10 @@ def main(): action="store_true", default=True, ) - parser.add_argument( "--push_to_hub", type=str, - default="itazap/blt-1b-hf", + default="itazap/blt-1b-converted", ) parser.add_argument( "--hub_private", From e3fdebba18b39148c914ba3b8b5a3ea76b4d4cc1 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 23 Jun 2025 08:29:06 +0000 Subject: [PATCH 028/139] overwritten commit! fixing PR --- .../models/blt_wip/modeling_blt.py | 1469 ++++++++--------- 1 file changed, 662 insertions(+), 807 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 0a6eaf1408dc..21d3951265aa 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -1,10 +1,10 @@ -#blt old - # Copyright (c) Meta Platforms, Inc. and affiliates. import logging import os -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union + +from ...cache_utils import Cache import torch import torch.nn @@ -12,8 +12,10 @@ from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention -from ...modeling_utils import PreTrainedModel -from .configuration_blt_og import ( +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update + +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from .configuration_blt import ( BLTConfig, PatchingModeEnum, ) @@ -24,38 +26,6 @@ flex_attention_comp = flex_attention - -def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - -def create_causal_mask( - seqlen, - attn_impl: str, - attn_bias_type: str | None, - *, - eos_id: int | None = None, - tokens: torch.Tensor | None = None, - sliding_window: int | None = None, -): - if attn_impl == "sdpa": - BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) - - if attn_bias_type == "causal": - return "causal" - - if BLT_SUPPRESS_ATTN_ERROR == 1: - return "causal" - else: - raise ValueError( - "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" - ) - elif attn_impl == "flex_attention": - return create_block_mask(causal_mask, None, None, seqlen, seqlen) - else: - raise NotImplementedError(f"Attention {attn_impl} with {sliding_window} sliding window not implemented") - - def cross_entropy(pred, target, **kwargs): return F.nll_loss( F.log_softmax(pred.flatten(end_dim=-2).float(), -1), @@ -67,242 +37,15 @@ def cross_entropy(pred, target, **kwargs): def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." - bs, slen, n_kv_heads, head_dim = x.shape + batch_size, slen, n_kv_heads, head_dim = x.shape if n_rep == 1: return x return ( x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + .expand(batch_size, slen, n_kv_heads, n_rep, head_dim) + .reshape(batch_size, slen, n_kv_heads * n_rep, head_dim) ) - -def precompute_freqs_cis( - dim: int, - end: int, - theta: float = 10000.0, - rope_use_fp32_in_outer_product: bool = False, -): - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' - and the end index 'end'. The 'theta' parameter scales the frequencies. - The returned tensor contains complex values in complex64 data type. - - Args: - dim (int): Dimension of the frequency tensor. - end (int): End index for precomputing frequencies. - theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. - - Returns: - torch.Tensor: Precomputed frequency tensor with complex exponentials. - """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) - if rope_use_fp32_in_outer_product: - t = t.to(torch.float32) - - freqs = torch.outer(t, freqs).float() - - cos, sin = freqs.cos(), freqs.sin() - - return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): - """ - Reshape frequency tensor for broadcasting it with another tensor. - - This function reshapes the frequency tensor to have the same shape as the target tensor 'x' - for the purpose of broadcasting the frequency tensor during element-wise operations. - - Args: - freqs_cis (torch.Tensor): Frequency tensor to be reshaped. - x (torch.Tensor): Target tensor for broadcasting compatibility. - seq_dim (int): Sequence dimension index. - - Returns: - torch.Tensor: Reshaped frequency tensor. - """ - ndim = x.ndim - assert 0 <= seq_dim < ndim - assert freqs_cis.shape == ( - x.shape[seq_dim], - x.shape[-3], - 2, - 2, - ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" - shape = [d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])] + [2, 2] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - seq_dim: int, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - - xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 - xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 - freqs_cis = reshape_for_broadcast(freqs_cis, xq_, seq_dim).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 - xq_out = (xq_ * freqs_cis).sum(5).flatten(3) - xk_out = (xk_ * freqs_cis).sum(5).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed. -class RotaryEmbedding(torch.nn.Module): - """ - RotaryEmbedding Module - """ - - def __init__( - self, - theta: float, - head_dim: int, - max_seqlen: int = 1024, - rope_use_fp32_in_outer_product: bool = False, - ): - super().__init__() - - self.theta = theta - self.head_dim = head_dim - self.max_seqlen = max_seqlen - self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product - - self.register_buffer( - "freqs_cis", - precompute_freqs_cis( - dim=head_dim, - end=max_seqlen, - theta=theta, - rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, - ), - persistent=False, - ) - - - def forward(self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None): - """ - Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions - Args: - seqlen (int): Contiguous sequence length - tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen - - Returns: - Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis - """ - test = (seqlen is not None) or (tok_idx is not None) - assert test, "Should provide atleast seqlen or tok_idx" - if tok_idx is not None: - return self.freqs_cis[tok_idx] - elif seqlen is not None: - return self.freqs_cis[0:seqlen] - - -class BLTSelfAttention(nn.Module): - def __init__( - self, - dim: int, - head_dim: int, - n_heads: int, - n_kv_heads: int, - rope_theta: float, - ): - super().__init__() - - self.dim = dim - self.head_dim = head_dim - self.rope_theta = rope_theta - - self.n_heads = n_heads - self.n_kv_heads = n_kv_heads - self.heads_per_group = self.n_heads // self.n_kv_heads - - self.wq = nn.Linear( - dim, - n_heads * head_dim, - bias=False, - ) - self.wk = nn.Linear( - dim, - n_kv_heads * head_dim, - bias=False, - ) - self.wv = nn.Linear( - dim, - n_kv_heads * head_dim, - bias=False, - ) - - self.wo = nn.Linear( - n_heads * head_dim, - dim, - bias=False, - ) - - def forward( - self, - x: torch.Tensor, - freq_cis: torch.Tensor, - tok_idx: Optional[torch.Tensor] = None, - mask: Optional[Union[BlockMask, str]] = None, - attn_impl: str = "sdpa", - ) -> torch.Tensor: - # B S D - bsz, seq_len, dim = x.shape - - xq = self.wq(x.view_as(x)) - xk = self.wk(x.view_as(x)) - xv = self.wv(x.view_as(x)) - - output_shape = xq.shape - # B S D -> B S H D - xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) - xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) - xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) - - xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) - - # This condition helps us be easily compatible - # with inference by adding a pluggable KVCache - if hasattr(self, "kv_cache"): - xk, xv = self.kv_cache.update(xk, xv, tok_idx) - - xk = repeat_kv(xk, self.heads_per_group, dim=2) - xv = repeat_kv(xv, self.heads_per_group, dim=2) - - if attn_impl == "flex_attention": - assert mask is None or isinstance(mask, BlockMask) - xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) - output = flex_attention_comp(xq, xk, xv, block_mask=mask) - output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - - elif attn_impl == "sdpa": - xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) - assert mask is None or isinstance(mask, (str, torch.Tensor)) - is_causal = (mask == "causal") if isinstance(mask, str) else False - mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None - output = F.scaled_dot_product_attention( - xq, - xk, - xv, - is_causal=is_causal, - attn_mask=mask, - ) - output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - else: - raise NotImplementedError(f"Attention implementation {attn_impl} not supported") - - output_reshaped = output.reshape(output_shape) - - output = self.wo(output_reshaped) - - return output - - class BLTMLP(nn.Module): def __init__( self, @@ -323,17 +66,17 @@ def __init__( self.dim = dim self.hidden_dim = hidden_dim - self.w1 = nn.Linear( + self.gate_proj = nn.Linear( dim, hidden_dim, bias=False, ) - self.w3 = nn.Linear( + self.up_proj = nn.Linear( dim, hidden_dim, bias=False, ) - self.w2 = nn.Linear( + self.down_proj = nn.Linear( hidden_dim, dim, bias=False, @@ -341,72 +84,216 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: # B S D - x1 = self.w1(x.view_as(x)) - x3 = self.w3(x.view_as(x)) - output = self.w2(F.silu(x1) * x3) + x1 = self.gate_proj(x.view_as(x)) + x3 = self.up_proj(x.view_as(x)) + output = self.down_proj(F.silu(x1) * x3) return output +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + # TODO: not exactly equivalent to other transformers implementations,, need feedback + # Extract first head_dim//2 elements which correspond to the unique frequencies + # This matches the original BLT approach which uses head_dim//2 frequency pairs + head_dim = q.shape[-1] + cos_freqs = cos[..., :head_dim//2] # [B, S, D/2] + sin_freqs = sin[..., :head_dim//2] # [B, S, D/2] + + # Expand cos/sin to match query/key tensor format [B, H, S, D/2] + cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + + # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... + q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] + k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] + + # Extract real and i parts + q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] + k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] + + # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] + q_real_rot = cos_freqs * q_real - sin_freqs * q_imag + q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag + k_real_rot = cos_freqs * k_real - sin_freqs * k_imag + k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag + + # Recombine pairs and reshape back to original format + q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] + k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] + + return q_rot.type_as(q), k_rot.type_as(k) + + + +class BLTSelfAttention(nn.Module): + def __init__(self, config: BLTConfig, layer_idx: int): + super().__init__() + self.config = config + self.num_heads = config.num_attention_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = self.head_dim ** -0.5 + self.rope_theta = config.rope_theta + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + output_attentions = False + self.config._attn_implementation = "sdpa" + self.scaling = None + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value class BLTTransformerLayer(nn.Module): - def __init__(self, args): + def __init__(self, dim, n_heads, config, layer_idx=0): super().__init__() # Extract parameters from dictionary - dim = args["dim"] - n_heads = args["n_heads"] - head_dim = args["head_dim"] - n_kv_heads = args["n_kv_heads"] - rope_theta = args["rope_theta"] - multiple_of = args["multiple_of"] - ffn_dim_multiplier = args["ffn_dim_multiplier"] - norm_eps = args["norm_eps"] - - assert (head_dim is not None) or (n_heads is not None), "Should specify at least head_dim or n_heads" + dim = dim + n_heads = n_heads + head_dim = config.head_dim + n_kv_heads = config.n_kv_heads + rope_theta = config.rope_theta + multiple_of = config.multiple_of + ffn_dim_multiplier = config.ffn_dim_multiplier + norm_eps = config.norm_eps + self.head_dim = head_dim or dim // n_heads self.n_heads = n_heads or dim // head_dim self.n_kv_heads = n_kv_heads or self.n_heads - assert n_heads % self.n_kv_heads == 0 - assert dim % n_heads == 0 + config.hidden_size = dim - self.attention = BLTSelfAttention( - dim=dim, - head_dim=self.head_dim, - n_heads=self.n_heads, - n_kv_heads=self.n_kv_heads, - rope_theta=rope_theta, - ) - self.feed_forward = BLTMLP( + self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) + + self.mlp = BLTMLP( dim=dim, hidden_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, ) - self.attention_norm = RMSNorm(dim, eps=norm_eps) - self.ffn_norm = RMSNorm(dim, eps=norm_eps) + self.input_layernorm = RMSNorm(dim, eps=norm_eps) + self.post_attention_layernorm = RMSNorm(dim, eps=norm_eps) def forward( self, - x: torch.Tensor, - freq_cis: torch.Tensor, - tok_idx: Optional[torch.Tensor] = None, - mask: Optional[Union[BlockMask, str]] = None, - attn_impl: str = "sdpa", + hidden_states: torch.Tensor, + past_key_value: Optional[bool] = None, + position_embeddings: Optional[torch.Tensor] = None, + + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: - norm_x = self.attention_norm(x) - attn_out = self.attention( - norm_x, - freq_cis, - tok_idx=tok_idx, - mask=mask, - attn_impl=attn_impl, + + residual = hidden_states + norm_hidden_states = self.input_layernorm(hidden_states) + + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=norm_hidden_states, + # TODO: = BLT, attn_out = self.self_attn(self.input_layernorm(x), in TransformerBlock.forward, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + position_embeddings=position_embeddings ) - h = x + attn_out - h_norm = self.ffn_norm(h) - out = h + self.feed_forward(h_norm) - return out + + hidden_states = residual + hidden_states + normalized_hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + self.mlp(normalized_hidden_states) + return hidden_states def check_non_zero_after_zero(tensor): zero_mask = tensor == 0 @@ -420,7 +307,7 @@ def check_non_zero_after_zero(tensor): non_zero_after_zero = (tensor != 0) & shifted_mask return non_zero_after_zero.any() -def rolling_polynomial_hash(t, hash_func_nb: int = 0): +def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): primes = [ 1000000007, 5915587277, @@ -434,25 +321,25 @@ def rolling_polynomial_hash(t, hash_func_nb: int = 0): 5463458053, 3367900313, ] - prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) - prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) - return torch.sum(t * prime_powers, dim=-1) + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) + prime_powers = torch.stack([prime**i for i in range(token_tensor.shape[-1])]) + return torch.sum(token_tensor * prime_powers, dim=-1) -def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): +def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): """ - Returns a hash of the input x and maps it to a value in the range [0, max_hash]. + Returns a hash of the input token_ids and maps it to a value in the range [0, max_hash]. - expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. + expects: token_ids of shape (batch_size, seq_len) with values as ids in the token vocab. returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. Note: max hash can make a big difference on the number of collisions. """ with torch.no_grad(): - bs, seq_len = x.shape - prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) - x = torch.cat([prefix, x], dim=1) - windows = x.unfold(1, group_size, 1) + batch_size, seq_len = token_ids.shape + prefix = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) + token_ids = torch.cat([prefix, token_ids], dim=1) + windows = token_ids.unfold(1, group_size, 1) # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) hashes = rolling_polynomial_hash(windows, hash_func_nb) hash_values_range = hashes % max_hash @@ -462,32 +349,32 @@ def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: def create_patch_mask_from_ids(patch_ids, num_patches, window=None, patches_as_queries=False): """ - Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) + Creates a tensor of shape [batch_size, seq_len, num_patches] where each element at position (i, j, k) is True if the patch id at position (i, j) is less than or equal to k. Args: - patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. + patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. num_patches (int): Total number of patches. window (int): If not None, only considers patches within a window of size window. patches_as_queries (bool): If True, the patches are used as queries Returns: - torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. + torch.Tensor: Tensor of shape [batch_size, q_len, kv_len] with the desired mask. """ - bs, seq_len = patch_ids.shape + batch_size, seq_len = patch_ids.shape if not patches_as_queries: - q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) + q_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) kv_ids = ( torch.arange(num_patches, device=patch_ids.device) .unsqueeze(0) .unsqueeze(0) - .expand(bs, seq_len, num_patches) + .expand(batch_size, seq_len, num_patches) ) else: - kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) + kv_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) q_ids = ( torch.arange(num_patches, device=patch_ids.device) .unsqueeze(0) .unsqueeze(-1) - .expand(bs, num_patches, seq_len) + .expand(batch_size, num_patches, seq_len) ) if window is None: mask = q_ids == kv_ids @@ -505,7 +392,7 @@ def cross_attn_mask( window=None, block_mask=True, ): - bs = patch_ids.shape[0] + batch_size = patch_ids.shape[0] with torch.no_grad(): # Create the patch mask cross_mask = create_patch_mask_from_ids( @@ -517,19 +404,19 @@ def cross_attn_mask( q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k assert cross_mask.shape == ( - bs, + batch_size, q_len, kv_len, - ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" + ), f"{cross_mask.shape} != {(batch_size, q_len, kv_len)}" block_mask = None if block_mask: - def patch_mask(b, h, q_idx, kv_idx): + def patch_mask(b, num_heads, q_idx, kv_idx): return cross_mask[b, q_idx, kv_idx] block_mask = create_block_mask( patch_mask, - B=bs, + B=batch_size, H=None, Q_LEN=q_len, KV_LEN=kv_len, @@ -539,7 +426,7 @@ def patch_mask(b, h, q_idx, kv_idx): else: return torch.where(cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).unsqueeze( 1 - ) # [bs, 1, q_len, kv_len] + ) # [batch_size, 1, q_len, kv_len] def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor: @@ -572,156 +459,88 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> return padded -class BLTLocalModelBase(nn.Module): - def __init__(self, config: BLTConfig, component_type: str = "encoder"): + +class BLTRotaryEmbedding(nn.Module): + def __init__(self, config: BLTConfig, device=None): super().__init__() + self.rope_type = config.rope_scaling["rope_type"] + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq - if component_type == "encoder": - self.dim = config.dim_local_encoder - self.n_layers = config.n_layers_local_encoder - self.n_heads = config.n_heads_local_encoder - self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen - self.attn_bias_type = "local_block_causal" - self.sliding_window = config.local_attention_window_len - elif component_type == "decoder": - self.dim = config.dim_local_decoder - self.n_layers = config.n_layers_local_decoder - self.n_heads = config.n_heads_local_decoder - self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen - self.attn_bias_type = "local_block_causal" - self.sliding_window = config.local_attention_window_len - else: - raise ValueError(f"Unknown component_type: {component_type}") + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() - self.dropout = config.dropout - self.vocab_size = config.vocab_size + config.pm_size - self.patch_size = config.patch_size + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling - self.attn_impl = config.attn_impl - self.use_rope = config.use_rope - self.init_std_factor = config.init_std_factor - self.init_base_std = config.init_base_std - self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None) - self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None) - self.cross_attn_k = getattr(config, "cross_attn_k", None) - self.eos_id = config.eos_token_id + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - self.boe_id = config.boe_id - # Initialize cross attention layers as None (will be set by subclasses if needed) - self.cross_attn_layers = None - # Create parameter dict for BLTTransformerLayers - layer_params = { - "dim": self.dim, - "n_heads": self.n_heads, - "head_dim": config.head_dim, - "n_kv_heads": getattr(config, "n_kv_heads", None), - "rope_theta": config.rope_theta, - "multiple_of": getattr(config, "multiple_of", 256), - "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), - "norm_eps": config.norm_eps, - } - - self.layers = nn.ModuleList([BLTTransformerLayer(layer_params) for _ in range(self.n_layers)]) - - if not self.use_rope: - self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length - else: - self.rope = RotaryEmbedding( - theta=config.rope_theta, - head_dim=config.head_dim or self.dim // self.n_heads, - max_seqlen=self.max_seqlen, - rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, - ) - self.pos_embeddings = None +class BLTLocalEncoder(nn.Module): + def __init__(self, config: BLTConfig): + super().__init__() + + # Extract config values to instance attributes + self.dropout = config.dropout + self.dim_local_encoder = config.dim_local_encoder + self.n_layers_local_encoder = config.n_layers_local_encoder + self.n_heads_local_encoder = config.n_heads_local_encoder + self.vocab_size = config.vocab_size + self.pm_size = config.pm_size + self.cross_attn_encoder = config.cross_attn_encoder + self.cross_attn_nheads = config.cross_attn_nheads + self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_k = config.cross_attn_k + self.norm_eps = config.norm_eps + self.sliding_window = config.sliding_window + + self.layers = nn.ModuleList([BLTTransformerLayer(self.dim_local_encoder, self.n_heads_local_encoder, config) for _ in range(self.n_layers_local_encoder)]) - # Set dimension-specific embedding dimensions - if component_type == "encoder": - self.dim_token_emb = config.encoder_dim_token_emb - self.dim_patch_emb = config.encoder_dim_patch_emb - elif component_type == "decoder": - self.dim_token_emb = config.decoder_dim_token_emb - self.dim_patch_emb = config.dim_global + # Set up config for rotary embedding + encoder_config = config + encoder_config.head_dim = self.dim_local_encoder // self.n_heads_local_encoder + encoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen + self.rotary_emb = BLTRotaryEmbedding(config=encoder_config) self.token_embedding_projection = ( - nn.Linear(self.dim_token_emb, self.dim, bias=False) - if self.dim_token_emb is not None and self.dim_token_emb != self.dim + nn.Linear(config.encoder_dim_token_emb, self.dim_local_encoder, bias=False) + if config.encoder_dim_token_emb is not None and config.encoder_dim_token_emb != self.dim_local_encoder else None ) self.patch_embedding_projection = self._create_patch_projection(config) - def _should_create_patch_projection(self, config: BLTConfig): - dimension_mismatch = self.dim_patch_emb is not None and self.dim_patch_emb != self.dim + self.embed_tokens = nn.Embedding(self.vocab_size + self.pm_size, self.dim_local_encoder) - # Check cross attention conditions - cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( - config.cross_attn_decoder and config.cross_attn_init_by_pooling - ) - - return dimension_mismatch or cross_attn_conditions - - def _create_patch_projection(self, config): - if not self._should_create_patch_projection(config): - return None - - output_dim = self.dim_token_emb * (self.cross_attn_k or 1) - - return nn.Linear( - in_features=self.dim_patch_emb, - out_features=output_dim, - bias=False, - ) - - def apply_embedding(self, tokens, embeds): - if embeds is not None: - return embeds - else: - return self.tok_embeddings(tokens) - - -class BLTLocalEncoder(BLTLocalModelBase): - def __init__(self, config: BLTConfig): - super().__init__(config, component_type="encoder") - - self.apply_transformer = config.use_local_encoder_transformer - self.downsampling_by_pooling = config.downsampling_by_pooling - self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None - self.cross_attn_encoder = config.cross_attn_encoder - self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder - self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling - self.cross_attn_nheads = config.cross_attn_nheads - - self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim) - - if self.cross_attn_encoder: + # Initialize cross attention layers only if cross attention is enabled + self.cross_attn_layers = None + if self.cross_attn_encoder and self.cross_attn_nheads is not None: self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1 - for _ in range(layers_to_add): + layers_to_add = self.n_layers_local_encoder if self.cross_attn_all_layers_encoder else 1 + for layer_idx in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention( - dim=self.dim, - head_dim=self.dim // self.cross_attn_nheads, - n_heads=self.cross_attn_nheads, - n_kv_heads=self.cross_attn_nheads, - norm_eps=config.norm_eps, - ) + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.dim_local_encoder) ) - def apply_embedding(self, tokens, embeds): - if embeds is not None: - assert self.expects_hash_embeddings, "Not expecting embeddings to be passed." - return embeds - else: - return self.tok_embeddings(tokens) - def forward( self, - tokens: torch.Tensor, - embeds: Optional[torch.Tensor] = None, + input_ids: torch.Tensor, + input_embeds: Optional[torch.Tensor] = None, patch_embeds: Optional[torch.Tensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, @@ -730,48 +549,65 @@ def forward( cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): """ """ - bs, seqlen = tokens.shape - if mask is None: - mask = create_causal_mask( - seqlen, - self.attn_impl, - "local_block_causal", - sliding_window=self.sliding_window, - tokens=tokens, - eos_id=self.eos_id, - ) + batch_size, sequence_length = input_ids.shape + if input_embeds is None: + input_embeds = self.embed_tokens(input_ids) - h = self.apply_embedding(tokens, embeds) - - - freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None - + batch_size, _, _ = input_embeds.shape - h = F.dropout(h, p=self.dropout, training=self.training) + hidden_states = input_embeds - for i, layer in enumerate(self.layers): - h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl) - # check if cross attention should be applied to either all layer or only the last layer - if self.cross_attn_encoder and (i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): - # apply pooling and project - if self.cross_attn_init_by_pooling and patch_embeds is None: - patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + for idx, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + + if self.cross_attn_encoder and (idx == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): + # Initialize patch_embeds if not provided when cross attention is enabled + if patch_embeds is None: + patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) if self.patch_embedding_projection is not None: patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) - - layer_idx = i if self.cross_attn_all_layers_encoder else 0 - patch_embeds_cross = self.cross_attn_layers[layer_idx]( - x=patch_embeds, - kv=h, - mask=cross_mask, + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.dim_local_encoder) + + layer_idx = idx if self.cross_attn_all_layers_encoder else 0 + cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( + hidden_states=patch_embeds, + cross_attention_states=hidden_states, + attention_mask=cross_mask, + output_attentions=False, + use_cache=False, + cache_position=None, ) - patch_embeds = patch_embeds + patch_embeds_cross + patch_embeds = patch_embeds + cross_attention_output - h_residual = patch_embeds if self.cross_attn_encoder else None - return (h, h_residual), cache + encoder_cross_states = patch_embeds if self.cross_attn_encoder else None + return (hidden_states, encoder_cross_states), cache + + def _create_patch_projection(self, config): + dimension_mismatch = config.encoder_dim_patch_emb is not None and config.encoder_dim_patch_emb != config.dim_local_encoder - def patch_reduce(self, h, max_num_patches, reduction, patch_ids): + cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( + config.cross_attn_decoder and config.cross_attn_init_by_pooling + ) + + if not (dimension_mismatch or cross_attn_conditions): + return None + + output_dim = config.encoder_dim_token_emb * config.cross_attn_k + + return nn.Linear( + in_features=config.encoder_dim_patch_emb, + out_features=output_dim, + bias=False, + ) + + def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): """ Reduce variable length patches to single embedding per patch Note: this works with variable number of patches for different sequences in the batch @@ -782,52 +618,90 @@ def patch_reduce(self, h, max_num_patches, reduction, patch_ids): (i.e. if the sum(patch_lengths[i]) < seq_len for any i) will be sent to a dummy patch, which is trimmed before returning. """ - bs, seq_len, emb_dim = h.shape + batch_size, seq_len, embedding_dim = hidden_states.shape - patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) - reduced_embs = torch.zeros((bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device) - reduced_embs = reduced_embs.scatter_reduce( - src=h, + reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device) + reduced_embeddings = reduced_embeddings.scatter_reduce( + src=hidden_states, dim=1, index=patch_ids, reduce=reduction, include_self=False, ) - reduced_embs = reduced_embs[:, :max_num_patches, :] + reduced_embeddings = reduced_embeddings[:, :max_num_patches, :] - return reduced_embs + return reduced_embeddings -class BLTLocalDecoder(BLTLocalModelBase): +class BLTLocalDecoder(nn.Module): def __init__(self, config: BLTConfig): - super().__init__(config, component_type="decoder") + super().__init__() - # Model configuration flags + # Extract config values to instance attributes + self.dim_local_decoder = config.dim_local_decoder + self.n_heads_local_decoder = config.n_heads_local_decoder + self.n_layers_local_decoder = config.n_layers_local_decoder + self.vocab_size = config.vocab_size + self.norm_eps = config.norm_eps + self.dropout = config.dropout self.cross_attn_decoder = config.cross_attn_decoder - self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder - self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling self.cross_attn_nheads = config.cross_attn_nheads + self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder + self.cross_attn_k = config.cross_attn_k + self.sliding_window = config.sliding_window + + self.layers = nn.ModuleList([BLTTransformerLayer(self.dim_local_decoder, self.n_heads_local_decoder, config) for _ in range(self.n_layers_local_decoder)]) + + decoder_config = config + decoder_config.head_dim = self.dim_local_decoder // self.n_heads_local_decoder + decoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen - self.norm = RMSNorm(self.dim, eps=config.norm_eps) + self.rotary_emb = BLTRotaryEmbedding(config=decoder_config) - if self.cross_attn_decoder: + self.token_embedding_projection = ( + nn.Linear(config.decoder_dim_token_emb, self.dim_local_decoder, bias=False) + if config.decoder_dim_token_emb is not None and config.decoder_dim_token_emb != self.dim_local_decoder + else None + ) + + self.patch_embedding_projection = self._create_patch_projection(config) + + self.norm = RMSNorm(self.dim_local_decoder, eps=self.norm_eps) + + # Initialize cross attention layers only if cross attention is enabled + self.cross_attn_layers = None + if self.cross_attn_decoder and self.cross_attn_nheads is not None: self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1 - for _ in range(layers_to_add): + layers_to_add = self.n_layers_local_decoder if self.cross_attn_all_layers_decoder else 1 + for layer_idx in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention( - dim=self.dim, - head_dim=self.dim // self.cross_attn_nheads, - n_heads=self.cross_attn_nheads, - n_kv_heads=self.cross_attn_nheads, - norm_eps=config.norm_eps, - ) + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.dim_local_decoder) ) - self.output = nn.Linear( - self.dim, - config.vocab_size, + self.lm_head = nn.Linear( + self.dim_local_decoder, + self.vocab_size, + bias=False, + ) + + def _create_patch_projection(self, config): + dimension_mismatch = config.dim_global is not None and config.dim_global != config.dim_local_decoder + + # Check cross attention conditions + cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( + config.cross_attn_decoder and config.cross_attn_init_by_pooling + ) + + if not (dimension_mismatch or cross_attn_conditions): + return None + + output_dim = config.decoder_dim_token_emb * config.cross_attn_k + + return nn.Linear( + in_features=config.dim_global, + out_features=output_dim, bias=False, ) @@ -840,214 +714,198 @@ def forward( cross_mask: Optional[torch.Tensor] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - bs, seqlen = tokens.shape - assert embeds is not None, "Embeddings must be provided" + batch_size, sequence_length = tokens.shape + batch_size, seq_length, _ = embeds.shape - if mask is None: - mask = create_causal_mask( - seqlen, - self.attn_impl, - "local_block_causal", - sliding_window=self.sliding_window, - tokens=tokens, - eos_id=self.eos_id, - ) + assert embeds is not None, "Embeddings must be provided" - h = embeds + hidden_states = embeds if self.patch_embedding_projection is not None: assert patch_embeds is not None, "Patch embeddings must be passed." patch_embeds = self.patch_embedding_projection(patch_embeds) if self.cross_attn_k is not None: - patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.dim_local_decoder) if patch_embeds is not None and not self.cross_attn_decoder: - h = h + patch_embeds + hidden_states = hidden_states + patch_embeds - freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) - h = F.dropout(h, p=self.dropout, training=self.training) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder): - # Use cross attention to extract info from patch_embeds into h - h_cross = self.cross_attn_layers[i]( - x=h, - kv=patch_embeds, - mask=cross_mask, + # Use cross attention to extract info from patch_embeds into hidden_states + cross_attention_output, _, _ = self.cross_attn_layers[i]( + hidden_states=hidden_states, + cross_attention_states=patch_embeds, + attention_mask=cross_mask, + output_attentions=False, + use_cache=False, + cache_position=None, ) - h = h + h_cross + hidden_states = hidden_states + cross_attention_output - h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl) + hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) - h_preds = self.norm(h) - h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) - h_preds = self.output(h_preds) - h_preds = h_preds.float() - return h_preds, cache + logits = self.lm_head(self.norm(hidden_states)) + return logits, cache class BLTCrossAttention(nn.Module): - def __init__( - self, - dim: int, - head_dim: int, - n_heads: int, - n_kv_heads: int, - norm_eps: float, - ): + """Cross-attention module for BLT, following transformers style""" + + def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None): super().__init__() + self.config = config + self.layer_idx = layer_idx + # Use provided hidden_size or fallback to encoder dimension + self.hidden_size = hidden_size or config.dim_local_encoder + self.num_heads = config.cross_attn_nheads + self.num_key_value_heads = config.cross_attn_nheads # Assuming same for cross attention + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = None #self.head_dim ** -0.5 + self.dropout = config.dropout - self.dim = dim - self.head_dim = head_dim + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.n_heads = n_heads - self.n_kv_heads = n_kv_heads - self.heads_per_group = self.n_heads // self.n_kv_heads + self.q_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) + self.k_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) - self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) - self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(hidden_states) # BLT normalizes first + query_states = self.q_proj(query_states) + + if cross_attention_states is not None: + cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + if past_key_value is not None: + # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states + # we still update the cross key states, past_cross_states, new_cross_states. And use it! + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position is not None and cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + if cross_attention_states is None: + raise ValueError( + "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!" + ) - self.wq = nn.Linear( - dim, - n_heads * head_dim, - bias=False, - ) - self.wk = nn.Linear( - dim, - n_kv_heads * head_dim, - bias=False, - ) - self.wv = nn.Linear( - dim, - n_kv_heads * head_dim, - bias=False, - ) + attention_interface: Callable = eager_attention_forward - self.wo = nn.Linear( - n_heads * head_dim, - dim, - bias=False, + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0, #if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, ) - def forward( - self, - x: torch.Tensor, - kv: torch.Tensor, - mask: Optional[Union[BlockMask, str]] = None, - ) -> torch.Tensor: - # B S D - bsz, seq_len, _ = x.shape - _, slen_kv, _ = kv.shape - x_norm = self.cross_attn_norm_q(x) - kv = self.cross_attn_norm_kv(kv) - - xq = self.wq(x_norm) - xk = self.wk(kv) - xv = self.wv(kv) - - output_shape = xq.shape - # B S D -> B S H D - xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) - xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) - xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) - - xk = repeat_kv(xk, self.heads_per_group, dim=2) - xv = repeat_kv(xv, self.heads_per_group, dim=2) - - # assert mask is None or isinstance(mask, BlockMask) - xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) - # output = flex_attention_comp(xq, xk, xv, block_mask=mask) - is_causal = (mask == "causal") if isinstance(mask, str) else False - mask = mask if isinstance(mask, torch.Tensor) else None - mask = mask.to(dtype=xq.dtype).to(xq.device) - output = F.scaled_dot_product_attention( - xq, - xk, - xv, - is_causal=is_causal, - attn_mask=mask, - ) - output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) - output = self.wo(output.reshape(output_shape)) + attn_output = attn_output + hidden_states #TODO: they add the residual twice?? move this out - return x + output + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value class BLTGlobalTransformer(nn.Module): def __init__(self, config): super().__init__() - self.config = config - - self.dim = config.dim_global - self.rope_embeddings = RotaryEmbedding( - theta=config.rope_theta, - head_dim=config.head_dim or self.config.dim_global // config.n_heads_global, - max_seqlen=config.max_seqlen, - rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, - ) - # Handle both eos_id and eos_token_id for compatibility - self.eos_id = getattr(config, "eos_id", getattr(config, "eos_token_id", 2)) - - # Create parameter dict for BLTTransformerLayers - layer_params = { - "dim": self.dim, - "n_heads": config.n_heads_global, - "head_dim": config.head_dim, - "n_kv_heads": getattr(config, "n_kv_heads_global", None), - "rope_theta": config.rope_theta, - "multiple_of": getattr(config, "multiple_of", 256), - "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), - "norm_eps": config.norm_eps, - } + # Extract config values to instance attributes + self.dim_global = config.dim_global + self.n_heads_global = config.n_heads_global + self.n_layers_global = config.n_layers_global + self.dropout = config.dropout self.layers = nn.ModuleList() - for _ in range(config.n_layers_global): - self.layers.append(BLTTransformerLayer(layer_params)) + old = config.n_kv_heads + config.n_kv_heads = config.n_kv_heads_global + for _ in range(self.n_layers_global): + self.layers.append(BLTTransformerLayer(self.dim_global, self.n_heads_global, config)) + config.n_kv_heads = old + + global_config = config + global_config.head_dim = self.dim_global // self.n_heads_global + + self.rotary_emb = BLTRotaryEmbedding(config=global_config) self.token_embedding_projection = None - if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim: + if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim_global: self.token_embedding_projection = nn.Linear( config.global_dim_patch_emb, - config.dim_global, + self.dim_global, bias=False, ) def forward( self, - tokens: torch.Tensor, + input_ids: torch.Tensor, tok_idx: Optional[torch.Tensor] = None, - embeds: Optional[torch.Tensor] = None, + input_embeds: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - bs, seqlen = tokens.shape - - h = embeds - - mask = ( - mask - if mask is not None - else create_causal_mask( - seqlen, - self.config.attn_impl, - self.config.attn_bias_type, - tokens=tokens, - eos_id=self.eos_id, - ) - ) + batch_size, seq_length, _ = input_embeds.shape - if self.token_embedding_projection is not None and h.shape[-1] != self.dim: - h = self.token_embedding_projection(h) + hidden_states = input_embeds - h = F.dropout(h, p=self.config.dropout, training=self.training) - freq_cis = self.rope_embeddings(seqlen=self.config.max_seqlen, tok_idx=tok_idx) + if self.token_embedding_projection is not None and hidden_states.shape[-1] != self.dim_global: + hidden_states = self.token_embedding_projection(hidden_states) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, layer in enumerate(self.layers): - h = layer(h, freq_cis, tok_idx=None, mask=mask, attn_impl=self.config.attn_impl) + hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) - return h, cache + return hidden_states, cache def compute_hash_embeddings( @@ -1063,7 +921,7 @@ def compute_hash_embeddings( Args: local_encoder_tokens: Input tokens tensor - local_encoder: Encoder object with tok_embeddings method + local_encoder: Encoder object with embed_tokens method encoder_hash_tok_embedding: ModuleList of hash token embeddings encoder_hash_byte_group_nb_functions: Number of hash functions encoder_hash_byte_group_size: List of byte group sizes @@ -1075,7 +933,7 @@ def compute_hash_embeddings( if encoder_hash_tok_embedding is None: return None - local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens) + local_encoder_embeds = local_encoder.embed_tokens(local_encoder_tokens) i = 0 for func_nb in range(encoder_hash_byte_group_nb_functions): @@ -1128,59 +986,69 @@ def _init_weights(self, module): a=-3 * std, b=3 * std, ) - - elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)): - nn.init.ones_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - - elif isinstance(module, RotaryEmbedding): - module.freqs_cis[...] = precompute_freqs_cis( - dim=module.head_dim, - end=module.max_seqlen, - theta=module.theta, - rope_use_fp32_in_outer_product=module.rope_use_fp32_in_outer_product, - ) - + elif isinstance(module, BLTModel): if module.encoder_hash_tok_embedding is not None: - emb_std = module.local_encoder.dim ** (-0.5) + emb_std = module.config.dim_local_encoder ** (-0.5) for emb in module.encoder_hash_tok_embedding: emb._custom_std = emb_std - elif isinstance(module, (BLTLocalEncoder, BLTLocalDecoder)): + elif isinstance(module, BLTLocalEncoder): if module.token_embedding_projection is not None: - module.token_embedding_projection._custom_std = module.dim ** (-0.5) + module.token_embedding_projection._custom_std = module.config.dim_local_encoder ** (-0.5) if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.dim_patch_emb ** (-0.5) + module.patch_embedding_projection._custom_std = module.config.encoder_dim_patch_emb ** (-0.5) + + elif isinstance(module, BLTLocalDecoder): + if module.token_embedding_projection is not None: + module.token_embedding_projection._custom_std = module.config.dim_local_decoder ** (-0.5) + + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.config.dim_global ** (-0.5) elif isinstance(module, BLTGlobalTransformer): if module.token_embedding_projection is not None: module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5) elif isinstance(module, BLTPatcher): - emb_std = module.config.patcher_dim ** (-0.5) - module.tok_embeddings._custom_std = emb_std - module.output._custom_std = emb_std + emb_std = module.config.dim ** (-0.5) + module.embed_tokens._custom_std = emb_std + module.lm_head._custom_std = emb_std class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) - self.config = config + # Extract frequently used config values + self.patch_in_forward = config.patch_in_forward + self.patching_mode = config.patching_mode + self.patch_size = config.patch_size + self.patching_threshold = config.patching_threshold + self.max_patch_length = config.max_patch_length + self.patching_batch_size = config.patching_batch_size + self.patching_device = config.patching_device + self.cross_attn_encoder = config.cross_attn_encoder + self.cross_attn_decoder = config.cross_attn_decoder + self.cross_attn_k = config.cross_attn_k + self.cross_attn_window_encoder = config.cross_attn_window_encoder + self.cross_attn_window_decoder = config.cross_attn_window_decoder + self.cross_attn_use_flex_attention = config.cross_attn_use_flex_attention + self.boe_id = config.boe_id + self.eos_token_id = config.eos_token_id + self.local_encoder = BLTLocalEncoder(config) self.global_transformer = BLTGlobalTransformer(config) self.local_decoder = BLTLocalDecoder(config) self.encoder_hash_tok_embedding = init_hash_embeddings( config, - local_encoder_dim=self.local_encoder.dim, + local_encoder_dim=config.dim_local_encoder, encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, ) - if config.patch_in_forward: + if self.patch_in_forward: self.patcher = BLTPatcher(config) self.patcher.eval() for param in self.patcher.parameters(): @@ -1196,36 +1064,36 @@ def forward( # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings # are no longer used in the final BLT model - bs, N = tokens.shape # Batch size and sequence length + batch_size, sequence_length = tokens.shape # Batch size and sequence length local_encoder_tokens, local_decoder_tokens = tokens, tokens # Patching if patch_lengths is None: # assert ( - # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward + # getattr(self, "patch_in_forward", None) is not None and self.patch_in_forward # ), "Patch in forward not enabled and no patch_lengths passed." # PATCHER MODEL DEFINED - if self.config.patching_mode == PatchingModeEnum.entropy: + if self.patching_mode == PatchingModeEnum.entropy: _, patch_lengths, _ = self.patcher( local_encoder_tokens, - patch_size=self.config.patch_size, + patch_size=self.patch_size, include_next_token=True, - threshold=self.config.patching_threshold, - max_patch_length=self.config.max_patch_length, - patching_batch_size=self.config.patching_batch_size, - device=self.config.patching_device, + threshold=self.patching_threshold, + max_patch_length=self.max_patch_length, + patching_batch_size=self.patching_batch_size, + device=self.patching_device, ) else: - # self.config.patching_mode == PatchingModeEnum.byte - bs, seq_len = local_encoder_tokens.shape + # self.patching_mode == PatchingModeEnum.byte + batch_size_tokens, seq_len = local_encoder_tokens.shape seq_len_next_tok = seq_len + 1 # include_next_token=True patch_lengths = torch.ones( - (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + (batch_size_tokens, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device ) - patch_lengths = process_patch_lengths(patch_lengths, self.config.max_patch_length) + patch_lengths = process_patch_lengths(patch_lengths, self.max_patch_length) #assert torch.min(patch_lengths) >= 0 # Generate patch IDs from patch_lengths @@ -1236,15 +1104,15 @@ def forward( cross_attn_mask_enc = None # Cross-attention encoder - if self.config.cross_attn_encoder: + if self.cross_attn_encoder: cross_attn_mask_enc = cross_attn_mask( patch_ids, patch_lengths, - N, + sequence_length, patches_as_queries=True, - cross_attn_k=self.config.cross_attn_k, - window=self.config.cross_attn_window_encoder, - block_mask=self.config.cross_attn_use_flex_attention, + cross_attn_k=self.cross_attn_k, + window=self.cross_attn_window_encoder, + block_mask=self.cross_attn_use_flex_attention, ) # Hashing and embedding @@ -1261,9 +1129,9 @@ def forward( # The final BLT model uses only hash-based n-gram embeddings # Local encoder - (h_encoder, h_cross), cache_encoder = self.local_encoder( - tokens=local_encoder_tokens, - embeds=local_encoder_embeds, + (encoder_hidden_states, encoder_cross_states), cache_encoder = self.local_encoder( + input_ids=local_encoder_tokens, + input_embeds=local_encoder_embeds, patch_embeds=None, cross_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], @@ -1271,50 +1139,58 @@ def forward( ) # Downsampling - h = h_cross.view(bs, patch_lengths.shape[1], -1) + if encoder_cross_states is not None: + # Cross attention is enabled - use cross states + global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) + else: + # Cross attention is disabled - use reduced embeddings from encoder hidden states + global_hidden_states = self.local_encoder.patch_reduce( + encoder_hidden_states, patch_lengths.shape[1], "amax", patch_ids + ) # Global transformer - global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.config.boe_id) - rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id) + global_tokens = tokens.new(global_hidden_states.shape[0], global_hidden_states.shape[1]).fill_(self.boe_id) + rows, cols = torch.where(local_encoder_tokens == self.eos_token_id) eos_patch_ids = patch_ids[rows, cols] - global_tokens[rows, eos_patch_ids] = self.config.eos_token_id + global_tokens[rows, eos_patch_ids] = self.eos_token_id - h, _ = self.global_transformer( - embeds=h, - tokens=global_tokens, + global_hidden_states, _ = self.global_transformer( + input_embeds=global_hidden_states, + input_ids=global_tokens, ) # Unpatching - dec_embeds = h_encoder + decoder_embeds = encoder_hidden_states # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens), so we need to map decoder positions to the remaining patches. decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], local_decoder_tokens.shape[-1]) - # assert torch.max(decoder_patch_ids) + 1 <= h.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" - # assert decoder_patch_ids.shape[1] == dec_embeds.shape[1], ( - # f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" + # assert torch.max(decoder_patch_ids) + 1 <= global_hidden_states.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {global_hidden_states.shape[1]}" + # assert decoder_patch_ids.shape[1] == decoder_embeds.shape[1], ( + # f"{decoder_patch_ids.shape[1]} != {decoder_embeds.shape[1]}" # ) # Cross-attention decoder - if not self.config.cross_attn_decoder: - h = torch.gather(h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])) + if not self.cross_attn_decoder: + patch_hidden_states = torch.gather(global_hidden_states, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, global_hidden_states.shape[-1])) cross_attn_mask_dec = None - # assert local_decoder_tokens.shape == h.shape[:-1] + # assert local_decoder_tokens.shape == patch_hidden_states.shape[:-1] else: + patch_hidden_states = global_hidden_states cross_attn_mask_dec = cross_attn_mask( decoder_patch_ids, patch_lengths, - N, + sequence_length, patches_as_queries=False, - cross_attn_k=self.config.cross_attn_k, - window=self.config.cross_attn_window_decoder, - block_mask=self.config.cross_attn_use_flex_attention, + cross_attn_k=self.cross_attn_k, + window=self.cross_attn_window_decoder, + block_mask=self.cross_attn_use_flex_attention, ) # Local decoder output, _ = self.local_decoder( - embeds=dec_embeds, - patch_embeds=h, + embeds=decoder_embeds, + patch_embeds=patch_hidden_states, tokens=local_decoder_tokens, cross_mask=cross_attn_mask_dec, ) @@ -1365,41 +1241,22 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> class BLTPatcher(BLTPreTrainedModel): def __init__(self, config): - super().__init__(config) + super().__init__(config.patcher_config) - self.rope_embeddings = RotaryEmbedding( - theta=config.patcher_rope_theta, - head_dim=config.patcher_head_dim or config.patcher_dim // config.patcher_n_heads, - max_seqlen=config.patcher_max_seqlen, - rope_use_fp32_in_outer_product=config.patcher_rope_use_fp32_in_outer_product, - ) + self.rotary_emb = BLTRotaryEmbedding(config=self.config) self.layers = nn.ModuleList() - for _ in range(config.patcher_n_layers): - self.layers.append( - BLTTransformerLayer( - { - "dim": config.patcher_dim, - "n_heads": config.patcher_n_heads, - "head_dim": config.patcher_head_dim, - "n_kv_heads": config.patcher_n_kv_heads, - "rope_theta": config.patcher_rope_theta, - "multiple_of": config.patcher_multiple_of, - "ffn_dim_multiplier": config.patcher_ffn_dim_multiplier, - "norm_eps": config.patcher_norm_eps, - } - ) - ) + for _ in range(self.config.n_layers): + self.layers.append(BLTTransformerLayer(self.config.dim, self.config.n_heads, self.config)) - #assert config.patcher_vocab_size > 0 - self.tok_embeddings = torch.nn.Embedding(config.patcher_vocab_size, config.patcher_dim) + self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim) - self.norm = RMSNorm(config.patcher_dim, eps=config.patcher_norm_eps) + self.norm = RMSNorm(self.config.dim, eps=self.config.norm_eps) - self.output = nn.Linear( - config.patcher_dim, - config.patcher_vocab_size, + self.lm_head = nn.Linear( + self.config.dim, + self.config.vocab_size, bias=False, ) @@ -1416,8 +1273,8 @@ def forward( # Handle chunked processing for entropy calculation entropies = [] - preds = [] - max_length = self.config.patcher_max_seqlen + predictions = [] + max_length = self.config.max_seqlen batch_numel = max_length * patching_batch_size splits = torch.split(token_values.flatten(), batch_numel) @@ -1430,34 +1287,32 @@ def forward( split = split.to(device) # Process chunk: embeddings -> layers -> output - bsz, seqlen = split.shape - h = self.tok_embeddings(split) - chunk_mask = create_causal_mask( - seqlen, - self.config.patcher_attn_impl , - self.config.patcher_attn_bias_type, - sliding_window=self.config.patcher_sliding_window, - tokens=split, - eos_id=self.config.eos_id, - ) + batch_size, sequence_length = split.shape + input_embeds = self.embed_tokens(split) + + hidden_states = input_embeds - freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None) + batch_size, _, _ = input_embeds.shape + position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) # = BLT self.rope + for i, layer in enumerate(self.layers): - h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=self.config.patcher_attn_impl) + hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) #, attn_impl=self.config.patcher_attn_impl ) - pred = self.output(self.norm(h)) - pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] - preds.append(pred) - pred_entropies = self.entropy(pred) - entropies.append(pred_entropies) + logits = self.lm_head(self.norm(hidden_states)) + logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] + predictions.append(logits) + prediction_entropies = self.entropy(logits) + entropies.append(prediction_entropies) concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) - concat_preds = torch.cat(preds, dim=0).reshape(token_values.shape[0], -1) + concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) # Always compute patch lengths from concatenated entropies - bs, seq_len = token_values.shape - seq_len_next_tok = seq_len + 1 if include_next_token else seq_len + batch_size, sequence_length = token_values.shape + seq_len_next_tok = sequence_length + 1 if include_next_token else sequence_length # Find patch start IDs based on entropy if patch_size is not None: @@ -1470,17 +1325,17 @@ def forward( patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok) else: # Default to byte-level patching - patch_lengths = torch.ones((bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) + patch_lengths = torch.ones((batch_size, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) - return concat_entropies, patch_lengths, concat_preds + return concat_entropies, patch_lengths, concat_predictions @staticmethod def entropy(scores): """ - scores: [bs, seq_len, vocab] - returns [bs, seq_len] + scores: [batch_size, seq_len, vocab] + returns [batch_size, seq_len] Computes the entropy for each token in the batch. Note: uses natural log. @@ -1493,26 +1348,26 @@ def entropy(scores): @staticmethod def patch_start_ids_from_patch_start_mask(patch_start_mask): - bs, trunc_seq_len = patch_start_mask.shape + batch_size, trunc_seq_len = patch_start_mask.shape max_patches = patch_start_mask.sum(dim=1).max() if max_patches == 0: patch_start_ids = torch.full( - (bs, trunc_seq_len), + (batch_size, trunc_seq_len), trunc_seq_len, dtype=torch.long, device=patch_start_mask.device, ) else: - patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(bs, 1) + patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(batch_size, 1) extra_patch_ids = torch.full( - (bs, trunc_seq_len), + (batch_size, trunc_seq_len), trunc_seq_len, dtype=torch.long, device=patch_start_mask.device, ) all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1) - patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, trunc_seq_len)[:, :max_patches] + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(batch_size, trunc_seq_len)[:, :max_patches] return patch_start_ids @staticmethod @@ -1548,13 +1403,13 @@ def find_entropy_patch_start_ids( different sequences, but patches can be identified incrementally rather than decided globally using the entire sequence. """ - bs, seq_len = entropies.shape[:2] + batch_size, sequence_length = entropies.shape[:2] - first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(bs, 1) - preds_truncation_len = first_ids.shape[1] # remove the first preds because they will be start of patches. + first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + predictions_truncation_len = first_ids.shape[1] # remove the first predictions because they will be start of patches. entropies = entropies[:, 1:] if threshold is None: - num_patches = seq_len // patch_size + num_patches = sequence_length // patch_size patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices patch_start_ids = patch_start_ids.sort(dim=1).values else: @@ -1564,7 +1419,7 @@ def find_entropy_patch_start_ids( # patch_start_mask[1:] |= tokens[:-1] < OFFSET patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) - patch_start_ids = torch.cat((first_ids, patch_start_ids + preds_truncation_len), dim=1) + patch_start_ids = torch.cat((first_ids, patch_start_ids + predictions_truncation_len), dim=1) return patch_start_ids def init_hash_embeddings( @@ -1599,4 +1454,4 @@ def init_hash_embeddings( "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer", -] +] \ No newline at end of file From 438e2e26e594ee4555f82cc58cc455ea401eef76 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 23 Jun 2025 13:09:38 +0000 Subject: [PATCH 029/139] apply feedback --- src/convert_blt_to_hf.py | 23 +- .../models/blt_wip/configuration_blt.py | 18 +- .../models/blt_wip/modeling_blt.py | 289 +++++++++--------- .../models/blt_wip/tokenization_blt.py | 3 +- 4 files changed, 163 insertions(+), 170 deletions(-) diff --git a/src/convert_blt_to_hf.py b/src/convert_blt_to_hf.py index 54b242d0f7c6..3b8eb79baba3 100644 --- a/src/convert_blt_to_hf.py +++ b/src/convert_blt_to_hf.py @@ -9,7 +9,7 @@ from safetensors.torch import load_file, save_file from transformers.models.blt_wip.configuration_blt import BLTConfig -from transformers.models.blt_wip.modeling_blt_modellike import BLTModel +from transformers.models.blt_wip.modeling_blt import BLTModel from transformers.utils import logging as transformers_logging @@ -195,7 +195,6 @@ def convert_hf_blt_to_unified( output_dir: str, config_name: str = "config.json", weights_name: str = "model.bin", - safe_serialization: bool = True, cache_dir: Optional[str] = None, push_to_hub_repo: Optional[str] = None, hub_private: bool = False, @@ -218,16 +217,11 @@ def convert_hf_blt_to_unified( with open(config_path, "w") as f: json.dump(unified_config, f, indent=2) - if safe_serialization and weights_name.endswith(".bin"): + if weights_name.endswith(".bin"): weights_name = weights_name.replace(".bin", ".safetensors") - elif not safe_serialization and weights_name.endswith(".safetensors"): - weights_name = weights_name.replace(".safetensors", ".bin") weights_path = os.path.join(output_dir, weights_name) - if safe_serialization: - save_file(unified_weights, weights_path) - else: - torch.save(unified_weights, weights_path) + save_file(unified_weights, weights_path) create_tokenizer_config(output_dir, unified_config) @@ -269,16 +263,6 @@ def main(): type=str, default="model.bin", ) - parser.add_argument( - "--safe_serialization", - action="store_true", - default=True, - ) - parser.add_argument( - "--no_safe_serialization", - dest="safe_serialization", - action="store_false", - ) parser.add_argument( "--cache_dir", type=str, @@ -316,7 +300,6 @@ def main(): output_dir=args.output_dir, config_name=args.config_name, weights_name=args.weights_name, - safe_serialization=args.safe_serialization, cache_dir=args.cache_dir, push_to_hub_repo=args.push_to_hub, hub_private=args.hub_private, diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index a10bee26c182..e2225b0cbfcf 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -1,7 +1,5 @@ -# new config - # coding=utf-8 -# Copyright 2024 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 Facebook Research and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""BLT (Byte Latent Transformer) model configuration""" +"""BLT model configuration""" from enum import Enum @@ -150,6 +148,10 @@ def __init__( self.num_attention_heads = n_heads self.num_key_value_heads = self.n_kv_heads # Use the computed n_kv_heads self.max_position_embeddings = max_seqlen + self.hidden_act = "silu" # BLT uses silu activation + + # intermediate_size will be calculated in BLTMLP based on actual hidden_size + self.intermediate_size = None # Set simple rope scaling for patcher (no complex dynamic rope) self.rope_scaling = {"rope_type": "default"} @@ -380,6 +382,7 @@ def __init__( dropout=0.0, ffn_dim_multiplier=1.0, multiple_of=256, + hidden_act="silu", # Positional encoding rope_theta=10000.0, rope_use_fp32_in_outer_product=False, @@ -476,6 +479,7 @@ def __init__( self.dropout = dropout self.ffn_dim_multiplier = ffn_dim_multiplier self.multiple_of = multiple_of + self.hidden_act = hidden_act # Positional encoding self.rope_theta = rope_theta @@ -569,6 +573,10 @@ def __init__( self.max_position_embeddings=max_seqlen self.hidden_size=dim_local_encoder self.num_attention_heads=n_heads_local_encoder + + # Calculate intermediate_size using BLTMLP logic for each component + # Note: Each component uses its own hidden dimension, not the main dim + self.intermediate_size = None # Will be calculated per component super().__init__( bos_token_id=bos_token_id, @@ -623,6 +631,8 @@ def decoder_dim_token_emb(self): else: return self.dim_local_decoder + + def get_init_std_factor(self, depth: int) -> float: """ Calculate the initialization standard deviation scaling factor for a given layer depth. diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 21d3951265aa..e71de7599cf3 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -1,16 +1,29 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. - -import logging -import os +# coding=utf-8 +# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLT model.""" + +from ...utils import is_torch_flex_attn_available, logging from typing import Callable, List, Optional, Tuple, Union from ...cache_utils import Cache +from ...activations import ACT2FN import torch import torch.nn import torch.nn as nn from torch.nn import functional as F -from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update @@ -22,16 +35,14 @@ RMSNorm = nn.RMSNorm -logger = logging.getLogger() +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask -flex_attention_comp = flex_attention + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) -def cross_entropy(pred, target, **kwargs): - return F.nll_loss( - F.log_softmax(pred.flatten(end_dim=-2).float(), -1), - target.flatten(end_dim=-1), - **kwargs, - ) def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: @@ -47,47 +58,23 @@ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: ) class BLTMLP(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int, - ffn_dim_multiplier: Optional[float], - mp_size: int = 1, - ): + def __init__(self, config): super().__init__() - - hidden_dim = int(2 * hidden_dim / 3) - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - assert hidden_dim % mp_size == 0 - - self.dim = dim - self.hidden_dim = hidden_dim - - self.gate_proj = nn.Linear( - dim, - hidden_dim, - bias=False, - ) - self.up_proj = nn.Linear( - dim, - hidden_dim, - bias=False, - ) - self.down_proj = nn.Linear( - hidden_dim, - dim, - bias=False, - ) + self.config = config + self.hidden_size = config.hidden_size + + # Calculate intermediate_size based on actual hidden_size (not config.dim) + base_dim = 4 * self.hidden_size + intermediate_dim = int(2 * base_dim / 3) + self.intermediate_size = config.multiple_of * ((intermediate_dim + config.multiple_of - 1) // config.multiple_of) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] def forward(self, x: torch.Tensor) -> torch.Tensor: - # B S D - x1 = self.gate_proj(x.view_as(x)) - x3 = self.up_proj(x.view_as(x)) - output = self.down_proj(F.silu(x1) * x3) - return output + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) def eager_attention_forward( module: nn.Module, @@ -149,7 +136,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_rot.type_as(q), k_rot.type_as(k) - +# Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT class BLTSelfAttention(nn.Module): def __init__(self, config: BLTConfig, layer_idx: int): super().__init__() @@ -253,12 +240,7 @@ def __init__(self, dim, n_heads, config, layer_idx=0): self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) - self.mlp = BLTMLP( - dim=dim, - hidden_dim=4 * dim, - multiple_of=multiple_of, - ffn_dim_multiplier=ffn_dim_multiplier, - ) + self.mlp = BLTMLP(config=config) self.input_layernorm = RMSNorm(dim, eps=norm_eps) self.post_attention_layernorm = RMSNorm(dim, eps=norm_eps) @@ -347,86 +329,93 @@ def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_ return hash_values_range -def create_patch_mask_from_ids(patch_ids, num_patches, window=None, patches_as_queries=False): +def _prepare_patch_cross_attention_mask( + patch_ids: torch.Tensor, + num_patches: int, + sequence_length: int, + patches_as_queries: bool = False, + cross_attn_k: int = 1, + dtype: torch.dtype = torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor]: """ - Creates a tensor of shape [batch_size, seq_len, num_patches] where each element at position (i, j, k) - is True if the patch id at position (i, j) is less than or equal to k. + Prepare cross-attention mask for patch-based attention, following mllama's robust approach. + + This function creates masks that control which patches can attend to which other patches, + with support for query/key role swapping and cross-attention multipliers. + Args: patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. num_patches (int): Total number of patches. - window (int): If not None, only considers patches within a window of size window. - patches_as_queries (bool): If True, the patches are used as queries + sequence_length (int): Length of the sequence. + patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. + cross_attn_k (int): Cross-attention multiplier for repeating patches. + dtype (torch.dtype): Data type for the output mask. + Returns: - torch.Tensor: Tensor of shape [batch_size, q_len, kv_len] with the desired mask. + Tuple[torch.Tensor, torch.Tensor]: + - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] + - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows """ batch_size, seq_len = patch_ids.shape - if not patches_as_queries: - q_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) - kv_ids = ( - torch.arange(num_patches, device=patch_ids.device) - .unsqueeze(0) - .unsqueeze(0) - .expand(batch_size, seq_len, num_patches) + device = patch_ids.device + + # Determine query and key lengths based on configuration + if patches_as_queries: + q_len = num_patches * cross_attn_k + kv_len = sequence_length + # Create patch-to-sequence mapping + q_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(-1).expand( + batch_size, num_patches, seq_len ) + kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) else: - kv_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) - q_ids = ( - torch.arange(num_patches, device=patch_ids.device) - .unsqueeze(0) - .unsqueeze(-1) - .expand(batch_size, num_patches, seq_len) + q_len = sequence_length + kv_len = num_patches * cross_attn_k + # Create sequence-to-patch mapping + q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) + kv_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand( + batch_size, seq_len, num_patches ) - if window is None: - mask = q_ids == kv_ids - else: - mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window) - return mask - - -def cross_attn_mask( - patch_ids, - patch_lengths, - N, - patches_as_queries=False, - cross_attn_k=1, - window=None, - block_mask=True, -): - batch_size = patch_ids.shape[0] - with torch.no_grad(): - # Create the patch mask - cross_mask = create_patch_mask_from_ids( - patch_ids, - patch_lengths.shape[1], - window=window, - patches_as_queries=patches_as_queries, - ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1) - q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N - kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k - assert cross_mask.shape == ( - batch_size, - q_len, - kv_len, - ), f"{cross_mask.shape} != {(batch_size, q_len, kv_len)}" - block_mask = None - if block_mask: - - def patch_mask(b, num_heads, q_idx, kv_idx): - return cross_mask[b, q_idx, kv_idx] - - block_mask = create_block_mask( - patch_mask, - B=batch_size, - H=None, - Q_LEN=q_len, - KV_LEN=kv_len, - _compile=True, - ) - return block_mask - else: - return torch.where(cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).unsqueeze( - 1 - ) # [batch_size, 1, q_len, kv_len] + + # Create base attention mask - boolean mask where True means "should attend" + # Exact patch matching + cross_attention_mask = q_patch_ids == kv_patch_ids + + # Handle cross_attn_k multiplier by repeating along appropriate dimension + repeat_dim = 1 if patches_as_queries else -1 + cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) + + # Validate dimensions + expected_shape = (batch_size, q_len, kv_len) + if cross_attention_mask.shape != expected_shape: + raise ValueError(f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}") + + # Reshape so it can be used by attn module - add head dimension + cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len] + + # Invert the mask (following mllama pattern exactly) + # True -> 0.0 (attend), False -> 1.0 (will become -inf) + inverted_cross_attn_mask = (1.0 - cross_attention_mask.to(dtype)) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # Apply full-row bias (following mllama pattern exactly) + # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + + + + + def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor: @@ -544,6 +533,7 @@ def forward( patch_embeds: Optional[torch.Tensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, @@ -580,6 +570,7 @@ def forward( hidden_states=patch_embeds, cross_attention_states=hidden_states, attention_mask=cross_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, output_attentions=False, use_cache=False, cache_position=None, @@ -712,6 +703,7 @@ def forward( patch_embeds: Optional[torch.Tensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): batch_size, sequence_length = tokens.shape @@ -741,6 +733,7 @@ def forward( hidden_states=hidden_states, cross_attention_states=patch_embeds, attention_mask=cross_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, output_attentions=False, use_cache=False, cache_position=None, @@ -783,6 +776,7 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, @@ -844,7 +838,11 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - attn_output = attn_output + hidden_states #TODO: they add the residual twice?? move this out + # Apply full row masking if provided (following mllama pattern) + if full_text_row_masked_out_mask is not None: + attn_output = full_text_row_masked_out_mask[:, 0] * attn_output + + attn_output = attn_output + hidden_states if not output_attentions: attn_weights = None @@ -1034,7 +1032,6 @@ def __init__(self, config: BLTConfig): self.cross_attn_k = config.cross_attn_k self.cross_attn_window_encoder = config.cross_attn_window_encoder self.cross_attn_window_decoder = config.cross_attn_window_decoder - self.cross_attn_use_flex_attention = config.cross_attn_use_flex_attention self.boe_id = config.boe_id self.eos_token_id = config.eos_token_id @@ -1102,17 +1099,17 @@ def forward( # f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" # ) - cross_attn_mask_enc = None # Cross-attention encoder + cross_attn_mask_enc = None + full_text_row_masked_out_mask_enc = None if self.cross_attn_encoder: - cross_attn_mask_enc = cross_attn_mask( - patch_ids, - patch_lengths, - sequence_length, + cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( + patch_ids=patch_ids, + num_patches=patch_lengths.shape[1], + sequence_length=sequence_length, patches_as_queries=True, cross_attn_k=self.cross_attn_k, - window=self.cross_attn_window_encoder, - block_mask=self.cross_attn_use_flex_attention, + dtype=torch.float32, ) # Hashing and embedding @@ -1134,6 +1131,7 @@ def forward( input_embeds=local_encoder_embeds, patch_embeds=None, cross_mask=cross_attn_mask_enc, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, ) @@ -1171,20 +1169,20 @@ def forward( # ) # Cross-attention decoder + cross_attn_mask_dec = None + full_text_row_masked_out_mask_dec = None if not self.cross_attn_decoder: patch_hidden_states = torch.gather(global_hidden_states, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, global_hidden_states.shape[-1])) - cross_attn_mask_dec = None # assert local_decoder_tokens.shape == patch_hidden_states.shape[:-1] else: patch_hidden_states = global_hidden_states - cross_attn_mask_dec = cross_attn_mask( - decoder_patch_ids, - patch_lengths, - sequence_length, + cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( + patch_ids=decoder_patch_ids, + num_patches=patch_lengths.shape[1], + sequence_length=sequence_length, patches_as_queries=False, cross_attn_k=self.cross_attn_k, - window=self.cross_attn_window_decoder, - block_mask=self.cross_attn_use_flex_attention, + dtype=torch.float32, ) # Local decoder @@ -1193,6 +1191,7 @@ def forward( patch_embeds=patch_hidden_states, tokens=local_decoder_tokens, cross_mask=cross_attn_mask_dec, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, ) return output @@ -1447,6 +1446,8 @@ def init_hash_embeddings( return nn.ModuleList(embeddings) + + __all__ = [ "BLTPreTrainedModel", "BLTModel", diff --git a/src/transformers/models/blt_wip/tokenization_blt.py b/src/transformers/models/blt_wip/tokenization_blt.py index cf57143de5dd..f5fba8a50625 100644 --- a/src/transformers/models/blt_wip/tokenization_blt.py +++ b/src/transformers/models/blt_wip/tokenization_blt.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tokenization classes for BLT.""" import os From 8ecda8429bb0be0b06c717663879970757e18f06 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 23 Jun 2025 13:15:10 +0000 Subject: [PATCH 030/139] adding BLTRMSNorm like Llama --- src/demo_hf.py | 2 +- .../models/blt_wip/modeling_blt.py | 35 +++++++++++++------ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index 7140d9749e0b..f43bac62fb6f 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -3,7 +3,7 @@ import torch -from transformers.models.blt_wip.modeling_blt_modellike import BLTModel +from transformers.models.blt_wip.modeling_blt import BLTModel from transformers.models.blt_wip.tokenization_blt import BLTTokenizer diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index e71de7599cf3..f883b08feb20 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -135,6 +135,26 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_rot.type_as(q), k_rot.type_as(k) +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText +class BLTRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BLTRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + # Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT class BLTSelfAttention(nn.Module): @@ -241,8 +261,8 @@ def __init__(self, dim, n_heads, config, layer_idx=0): self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) self.mlp = BLTMLP(config=config) - self.input_layernorm = RMSNorm(dim, eps=norm_eps) - self.post_attention_layernorm = RMSNorm(dim, eps=norm_eps) + self.input_layernorm = BLTRMSNorm(dim, eps=norm_eps) + self.post_attention_layernorm = BLTRMSNorm(dim, eps=norm_eps) def forward( self, @@ -411,13 +431,6 @@ def _prepare_patch_cross_attention_mask( return cross_attention_mask, full_text_row_masked_out_mask - - - - - - - def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor: if max_patch_length is None: return patch_lengths @@ -659,7 +672,7 @@ def __init__(self, config: BLTConfig): self.patch_embedding_projection = self._create_patch_projection(config) - self.norm = RMSNorm(self.dim_local_decoder, eps=self.norm_eps) + self.norm = BLTRMSNorm(self.dim_local_decoder, eps=self.norm_eps) # Initialize cross attention layers only if cross attention is enabled self.cross_attn_layers = None @@ -1251,7 +1264,7 @@ def __init__(self, config): self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim) - self.norm = RMSNorm(self.config.dim, eps=self.config.norm_eps) + self.norm = BLTRMSNorm(self.config.dim, eps=self.config.norm_eps) self.lm_head = nn.Linear( self.config.dim, From ceb3d8e23335cbaaea113e1f0e00ff89f509dd53 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 23 Jun 2025 13:20:02 +0000 Subject: [PATCH 031/139] add repeat_kv and eager_attention_forward copied from --- .../models/blt_wip/modeling_blt.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index f883b08feb20..b34e34acc5b2 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -33,8 +33,6 @@ PatchingModeEnum, ) -RMSNorm = nn.RMSNorm - if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask @@ -44,18 +42,18 @@ logger = logging.get_logger(__name__) - -def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." - batch_size, slen, n_kv_heads, head_dim = x.shape +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: - return x - return ( - x[:, :, :, None, :] - .expand(batch_size, slen, n_kv_heads, n_rep, head_dim) - .reshape(batch_size, slen, n_kv_heads * n_rep, head_dim) - ) + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + class BLTMLP(nn.Module): def __init__(self, config): @@ -76,6 +74,7 @@ def __init__(self, config): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) +# Copied from transformers.models.llama.modeling_llama.eager_attention_forward def eager_attention_forward( module: nn.Module, query: torch.Tensor, From 2102d32515d93ae5b6a3ac8232bd000f1ee4fe99 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 23 Jun 2025 13:26:59 +0000 Subject: [PATCH 032/139] BLTMLP identical to MllamTextMLP --- .../models/blt_wip/modeling_blt.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index b34e34acc5b2..07a62359a469 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -54,25 +54,21 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - +# Copied from transformers.models.mllama.modeling_mllama.MllamaTextMLP class BLTMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size - - # Calculate intermediate_size based on actual hidden_size (not config.dim) - base_dim = 4 * self.hidden_size - intermediate_dim = int(2 * base_dim / 3) - self.intermediate_size = config.multiple_of * ((intermediate_dim + config.multiple_of - 1) // config.multiple_of) - + self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Copied from transformers.models.llama.modeling_llama.eager_attention_forward def eager_attention_forward( @@ -259,7 +255,14 @@ def __init__(self, dim, n_heads, config, layer_idx=0): self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) - self.mlp = BLTMLP(config=config) + # Create a copy of config for MLP with pre-calculated dimensions + mlp_config = type(config)(**config.__dict__) + mlp_config.hidden_size = dim + + # Calculate intermediate_size using the same logic as BLTMLP + mlp_config.intermediate_size = multiple_of * (( int(8 * dim / 3) + multiple_of - 1) // multiple_of) + + self.mlp = BLTMLP(config=mlp_config) self.input_layernorm = BLTRMSNorm(dim, eps=norm_eps) self.post_attention_layernorm = BLTRMSNorm(dim, eps=norm_eps) From 50d250368653adedf81328b1a4e4e55c866dea87 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 23 Jun 2025 13:56:30 +0000 Subject: [PATCH 033/139] clean up some args' --- src/transformers/models/blt_wip/modeling_blt.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 07a62359a469..5a34622a5fe6 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -238,13 +238,9 @@ def __init__(self, dim, n_heads, config, layer_idx=0): super().__init__() # Extract parameters from dictionary - dim = dim - n_heads = n_heads head_dim = config.head_dim n_kv_heads = config.n_kv_heads - rope_theta = config.rope_theta multiple_of = config.multiple_of - ffn_dim_multiplier = config.ffn_dim_multiplier norm_eps = config.norm_eps self.head_dim = head_dim or dim // n_heads From 66bcddb904cf235852fff3aede5a02d0298b5b88 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 23 Jun 2025 14:03:21 +0000 Subject: [PATCH 034/139] more like mllama, but busier inits --- .../models/blt_wip/modeling_blt.py | 165 +++++++++++------- 1 file changed, 105 insertions(+), 60 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 5a34622a5fe6..8f16f52eac3d 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -233,67 +233,86 @@ def forward( return attn_output, attn_weights, past_key_value +# Copied from transformers.models.llama.modeling_mllama.MllamaSelfAttentionDecoderLayer class BLTTransformerLayer(nn.Module): - def __init__(self, dim, n_heads, config, layer_idx=0): + def __init__(self, config: BLTConfig, layer_idx: int): super().__init__() - - # Extract parameters from dictionary - head_dim = config.head_dim - n_kv_heads = config.n_kv_heads - multiple_of = config.multiple_of - norm_eps = config.norm_eps - - self.head_dim = head_dim or dim // n_heads - self.n_heads = n_heads or dim // head_dim - self.n_kv_heads = n_kv_heads or self.n_heads - - config.hidden_size = dim + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) - - # Create a copy of config for MLP with pre-calculated dimensions - mlp_config = type(config)(**config.__dict__) - mlp_config.hidden_size = dim - - # Calculate intermediate_size using the same logic as BLTMLP - mlp_config.intermediate_size = multiple_of * (( int(8 * dim / 3) + multiple_of - 1) // multiple_of) - - self.mlp = BLTMLP(config=mlp_config) - self.input_layernorm = BLTRMSNorm(dim, eps=norm_eps) - self.post_attention_layernorm = BLTRMSNorm(dim, eps=norm_eps) + self.mlp = BLTMLP(config) + self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) def forward( self, hidden_states: torch.Tensor, - past_key_value: Optional[bool] = None, - position_embeddings: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - - ) -> torch.Tensor: - + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor`, *optional*): + Position indices of tokens in the sequence for RoPE computation. + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ residual = hidden_states - norm_hidden_states = self.input_layernorm(hidden_states) - + hidden_states = self.input_layernorm(hidden_states) + # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=norm_hidden_states, - # TODO: = BLT, attn_out = self.self_attn(self.input_layernorm(x), in TransformerBlock.forward, - past_key_value=past_key_value, + hidden_states=hidden_states, attention_mask=attention_mask, - layer_head_mask=layer_head_mask, + position_ids=position_ids, + past_key_value=past_key_value, output_attentions=output_attentions, + use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings + position_embeddings=position_embeddings, + **kwargs, ) + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - normalized_hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = hidden_states + self.mlp(normalized_hidden_states) - return hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs def check_non_zero_after_zero(tensor): zero_mask = tensor == 0 @@ -509,11 +528,17 @@ def __init__(self, config: BLTConfig): self.norm_eps = config.norm_eps self.sliding_window = config.sliding_window - self.layers = nn.ModuleList([BLTTransformerLayer(self.dim_local_encoder, self.n_heads_local_encoder, config) for _ in range(self.n_layers_local_encoder)]) - - # Set up config for rotary embedding + # Set up config for layers with proper dimensions encoder_config = config + encoder_config.hidden_size = self.dim_local_encoder + encoder_config.num_attention_heads = self.n_heads_local_encoder + encoder_config.num_key_value_heads = getattr(config, 'n_kv_heads', None) or self.n_heads_local_encoder encoder_config.head_dim = self.dim_local_encoder // self.n_heads_local_encoder + encoder_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_local_encoder / 3) + config.multiple_of - 1) // config.multiple_of) + + self.layers = nn.ModuleList([BLTTransformerLayer(encoder_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)]) + + # Set up config for rotary embedding encoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen self.rotary_emb = BLTRotaryEmbedding(config=encoder_config) @@ -566,7 +591,8 @@ def forward( hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) for idx, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] if self.cross_attn_encoder and (idx == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): # Initialize patch_embeds if not provided when cross attention is enabled @@ -654,10 +680,16 @@ def __init__(self, config: BLTConfig): self.cross_attn_k = config.cross_attn_k self.sliding_window = config.sliding_window - self.layers = nn.ModuleList([BLTTransformerLayer(self.dim_local_decoder, self.n_heads_local_decoder, config) for _ in range(self.n_layers_local_decoder)]) - + # Set up config for layers with proper dimensions decoder_config = config + decoder_config.hidden_size = self.dim_local_decoder + decoder_config.num_attention_heads = self.n_heads_local_decoder + decoder_config.num_key_value_heads = getattr(config, 'n_kv_heads', None) or self.n_heads_local_decoder decoder_config.head_dim = self.dim_local_decoder // self.n_heads_local_decoder + decoder_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_local_decoder / 3) + config.multiple_of - 1) // config.multiple_of) + + self.layers = nn.ModuleList([BLTTransformerLayer(decoder_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)]) + decoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen self.rotary_emb = BLTRotaryEmbedding(config=decoder_config) @@ -751,12 +783,13 @@ def forward( ) hidden_states = hidden_states + cross_attention_output - hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] logits = self.lm_head(self.norm(hidden_states)) return logits, cache - +# Modified from transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention class BLTCrossAttention(nn.Module): """Cross-attention module for BLT, following transformers style""" @@ -871,15 +904,17 @@ def __init__(self, config): self.n_layers_global = config.n_layers_global self.dropout = config.dropout - self.layers = nn.ModuleList() - old = config.n_kv_heads - config.n_kv_heads = config.n_kv_heads_global - for _ in range(self.n_layers_global): - self.layers.append(BLTTransformerLayer(self.dim_global, self.n_heads_global, config)) - config.n_kv_heads = old - + # Set up config for layers with proper dimensions global_config = config + global_config.hidden_size = self.dim_global + global_config.num_attention_heads = self.n_heads_global + global_config.num_key_value_heads = getattr(config, 'n_kv_heads_global', None) or self.n_heads_global global_config.head_dim = self.dim_global // self.n_heads_global + global_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_global / 3) + config.multiple_of - 1) // config.multiple_of) + + self.layers = nn.ModuleList() + for layer_idx in range(self.n_layers_global): + self.layers.append(BLTTransformerLayer(global_config, layer_idx)) self.rotary_emb = BLTRotaryEmbedding(config=global_config) @@ -912,7 +947,8 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] return hidden_states, cache @@ -1256,8 +1292,16 @@ def __init__(self, config): self.rotary_emb = BLTRotaryEmbedding(config=self.config) self.layers = nn.ModuleList() - for _ in range(self.config.n_layers): - self.layers.append(BLTTransformerLayer(self.config.dim, self.config.n_heads, self.config)) + # Set up config for layers with proper dimensions + patcher_config = self.config + patcher_config.hidden_size = self.config.dim + patcher_config.num_attention_heads = self.config.n_heads + patcher_config.num_key_value_heads = getattr(self.config, 'n_kv_heads', None) or self.config.n_heads + patcher_config.head_dim = self.config.dim // self.config.n_heads + patcher_config.intermediate_size = self.config.multiple_of * ((int(8 * self.config.dim / 3) + self.config.multiple_of - 1) // self.config.multiple_of) + + for layer_idx in range(self.config.n_layers): + self.layers.append(BLTTransformerLayer(patcher_config, layer_idx)) self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim) @@ -1309,7 +1353,8 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) # = BLT self.rope for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) #, attn_impl=self.config.patcher_attn_impl ) + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) #, attn_impl=self.config.patcher_attn_impl ) + hidden_states = layer_outputs[0] logits = self.lm_head(self.norm(hidden_states)) logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] From 5bcc11d98d7e4cb8adddb2e8f56f0cade82e30c7 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 23 Jun 2025 15:06:42 +0000 Subject: [PATCH 035/139] BLTTransformerLayer config --- .../models/blt_wip/configuration_blt.py | 113 +++++++++++++++++- .../models/blt_wip/modeling_blt.py | 66 +++------- 2 files changed, 125 insertions(+), 54 deletions(-) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index e2225b0cbfcf..52b7496b90e5 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -37,6 +37,44 @@ class PatchingModeEnum(str, Enum): byte = "byte" +class TransformersLayerConfig: + """ + Configuration class for BLT Transformer layers, providing all necessary parameters + for attention, MLP, and normalization components. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + intermediate_size: int, + norm_eps: float, + dropout: float, + max_position_embeddings: int, + rope_theta: float, + rope_scaling: dict, + hidden_act: str = "silu", + **kwargs, + ): + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.intermediate_size = intermediate_size + self.norm_eps = norm_eps + self.dropout = dropout + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.hidden_act = hidden_act + + # Add any additional kwargs as attributes + for key, value in kwargs.items(): + setattr(self, key, value) + + class BLTPatcherConfig(PretrainedConfig): r""" Configuration class for the BLT Patcher/Entropy model component. @@ -578,6 +616,63 @@ def __init__( # Note: Each component uses its own hidden dimension, not the main dim self.intermediate_size = None # Will be calculated per component + # layer configurations as dictionaries (needed to be JSON serializable!) + self._encoder_layer_config_dict = { + "hidden_size": self.dim_local_encoder, + "num_attention_heads": self.n_heads_local_encoder, + "num_key_value_heads": getattr(self, 'n_kv_heads', None) or self.n_heads_local_encoder, + "head_dim": self.dim_local_encoder // self.n_heads_local_encoder, + "intermediate_size": self.multiple_of * ((int(8 * self.dim_local_encoder / 3) + self.multiple_of - 1) // self.multiple_of), + "norm_eps": self.norm_eps, + "dropout": self.dropout, + "max_position_embeddings": self.max_encoder_seq_length or self.max_seqlen, + "rope_theta": self.rope_theta, + "rope_scaling": self.rope_scaling, + "hidden_act": self.hidden_act, + } + + self._decoder_layer_config_dict = { + "hidden_size": self.dim_local_decoder, + "num_attention_heads": self.n_heads_local_decoder, + "num_key_value_heads": getattr(self, 'n_kv_heads', None) or self.n_heads_local_decoder, + "head_dim": self.dim_local_decoder // self.n_heads_local_decoder, + "intermediate_size": self.multiple_of * ((int(8 * self.dim_local_decoder / 3) + self.multiple_of - 1) // self.multiple_of), + "norm_eps": self.norm_eps, + "dropout": self.dropout, + "max_position_embeddings": self.max_encoder_seq_length or self.max_seqlen, + "rope_theta": self.rope_theta, + "rope_scaling": self.rope_scaling, + "hidden_act": self.hidden_act, + } + + self._global_layer_config_dict = { + "hidden_size": self.dim_global, + "num_attention_heads": self.n_heads_global, + "num_key_value_heads": getattr(self, 'n_kv_heads_global', None) or self.n_heads_global, + "head_dim": self.dim_global // self.n_heads_global, + "intermediate_size": self.multiple_of * ((int(8 * self.dim_global / 3) + self.multiple_of - 1) // self.multiple_of), + "norm_eps": self.norm_eps, + "dropout": self.dropout, + "max_position_embeddings": self.max_seqlen, + "rope_theta": self.rope_theta, + "rope_scaling": self.rope_scaling, + "hidden_act": self.hidden_act, + } + + self._patcher_layer_config_dict = { + "hidden_size": self.patcher_config.dim, + "num_attention_heads": self.patcher_config.n_heads, + "num_key_value_heads": getattr(self.patcher_config, 'n_kv_heads', None) or self.patcher_config.n_heads, + "head_dim": self.patcher_config.dim // self.patcher_config.n_heads, + "intermediate_size": self.patcher_config.multiple_of * ((int(8 * self.patcher_config.dim / 3) + self.patcher_config.multiple_of - 1) // self.patcher_config.multiple_of), + "norm_eps": self.patcher_config.norm_eps, + "dropout": self.patcher_config.dropout, + "max_position_embeddings": self.patcher_config.max_seqlen, + "rope_theta": self.patcher_config.rope_theta, + "rope_scaling": self.patcher_config.rope_scaling, + "hidden_act": self.hidden_act, + } + super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, @@ -585,6 +680,21 @@ def __init__( **kwargs, ) + @property + def encoder_layer_config(self) -> TransformersLayerConfig: + return TransformersLayerConfig(**self._encoder_layer_config_dict) + + @property + def decoder_layer_config(self) -> TransformersLayerConfig: + return TransformersLayerConfig(**self._decoder_layer_config_dict) + + @property + def global_layer_config(self) -> TransformersLayerConfig: + return TransformersLayerConfig(**self._global_layer_config_dict) + + @property + def patcher_layer_config(self) -> TransformersLayerConfig: + return TransformersLayerConfig(**self._patcher_layer_config_dict) @property def encoder_dim_token_emb(self): @@ -648,6 +758,5 @@ def get_init_std_factor(self, depth: int) -> float: else: # DISABLED return 1.0 - -__all__ = ["BLTConfig", "BLTPatcherConfig", "InitStdFactor", "PatchingModeEnum"] +__all__ = ["BLTConfig", "BLTPatcherConfig", "TransformersLayerConfig", "InitStdFactor", "PatchingModeEnum"] diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 8f16f52eac3d..1a44faee0648 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -31,6 +31,7 @@ from .configuration_blt import ( BLTConfig, PatchingModeEnum, + TransformersLayerConfig, ) if is_torch_flex_attn_available(): @@ -153,7 +154,7 @@ def extra_repr(self): # Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT class BLTSelfAttention(nn.Module): - def __init__(self, config: BLTConfig, layer_idx: int): + def __init__(self, config: TransformersLayerConfig, layer_idx: int): super().__init__() self.config = config self.num_heads = config.num_attention_heads @@ -233,9 +234,10 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.llama.modeling_mllama.MllamaSelfAttentionDecoderLayer + +# Copied from transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer class BLTTransformerLayer(nn.Module): - def __init__(self, config: BLTConfig, layer_idx: int): + def __init__(self, config: TransformersLayerConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.layer_idx = layer_idx @@ -480,7 +482,7 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> class BLTRotaryEmbedding(nn.Module): - def __init__(self, config: BLTConfig, device=None): + def __init__(self, config: TransformersLayerConfig, device=None): super().__init__() self.rope_type = config.rope_scaling["rope_type"] self.max_seq_len_cached = config.max_position_embeddings @@ -528,19 +530,9 @@ def __init__(self, config: BLTConfig): self.norm_eps = config.norm_eps self.sliding_window = config.sliding_window - # Set up config for layers with proper dimensions - encoder_config = config - encoder_config.hidden_size = self.dim_local_encoder - encoder_config.num_attention_heads = self.n_heads_local_encoder - encoder_config.num_key_value_heads = getattr(config, 'n_kv_heads', None) or self.n_heads_local_encoder - encoder_config.head_dim = self.dim_local_encoder // self.n_heads_local_encoder - encoder_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_local_encoder / 3) + config.multiple_of - 1) // config.multiple_of) - - self.layers = nn.ModuleList([BLTTransformerLayer(encoder_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)]) + self.layers = nn.ModuleList([BLTTransformerLayer(config.encoder_layer_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)]) - # Set up config for rotary embedding - encoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen - self.rotary_emb = BLTRotaryEmbedding(config=encoder_config) + self.rotary_emb = BLTRotaryEmbedding(config=config.encoder_layer_config) self.token_embedding_projection = ( nn.Linear(config.encoder_dim_token_emb, self.dim_local_encoder, bias=False) @@ -552,7 +544,6 @@ def __init__(self, config: BLTConfig): self.embed_tokens = nn.Embedding(self.vocab_size + self.pm_size, self.dim_local_encoder) - # Initialize cross attention layers only if cross attention is enabled self.cross_attn_layers = None if self.cross_attn_encoder and self.cross_attn_nheads is not None: self.cross_attn_layers = torch.nn.ModuleList() @@ -680,19 +671,9 @@ def __init__(self, config: BLTConfig): self.cross_attn_k = config.cross_attn_k self.sliding_window = config.sliding_window - # Set up config for layers with proper dimensions - decoder_config = config - decoder_config.hidden_size = self.dim_local_decoder - decoder_config.num_attention_heads = self.n_heads_local_decoder - decoder_config.num_key_value_heads = getattr(config, 'n_kv_heads', None) or self.n_heads_local_decoder - decoder_config.head_dim = self.dim_local_decoder // self.n_heads_local_decoder - decoder_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_local_decoder / 3) + config.multiple_of - 1) // config.multiple_of) - - self.layers = nn.ModuleList([BLTTransformerLayer(decoder_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)]) + self.layers = nn.ModuleList([BLTTransformerLayer(config.decoder_layer_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)]) - decoder_config.max_position_embeddings = config.max_encoder_seq_length or config.max_seqlen - - self.rotary_emb = BLTRotaryEmbedding(config=decoder_config) + self.rotary_emb = BLTRotaryEmbedding(config=config.decoder_layer_config) self.token_embedding_projection = ( nn.Linear(config.decoder_dim_token_emb, self.dim_local_decoder, bias=False) @@ -704,7 +685,6 @@ def __init__(self, config: BLTConfig): self.norm = BLTRMSNorm(self.dim_local_decoder, eps=self.norm_eps) - # Initialize cross attention layers only if cross attention is enabled self.cross_attn_layers = None if self.cross_attn_decoder and self.cross_attn_nheads is not None: self.cross_attn_layers = torch.nn.ModuleList() @@ -789,7 +769,7 @@ def forward( logits = self.lm_head(self.norm(hidden_states)) return logits, cache -# Modified from transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention + class BLTCrossAttention(nn.Module): """Cross-attention module for BLT, following transformers style""" @@ -898,25 +878,16 @@ class BLTGlobalTransformer(nn.Module): def __init__(self, config): super().__init__() - # Extract config values to instance attributes self.dim_global = config.dim_global self.n_heads_global = config.n_heads_global self.n_layers_global = config.n_layers_global self.dropout = config.dropout - # Set up config for layers with proper dimensions - global_config = config - global_config.hidden_size = self.dim_global - global_config.num_attention_heads = self.n_heads_global - global_config.num_key_value_heads = getattr(config, 'n_kv_heads_global', None) or self.n_heads_global - global_config.head_dim = self.dim_global // self.n_heads_global - global_config.intermediate_size = config.multiple_of * ((int(8 * self.dim_global / 3) + config.multiple_of - 1) // config.multiple_of) - self.layers = nn.ModuleList() for layer_idx in range(self.n_layers_global): - self.layers.append(BLTTransformerLayer(global_config, layer_idx)) + self.layers.append(BLTTransformerLayer(config.global_layer_config, layer_idx)) - self.rotary_emb = BLTRotaryEmbedding(config=global_config) + self.rotary_emb = BLTRotaryEmbedding(config=config.global_layer_config) self.token_embedding_projection = None if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim_global: @@ -1292,17 +1263,8 @@ def __init__(self, config): self.rotary_emb = BLTRotaryEmbedding(config=self.config) self.layers = nn.ModuleList() - # Set up config for layers with proper dimensions - patcher_config = self.config - patcher_config.hidden_size = self.config.dim - patcher_config.num_attention_heads = self.config.n_heads - patcher_config.num_key_value_heads = getattr(self.config, 'n_kv_heads', None) or self.config.n_heads - patcher_config.head_dim = self.config.dim // self.config.n_heads - patcher_config.intermediate_size = self.config.multiple_of * ((int(8 * self.config.dim / 3) + self.config.multiple_of - 1) // self.config.multiple_of) - for layer_idx in range(self.config.n_layers): - self.layers.append(BLTTransformerLayer(patcher_config, layer_idx)) - + self.layers.append(BLTTransformerLayer(config.patcher_layer_config, layer_idx)) self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim) From aa03d78faf73d90824bc0025675bda9b30cb7fa7 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 23 Jun 2025 16:43:25 +0000 Subject: [PATCH 036/139] decoder, encoder, global configs --- .../models/blt_wip/configuration_blt.py | 259 +++++++++++------- .../models/blt_wip/modeling_blt.py | 190 +++++++------ 2 files changed, 262 insertions(+), 187 deletions(-) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index 52b7496b90e5..8d9601433e14 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -37,42 +37,127 @@ class PatchingModeEnum(str, Enum): byte = "byte" -class TransformersLayerConfig: +class BLTLocalEncoderConfig(PretrainedConfig): """ - Configuration class for BLT Transformer layers, providing all necessary parameters - for attention, MLP, and normalization components. + Configuration class for the BLT Local Encoder component. """ + model_type = "blt_local_encoder" + + def __init__( + self, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=None, + head_dim=None, + intermediate_size=None, + num_hidden_layers=8, + norm_eps=1e-5, + dropout=0.0, + max_position_embeddings=1024, + rope_theta=10000.0, + rope_scaling=None, + hidden_act="silu", + multiple_of=256, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads or num_attention_heads + self.head_dim = head_dim or (hidden_size // num_attention_heads) + self.intermediate_size = intermediate_size or multiple_of * ((int(8 * hidden_size / 3) + multiple_of - 1) // multiple_of) + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.dropout = dropout + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling or {"rope_type": "default"} + self.hidden_act = hidden_act + self.multiple_of = multiple_of + + super().__init__(**kwargs) + + +class BLTLocalDecoderConfig(PretrainedConfig): + """ + Configuration class for the BLT Local Decoder component. + """ + + model_type = "blt_local_decoder" + + def __init__( + self, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=None, + head_dim=None, + intermediate_size=None, + num_hidden_layers=8, + norm_eps=1e-5, + dropout=0.0, + max_position_embeddings=1024, + rope_theta=10000.0, + rope_scaling=None, + hidden_act="silu", + multiple_of=256, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads or num_attention_heads + self.head_dim = head_dim or (hidden_size // num_attention_heads) + self.intermediate_size = intermediate_size or multiple_of * ((int(8 * hidden_size / 3) + multiple_of - 1) // multiple_of) + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.dropout = dropout + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling or {"rope_type": "default"} + self.hidden_act = hidden_act + self.multiple_of = multiple_of + + super().__init__(**kwargs) + + +class BLTGlobalTransformerConfig(PretrainedConfig): + """ + Configuration class for the BLT Global Transformer component. + """ + + model_type = "blt_global_transformer" + def __init__( self, - hidden_size: int, - num_attention_heads: int, - num_key_value_heads: int, - head_dim: int, - intermediate_size: int, - norm_eps: float, - dropout: float, - max_position_embeddings: int, - rope_theta: float, - rope_scaling: dict, - hidden_act: str = "silu", + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=None, + head_dim=None, + intermediate_size=None, + num_hidden_layers=8, + norm_eps=1e-5, + dropout=0.0, + max_position_embeddings=1024, + rope_theta=10000.0, + rope_scaling=None, + hidden_act="silu", + multiple_of=256, **kwargs, ): self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.head_dim = head_dim - self.intermediate_size = intermediate_size + self.num_key_value_heads = num_key_value_heads or num_attention_heads + self.head_dim = head_dim or (hidden_size // num_attention_heads) + self.intermediate_size = intermediate_size or multiple_of * ((int(8 * hidden_size / 3) + multiple_of - 1) // multiple_of) + self.num_hidden_layers = num_hidden_layers self.norm_eps = norm_eps self.dropout = dropout self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta - self.rope_scaling = rope_scaling + self.rope_scaling = rope_scaling or {"rope_type": "default"} self.hidden_act = hidden_act + self.multiple_of = multiple_of - # Add any additional kwargs as attributes - for key, value in kwargs.items(): - setattr(self, key, value) + super().__init__(**kwargs) class BLTPatcherConfig(PretrainedConfig): @@ -188,8 +273,8 @@ def __init__( self.max_position_embeddings = max_seqlen self.hidden_act = "silu" # BLT uses silu activation - # intermediate_size will be calculated in BLTMLP based on actual hidden_size - self.intermediate_size = None + # Calculate intermediate_size using BLTMLP logic based on actual hidden_size + self.intermediate_size = multiple_of * ((int(8 * dim / 3) + multiple_of - 1) // multiple_of) # Set simple rope scaling for patcher (no complex dynamic rope) self.rope_scaling = {"rope_type": "default"} @@ -588,6 +673,49 @@ def __init__( # Special token IDs self.boe_id = boe_id + # Initialize component configurations + self.encoder_config = BLTLocalEncoderConfig( + hidden_size=dim_local_encoder, + num_attention_heads=n_heads_local_encoder, + num_key_value_heads=n_kv_heads, + num_hidden_layers=n_layers_local_encoder, + norm_eps=norm_eps, + dropout=dropout, + max_position_embeddings=max_encoder_seq_length or max_seqlen, + rope_theta=rope_theta, + rope_scaling={"type": "default", "rope_type": "default"}, + hidden_act=hidden_act, + multiple_of=multiple_of, + ) + + self.decoder_config = BLTLocalDecoderConfig( + hidden_size=dim_local_decoder, + num_attention_heads=n_heads_local_decoder, + num_key_value_heads=n_kv_heads, + num_hidden_layers=n_layers_local_decoder, + norm_eps=norm_eps, + dropout=dropout, + max_position_embeddings=max_encoder_seq_length or max_seqlen, + rope_theta=rope_theta, + rope_scaling={"type": "default", "rope_type": "default"}, + hidden_act=hidden_act, + multiple_of=multiple_of, + ) + + self.global_config = BLTGlobalTransformerConfig( + hidden_size=dim_global, + num_attention_heads=n_heads_global, + num_key_value_heads=n_kv_heads_global, + num_hidden_layers=n_layers_global, + norm_eps=norm_eps, + dropout=dropout, + max_position_embeddings=max_seqlen, + rope_theta=rope_theta, + rope_scaling={"type": "default", "rope_type": "default"}, + hidden_act=hidden_act, + multiple_of=multiple_of, + ) + # Initialize patcher configuration if patcher_args is not None: self.patcher_config = BLTPatcherConfig(**patcher_args) @@ -616,63 +744,6 @@ def __init__( # Note: Each component uses its own hidden dimension, not the main dim self.intermediate_size = None # Will be calculated per component - # layer configurations as dictionaries (needed to be JSON serializable!) - self._encoder_layer_config_dict = { - "hidden_size": self.dim_local_encoder, - "num_attention_heads": self.n_heads_local_encoder, - "num_key_value_heads": getattr(self, 'n_kv_heads', None) or self.n_heads_local_encoder, - "head_dim": self.dim_local_encoder // self.n_heads_local_encoder, - "intermediate_size": self.multiple_of * ((int(8 * self.dim_local_encoder / 3) + self.multiple_of - 1) // self.multiple_of), - "norm_eps": self.norm_eps, - "dropout": self.dropout, - "max_position_embeddings": self.max_encoder_seq_length or self.max_seqlen, - "rope_theta": self.rope_theta, - "rope_scaling": self.rope_scaling, - "hidden_act": self.hidden_act, - } - - self._decoder_layer_config_dict = { - "hidden_size": self.dim_local_decoder, - "num_attention_heads": self.n_heads_local_decoder, - "num_key_value_heads": getattr(self, 'n_kv_heads', None) or self.n_heads_local_decoder, - "head_dim": self.dim_local_decoder // self.n_heads_local_decoder, - "intermediate_size": self.multiple_of * ((int(8 * self.dim_local_decoder / 3) + self.multiple_of - 1) // self.multiple_of), - "norm_eps": self.norm_eps, - "dropout": self.dropout, - "max_position_embeddings": self.max_encoder_seq_length or self.max_seqlen, - "rope_theta": self.rope_theta, - "rope_scaling": self.rope_scaling, - "hidden_act": self.hidden_act, - } - - self._global_layer_config_dict = { - "hidden_size": self.dim_global, - "num_attention_heads": self.n_heads_global, - "num_key_value_heads": getattr(self, 'n_kv_heads_global', None) or self.n_heads_global, - "head_dim": self.dim_global // self.n_heads_global, - "intermediate_size": self.multiple_of * ((int(8 * self.dim_global / 3) + self.multiple_of - 1) // self.multiple_of), - "norm_eps": self.norm_eps, - "dropout": self.dropout, - "max_position_embeddings": self.max_seqlen, - "rope_theta": self.rope_theta, - "rope_scaling": self.rope_scaling, - "hidden_act": self.hidden_act, - } - - self._patcher_layer_config_dict = { - "hidden_size": self.patcher_config.dim, - "num_attention_heads": self.patcher_config.n_heads, - "num_key_value_heads": getattr(self.patcher_config, 'n_kv_heads', None) or self.patcher_config.n_heads, - "head_dim": self.patcher_config.dim // self.patcher_config.n_heads, - "intermediate_size": self.patcher_config.multiple_of * ((int(8 * self.patcher_config.dim / 3) + self.patcher_config.multiple_of - 1) // self.patcher_config.multiple_of), - "norm_eps": self.patcher_config.norm_eps, - "dropout": self.patcher_config.dropout, - "max_position_embeddings": self.patcher_config.max_seqlen, - "rope_theta": self.patcher_config.rope_theta, - "rope_scaling": self.patcher_config.rope_scaling, - "hidden_act": self.hidden_act, - } - super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, @@ -680,21 +751,6 @@ def __init__( **kwargs, ) - @property - def encoder_layer_config(self) -> TransformersLayerConfig: - return TransformersLayerConfig(**self._encoder_layer_config_dict) - - @property - def decoder_layer_config(self) -> TransformersLayerConfig: - return TransformersLayerConfig(**self._decoder_layer_config_dict) - - @property - def global_layer_config(self) -> TransformersLayerConfig: - return TransformersLayerConfig(**self._global_layer_config_dict) - - @property - def patcher_layer_config(self) -> TransformersLayerConfig: - return TransformersLayerConfig(**self._patcher_layer_config_dict) @property def encoder_dim_token_emb(self): @@ -758,5 +814,16 @@ def get_init_std_factor(self, depth: int) -> float: else: # DISABLED return 1.0 -__all__ = ["BLTConfig", "BLTPatcherConfig", "TransformersLayerConfig", "InitStdFactor", "PatchingModeEnum"] + + + +__all__ = [ + "BLTConfig", + "BLTPatcherConfig", + "BLTLocalEncoderConfig", + "BLTLocalDecoderConfig", + "BLTGlobalTransformerConfig", + "InitStdFactor", + "PatchingModeEnum" +] diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 1a44faee0648..bda6772572f3 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -30,8 +30,11 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from .configuration_blt import ( BLTConfig, + BLTLocalEncoderConfig, + BLTLocalDecoderConfig, + BLTGlobalTransformerConfig, + BLTPatcherConfig, PatchingModeEnum, - TransformersLayerConfig, ) if is_torch_flex_attn_available(): @@ -152,9 +155,88 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +class BLTTransformerLayer(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) + self.mlp = BLTMLP(config) + self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor`, *optional*): + Position indices of tokens in the sequence for RoPE computation. + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + # Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT class BLTSelfAttention(nn.Module): - def __init__(self, config: TransformersLayerConfig, layer_idx: int): + def __init__(self, config, layer_idx: int): super().__init__() self.config = config self.num_heads = config.num_attention_heads @@ -235,86 +317,6 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer -class BLTTransformerLayer(nn.Module): - def __init__(self, config: TransformersLayerConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - self.layer_idx = layer_idx - - self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) - self.mlp = BLTMLP(config) - self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) - self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - position_ids (`torch.LongTensor`, *optional*): - Position indices of tokens in the sequence for RoPE computation. - past_key_value (`Cache`, *optional*): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs def check_non_zero_after_zero(tensor): zero_mask = tensor == 0 @@ -482,7 +484,7 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> class BLTRotaryEmbedding(nn.Module): - def __init__(self, config: TransformersLayerConfig, device=None): + def __init__(self, config, device=None): super().__init__() self.rope_type = config.rope_scaling["rope_type"] self.max_seq_len_cached = config.max_position_embeddings @@ -530,9 +532,10 @@ def __init__(self, config: BLTConfig): self.norm_eps = config.norm_eps self.sliding_window = config.sliding_window - self.layers = nn.ModuleList([BLTTransformerLayer(config.encoder_layer_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)]) + encoder_config = config.encoder_config + self.layers = nn.ModuleList([BLTTransformerLayer(encoder_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)]) - self.rotary_emb = BLTRotaryEmbedding(config=config.encoder_layer_config) + self.rotary_emb = BLTRotaryEmbedding(config=encoder_config) self.token_embedding_projection = ( nn.Linear(config.encoder_dim_token_emb, self.dim_local_encoder, bias=False) @@ -671,9 +674,10 @@ def __init__(self, config: BLTConfig): self.cross_attn_k = config.cross_attn_k self.sliding_window = config.sliding_window - self.layers = nn.ModuleList([BLTTransformerLayer(config.decoder_layer_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)]) + decoder_config = config.decoder_config + self.layers = nn.ModuleList([BLTTransformerLayer(decoder_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)]) - self.rotary_emb = BLTRotaryEmbedding(config=config.decoder_layer_config) + self.rotary_emb = BLTRotaryEmbedding(config=decoder_config) self.token_embedding_projection = ( nn.Linear(config.decoder_dim_token_emb, self.dim_local_decoder, bias=False) @@ -883,11 +887,12 @@ def __init__(self, config): self.n_layers_global = config.n_layers_global self.dropout = config.dropout + global_config = config.global_config self.layers = nn.ModuleList() for layer_idx in range(self.n_layers_global): - self.layers.append(BLTTransformerLayer(config.global_layer_config, layer_idx)) + self.layers.append(BLTTransformerLayer(global_config, layer_idx)) - self.rotary_emb = BLTRotaryEmbedding(config=config.global_layer_config) + self.rotary_emb = BLTRotaryEmbedding(config=global_config) self.token_embedding_projection = None if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim_global: @@ -1263,8 +1268,10 @@ def __init__(self, config): self.rotary_emb = BLTRotaryEmbedding(config=self.config) self.layers = nn.ModuleList() + # Create transformer layers using the patcher config for layer_idx in range(self.config.n_layers): - self.layers.append(BLTTransformerLayer(config.patcher_layer_config, layer_idx)) + self.layers.append(BLTTransformerLayer(self.config, layer_idx)) + self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim) @@ -1471,6 +1478,7 @@ def init_hash_embeddings( "BLTModel", "BLTPatcher", "BLTLocalEncoder", - "BLTLocalDecoder", + "BLTLocalDecoder", "BLTGlobalTransformer", + "BLTTransformerLayer", ] \ No newline at end of file From 494b48812677c4078a41f391a89f44a522563305 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 23 Jun 2025 16:59:34 +0000 Subject: [PATCH 037/139] wip working on modular file --- .../models/blt_wip/modeling_blt_modular.py | 1235 +++++++++++++++++ 1 file changed, 1235 insertions(+) create mode 100644 src/transformers/models/blt_wip/modeling_blt_modular.py diff --git a/src/transformers/models/blt_wip/modeling_blt_modular.py b/src/transformers/models/blt_wip/modeling_blt_modular.py new file mode 100644 index 000000000000..217cd809af80 --- /dev/null +++ b/src/transformers/models/blt_wip/modeling_blt_modular.py @@ -0,0 +1,1235 @@ +# coding=utf-8 +# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLT model.""" + +from ...utils import is_torch_flex_attn_available, logging +from typing import Callable, List, Optional, Tuple, Union + +from ...cache_utils import Cache +from ...activations import ACT2FN + +import torch +import torch.nn +import torch.nn as nn +from torch.nn import functional as F + +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update + +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from .configuration_blt import ( + BLTConfig, + PatchingModeEnum, +) + +from ..mllama.modeling_mllama import MllamaTextRMSNorm, MllamaTextMLP, MllamaTextCrossAttention, MllamaRotaryEmbedding, MllamaTextSelfAttention, MllamaSelfAttentionDecoderLayer, eager_attention_forward, repeat_kv + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +# Copied from transformers.models.mllama.modeling_mllama.MllamaTextMLP +class BLTMLP(MllamaTextMLP): + pass + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + # TODO: not exactly equivalent to other transformers implementations,, need feedback + # Extract first head_dim//2 elements which correspond to the unique frequencies + # This matches the original BLT approach which uses head_dim//2 frequency pairs + head_dim = q.shape[-1] + cos_freqs = cos[..., :head_dim//2] # [B, S, D/2] + sin_freqs = sin[..., :head_dim//2] # [B, S, D/2] + + # Expand cos/sin to match query/key tensor format [B, H, S, D/2] + cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + + # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... + q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] + k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] + + # Extract real and i parts + q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] + k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] + + # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] + q_real_rot = cos_freqs * q_real - sin_freqs * q_imag + q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag + k_real_rot = cos_freqs * k_real - sin_freqs * k_imag + k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag + + # Recombine pairs and reshape back to original format + q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] + k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] + + return q_rot.type_as(q), k_rot.type_as(k) + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText +class BLTRMSNorm(MllamaTextRMSNorm): + pass + + +class BLTTransformerLayer(MllamaSelfAttentionDecoderLayer): + pass + + +# Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT +class BLTSelfAttention(MllamaTextSelfAttention): + pass + + +def check_non_zero_after_zero(tensor): + zero_mask = tensor == 0 + shifted_mask = torch.cat( + [ + torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), + zero_mask[:, :-1], + ], + dim=1, + ) + non_zero_after_zero = (tensor != 0) & shifted_mask + return non_zero_after_zero.any() + +def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): + primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, + ] + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) + prime_powers = torch.stack([prime**i for i in range(token_tensor.shape[-1])]) + return torch.sum(token_tensor * prime_powers, dim=-1) + + +def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): + """ + Returns a hash of the input token_ids and maps it to a value in the range [0, max_hash]. + + expects: token_ids of shape (batch_size, seq_len) with values as ids in the token vocab. + returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. + + Note: max hash can make a big difference on the number of collisions. + """ + with torch.no_grad(): + batch_size, seq_len = token_ids.shape + prefix = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) + token_ids = torch.cat([prefix, token_ids], dim=1) + windows = token_ids.unfold(1, group_size, 1) + # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values_range = hashes % max_hash + hash_values_range.requires_grad = False + return hash_values_range + + +def _prepare_patch_cross_attention_mask( + patch_ids: torch.Tensor, + num_patches: int, + sequence_length: int, + patches_as_queries: bool = False, + cross_attn_k: int = 1, + dtype: torch.dtype = torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Prepare cross-attention mask for patch-based attention, following mllama's robust approach. + + This function creates masks that control which patches can attend to which other patches, + with support for query/key role swapping and cross-attention multipliers. + + Args: + patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. + num_patches (int): Total number of patches. + sequence_length (int): Length of the sequence. + patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. + cross_attn_k (int): Cross-attention multiplier for repeating patches. + dtype (torch.dtype): Data type for the output mask. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] + - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows + """ + batch_size, seq_len = patch_ids.shape + device = patch_ids.device + + # Determine query and key lengths based on configuration + if patches_as_queries: + q_len = num_patches * cross_attn_k + kv_len = sequence_length + # Create patch-to-sequence mapping + q_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(-1).expand( + batch_size, num_patches, seq_len + ) + kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) + else: + q_len = sequence_length + kv_len = num_patches * cross_attn_k + # Create sequence-to-patch mapping + q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) + kv_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand( + batch_size, seq_len, num_patches + ) + + # Create base attention mask - boolean mask where True means "should attend" + # Exact patch matching + cross_attention_mask = q_patch_ids == kv_patch_ids + + # Handle cross_attn_k multiplier by repeating along appropriate dimension + repeat_dim = 1 if patches_as_queries else -1 + cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) + + # Validate dimensions + expected_shape = (batch_size, q_len, kv_len) + if cross_attention_mask.shape != expected_shape: + raise ValueError(f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}") + + # Reshape so it can be used by attn module - add head dimension + cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len] + + # Invert the mask (following mllama pattern exactly) + # True -> 0.0 (attend), False -> 1.0 (will become -inf) + inverted_cross_attn_mask = (1.0 - cross_attention_mask.to(dtype)) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # Apply full-row bias (following mllama pattern exactly) + # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor: + if max_patch_length is None: + return patch_lengths + + batch_size = patch_lengths.size(0) + split_all = [] + max_len = 0 + + for seq in patch_lengths: + splits = [] + for length in seq[seq > 0]: + # Split long patches into max_patch_length chunks + full, rem = divmod(length.item(), max_patch_length) + splits.extend([max_patch_length] * full + ([rem] if rem else [])) + split_all.append(splits) + max_len = max(max_len, len(splits)) + + # Pad sequences to the maximum length + padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) + for i, splits in enumerate(split_all): + if splits: + padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) + + # Trim trailing columns that are all zeros + last_non_zero = (padded != 0).flip(1).int().argmax(1).min() + if last_non_zero < padded.shape[1]: + padded = padded[:, :padded.shape[1] - last_non_zero] + + return padded + + +class BLTRotaryEmbedding(MllamaRotaryEmbedding): + pass + + + +class BLTLocalEncoder(nn.Module): + def __init__(self, config: BLTConfig): + super().__init__() + + # Extract config values to instance attributes + self.dropout = config.dropout + self.dim_local_encoder = config.dim_local_encoder + self.n_layers_local_encoder = config.n_layers_local_encoder + self.n_heads_local_encoder = config.n_heads_local_encoder + self.vocab_size = config.vocab_size + self.pm_size = config.pm_size + self.cross_attn_encoder = config.cross_attn_encoder + self.cross_attn_nheads = config.cross_attn_nheads + self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_k = config.cross_attn_k + self.norm_eps = config.norm_eps + self.sliding_window = config.sliding_window + + encoder_config = config.encoder_config + self.layers = nn.ModuleList([BLTTransformerLayer(encoder_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)]) + + self.rotary_emb = BLTRotaryEmbedding(config=encoder_config) + + self.token_embedding_projection = ( + nn.Linear(config.encoder_dim_token_emb, self.dim_local_encoder, bias=False) + if config.encoder_dim_token_emb is not None and config.encoder_dim_token_emb != self.dim_local_encoder + else None + ) + + self.patch_embedding_projection = self._create_patch_projection(config) + + self.embed_tokens = nn.Embedding(self.vocab_size + self.pm_size, self.dim_local_encoder) + + self.cross_attn_layers = None + if self.cross_attn_encoder and self.cross_attn_nheads is not None: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.n_layers_local_encoder if self.cross_attn_all_layers_encoder else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.dim_local_encoder) + ) + + def forward( + self, + input_ids: torch.Tensor, + input_embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ """ + batch_size, sequence_length = input_ids.shape + if input_embeds is None: + input_embeds = self.embed_tokens(input_ids) + + batch_size, _, _ = input_embeds.shape + + hidden_states = input_embeds + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + for idx, layer in enumerate(self.layers): + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + if self.cross_attn_encoder and (idx == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): + # Initialize patch_embeds if not provided when cross attention is enabled + if patch_embeds is None: + patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) + if self.patch_embedding_projection is not None: + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.dim_local_encoder) + + layer_idx = idx if self.cross_attn_all_layers_encoder else 0 + cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( + hidden_states=patch_embeds, + cross_attention_states=hidden_states, + attention_mask=cross_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + output_attentions=False, + use_cache=False, + cache_position=None, + ) + patch_embeds = patch_embeds + cross_attention_output + + encoder_cross_states = patch_embeds if self.cross_attn_encoder else None + return (hidden_states, encoder_cross_states), cache + + def _create_patch_projection(self, config): + dimension_mismatch = config.encoder_dim_patch_emb is not None and config.encoder_dim_patch_emb != config.dim_local_encoder + + cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( + config.cross_attn_decoder and config.cross_attn_init_by_pooling + ) + + if not (dimension_mismatch or cross_attn_conditions): + return None + + output_dim = config.encoder_dim_token_emb * config.cross_attn_k + + return nn.Linear( + in_features=config.encoder_dim_patch_emb, + out_features=output_dim, + bias=False, + ) + + def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + batch_size, seq_len, embedding_dim = hidden_states.shape + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) + + reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device) + reduced_embeddings = reduced_embeddings.scatter_reduce( + src=hidden_states, + dim=1, + index=patch_ids, + reduce=reduction, + include_self=False, + ) + reduced_embeddings = reduced_embeddings[:, :max_num_patches, :] + + return reduced_embeddings + + +class BLTLocalDecoder(nn.Module): + def __init__(self, config: BLTConfig): + super().__init__() + + # Extract config values to instance attributes + self.dim_local_decoder = config.dim_local_decoder + self.n_heads_local_decoder = config.n_heads_local_decoder + self.n_layers_local_decoder = config.n_layers_local_decoder + self.vocab_size = config.vocab_size + self.norm_eps = config.norm_eps + self.dropout = config.dropout + self.cross_attn_decoder = config.cross_attn_decoder + self.cross_attn_nheads = config.cross_attn_nheads + self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder + self.cross_attn_k = config.cross_attn_k + self.sliding_window = config.sliding_window + + decoder_config = config.decoder_config + self.layers = nn.ModuleList([BLTTransformerLayer(decoder_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)]) + + self.rotary_emb = BLTRotaryEmbedding(config=decoder_config) + + self.token_embedding_projection = ( + nn.Linear(config.decoder_dim_token_emb, self.dim_local_decoder, bias=False) + if config.decoder_dim_token_emb is not None and config.decoder_dim_token_emb != self.dim_local_decoder + else None + ) + + self.patch_embedding_projection = self._create_patch_projection(config) + + self.norm = BLTRMSNorm(self.dim_local_decoder, eps=self.norm_eps) + + self.cross_attn_layers = None + if self.cross_attn_decoder and self.cross_attn_nheads is not None: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.n_layers_local_decoder if self.cross_attn_all_layers_decoder else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.dim_local_decoder) + ) + + self.lm_head = nn.Linear( + self.dim_local_decoder, + self.vocab_size, + bias=False, + ) + + def _create_patch_projection(self, config): + dimension_mismatch = config.dim_global is not None and config.dim_global != config.dim_local_decoder + + # Check cross attention conditions + cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( + config.cross_attn_decoder and config.cross_attn_init_by_pooling + ) + + if not (dimension_mismatch or cross_attn_conditions): + return None + + output_dim = config.decoder_dim_token_emb * config.cross_attn_k + + return nn.Linear( + in_features=config.dim_global, + out_features=output_dim, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor], + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + batch_size, sequence_length = tokens.shape + batch_size, seq_length, _ = embeds.shape + + assert embeds is not None, "Embeddings must be provided" + + hidden_states = embeds + + if self.patch_embedding_projection is not None: + assert patch_embeds is not None, "Patch embeddings must be passed." + patch_embeds = self.patch_embedding_projection(patch_embeds) + if self.cross_attn_k is not None: + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.dim_local_decoder) + + if patch_embeds is not None and not self.cross_attn_decoder: + hidden_states = hidden_states + patch_embeds + + position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + for i, layer in enumerate(self.layers): + if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder): + # Use cross attention to extract info from patch_embeds into hidden_states + cross_attention_output, _, _ = self.cross_attn_layers[i]( + hidden_states=hidden_states, + cross_attention_states=patch_embeds, + attention_mask=cross_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + output_attentions=False, + use_cache=False, + cache_position=None, + ) + hidden_states = hidden_states + cross_attention_output + + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + logits = self.lm_head(self.norm(hidden_states)) + return logits, cache + + +class BLTCrossAttention(nn.Module): + """Cross-attention module for BLT, following transformers style""" + + def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + # Use provided hidden_size or fallback to encoder dimension + self.hidden_size = hidden_size or config.dim_local_encoder + self.num_heads = config.cross_attn_nheads + self.num_key_value_heads = config.cross_attn_nheads # Assuming same for cross attention + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = None #self.head_dim ** -0.5 + self.dropout = config.dropout + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.q_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) + self.k_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(hidden_states) # BLT normalizes first + query_states = self.q_proj(query_states) + + if cross_attention_states is not None: + cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + if past_key_value is not None: + # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states + # we still update the cross key states, past_cross_states, new_cross_states. And use it! + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position is not None and cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + if cross_attention_states is None: + raise ValueError( + "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!" + ) + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0, #if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + # Apply full row masking if provided (following mllama pattern) + if full_text_row_masked_out_mask is not None: + attn_output = full_text_row_masked_out_mask[:, 0] * attn_output + + attn_output = attn_output + hidden_states + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class BLTGlobalTransformer(nn.Module): + def __init__(self, config): + super().__init__() + + self.dim_global = config.dim_global + self.n_heads_global = config.n_heads_global + self.n_layers_global = config.n_layers_global + self.dropout = config.dropout + + global_config = config.global_config + self.layers = nn.ModuleList() + for layer_idx in range(self.n_layers_global): + self.layers.append(BLTTransformerLayer(global_config, layer_idx)) + + self.rotary_emb = BLTRotaryEmbedding(config=global_config) + + self.token_embedding_projection = None + if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim_global: + self.token_embedding_projection = nn.Linear( + config.global_dim_patch_emb, + self.dim_global, + bias=False, + ) + + def forward( + self, + input_ids: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + input_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + batch_size, seq_length, _ = input_embeds.shape + + hidden_states = input_embeds + + if self.token_embedding_projection is not None and hidden_states.shape[-1] != self.dim_global: + hidden_states = self.token_embedding_projection(hidden_states) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for i, layer in enumerate(self.layers): + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + return hidden_states, cache + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """ + Compute embeddings using hash token embeddings. + + Args: + local_encoder_tokens: Input tokens tensor + local_encoder: Encoder object with embed_tokens method + encoder_hash_tok_embedding: ModuleList of hash token embeddings + encoder_hash_byte_group_nb_functions: Number of hash functions + encoder_hash_byte_group_size: List of byte group sizes + encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings + + Returns: + torch.Tensor: Combined embeddings + """ + if encoder_hash_tok_embedding is None: + return None + + local_encoder_embeds = local_encoder.embed_tokens(local_encoder_tokens) + + i = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for byte_group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, + byte_group_size, + hash_func_nb=func_nb, + max_hash=encoder_hash_byte_group_vocab, + ) + hash_tok_embedding = encoder_hash_tok_embedding[i] + local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) + i += 1 + + assert i == len(encoder_hash_tok_embedding) + return local_encoder_embeds + + +class BLTPreTrainedModel(PreTrainedModel): + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False # BLT uses its own attention implementation + _supports_sdpa = True + _supports_cache_class = False + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + std = getattr(module, '_custom_std', module.in_features ** (-0.5)) + + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.Embedding): + std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) + + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + elif isinstance(module, BLTModel): + if module.encoder_hash_tok_embedding is not None: + emb_std = module.config.dim_local_encoder ** (-0.5) + for emb in module.encoder_hash_tok_embedding: + emb._custom_std = emb_std + + elif isinstance(module, BLTLocalEncoder): + if module.token_embedding_projection is not None: + module.token_embedding_projection._custom_std = module.config.dim_local_encoder ** (-0.5) + + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.config.encoder_dim_patch_emb ** (-0.5) + + elif isinstance(module, BLTLocalDecoder): + if module.token_embedding_projection is not None: + module.token_embedding_projection._custom_std = module.config.dim_local_decoder ** (-0.5) + + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.config.dim_global ** (-0.5) + + elif isinstance(module, BLTGlobalTransformer): + if module.token_embedding_projection is not None: + module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5) + + elif isinstance(module, BLTPatcher): + emb_std = module.config.dim ** (-0.5) + module.embed_tokens._custom_std = emb_std + module.lm_head._custom_std = emb_std + + +class BLTModel(BLTPreTrainedModel): + def __init__(self, config: BLTConfig): + super().__init__(config) + + # Extract frequently used config values + self.patch_in_forward = config.patch_in_forward + self.patching_mode = config.patching_mode + self.patch_size = config.patch_size + self.patching_threshold = config.patching_threshold + self.max_patch_length = config.max_patch_length + self.patching_batch_size = config.patching_batch_size + self.patching_device = config.patching_device + self.cross_attn_encoder = config.cross_attn_encoder + self.cross_attn_decoder = config.cross_attn_decoder + self.cross_attn_k = config.cross_attn_k + self.cross_attn_window_encoder = config.cross_attn_window_encoder + self.cross_attn_window_decoder = config.cross_attn_window_decoder + self.boe_id = config.boe_id + self.eos_token_id = config.eos_token_id + + self.local_encoder = BLTLocalEncoder(config) + self.global_transformer = BLTGlobalTransformer(config) + self.local_decoder = BLTLocalDecoder(config) + + self.encoder_hash_tok_embedding = init_hash_embeddings( + config, + local_encoder_dim=config.dim_local_encoder, + encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, + ) + + if self.patch_in_forward: + self.patcher = BLTPatcher(config) + self.patcher.eval() + for param in self.patcher.parameters(): + param.requires_grad = False + else: + self.patcher = None + + def forward( + self, + tokens: torch.Tensor, + patch_lengths: Optional[torch.Tensor] = None, + ): + # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings + # are no longer used in the final BLT model + + batch_size, sequence_length = tokens.shape # Batch size and sequence length + + local_encoder_tokens, local_decoder_tokens = tokens, tokens + + # Patching + if patch_lengths is None: + # assert ( + # getattr(self, "patch_in_forward", None) is not None and self.patch_in_forward + # ), "Patch in forward not enabled and no patch_lengths passed." + + # PATCHER MODEL DEFINED + if self.patching_mode == PatchingModeEnum.entropy: + _, patch_lengths, _ = self.patcher( + local_encoder_tokens, + patch_size=self.patch_size, + include_next_token=True, + threshold=self.patching_threshold, + max_patch_length=self.max_patch_length, + patching_batch_size=self.patching_batch_size, + device=self.patching_device, + ) + else: + # self.patching_mode == PatchingModeEnum.byte + batch_size_tokens, seq_len = local_encoder_tokens.shape + seq_len_next_tok = seq_len + 1 # include_next_token=True + patch_lengths = torch.ones( + (batch_size_tokens, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + ) + + patch_lengths = process_patch_lengths(patch_lengths, self.max_patch_length) + + #assert torch.min(patch_lengths) >= 0 + # Generate patch IDs from patch_lengths + patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1]) + # assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), ( + # f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" + # ) + + # Cross-attention encoder + cross_attn_mask_enc = None + full_text_row_masked_out_mask_enc = None + if self.cross_attn_encoder: + cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( + patch_ids=patch_ids, + num_patches=patch_lengths.shape[1], + sequence_length=sequence_length, + patches_as_queries=True, + cross_attn_k=self.cross_attn_k, + dtype=torch.float32, + ) + + # Hashing and embedding + local_encoder_embeds = compute_hash_embeddings( + local_encoder_tokens=local_encoder_tokens, + local_encoder=self.local_encoder, + encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, + encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab, + ) + + # NOTE: Frequency-based n-gram embeddings removed as per paper + # The final BLT model uses only hash-based n-gram embeddings + + # Local encoder + (encoder_hidden_states, encoder_cross_states), cache_encoder = self.local_encoder( + input_ids=local_encoder_tokens, + input_embeds=local_encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + + # Downsampling + if encoder_cross_states is not None: + # Cross attention is enabled - use cross states + global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) + else: + # Cross attention is disabled - use reduced embeddings from encoder hidden states + global_hidden_states = self.local_encoder.patch_reduce( + encoder_hidden_states, patch_lengths.shape[1], "amax", patch_ids + ) + + # Global transformer + global_tokens = tokens.new(global_hidden_states.shape[0], global_hidden_states.shape[1]).fill_(self.boe_id) + rows, cols = torch.where(local_encoder_tokens == self.eos_token_id) + eos_patch_ids = patch_ids[rows, cols] + global_tokens[rows, eos_patch_ids] = self.eos_token_id + + global_hidden_states, _ = self.global_transformer( + input_embeds=global_hidden_states, + input_ids=global_tokens, + ) + + # Unpatching + + decoder_embeds = encoder_hidden_states + + # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens), so we need to map decoder positions to the remaining patches. + decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], local_decoder_tokens.shape[-1]) + # assert torch.max(decoder_patch_ids) + 1 <= global_hidden_states.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {global_hidden_states.shape[1]}" + # assert decoder_patch_ids.shape[1] == decoder_embeds.shape[1], ( + # f"{decoder_patch_ids.shape[1]} != {decoder_embeds.shape[1]}" + # ) + + # Cross-attention decoder + cross_attn_mask_dec = None + full_text_row_masked_out_mask_dec = None + if not self.cross_attn_decoder: + patch_hidden_states = torch.gather(global_hidden_states, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, global_hidden_states.shape[-1])) + # assert local_decoder_tokens.shape == patch_hidden_states.shape[:-1] + else: + patch_hidden_states = global_hidden_states + cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( + patch_ids=decoder_patch_ids, + num_patches=patch_lengths.shape[1], + sequence_length=sequence_length, + patches_as_queries=False, + cross_attn_k=self.cross_attn_k, + dtype=torch.float32, + ) + + # Local decoder + output, _ = self.local_decoder( + embeds=decoder_embeds, + patch_embeds=patch_hidden_states, + tokens=local_decoder_tokens, + cross_mask=cross_attn_mask_dec, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, + ) + return output + + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: + """ + Convert patch lengths to patch IDs for each token position. + For each token position in the sequence, determines which patch it belongs to. + + Args: + patch_lengths: [batch_size, num_patches] - length of each patch + seq_len: total sequence length + + Returns: + patch_ids: [batch_size, seq_len] - patch index for each token position + + Example: + patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1 + seq_len = 10 + Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]] + # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3 + """ + batch_size, num_patches = patch_lengths.shape + + # Create patch start positions: [0, 3, 5, 9] for the example above + patch_starts = torch.cat( + [ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1], # cumsum without the final total + ], + dim=-1, + ) + + # For each token position, find which patch it belongs to + # by finding the rightmost patch start that's <= the position + token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1] + + # Broadcasting: patch_starts[batch, patch] <= token_positions[position] + # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t + position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1) + + # Count how many patch starts are <= each position, then subtract 1 to get patch index + patch_ids = position_ge_patch_start.sum(dim=-1) - 1 + + return patch_ids + + +class BLTPatcher(BLTPreTrainedModel): + def __init__(self, config): + super().__init__(config.patcher_config) + + self.rotary_emb = BLTRotaryEmbedding(config=self.config) + + self.layers = nn.ModuleList() + # Create transformer layers using the patcher config + for layer_idx in range(self.config.n_layers): + self.layers.append(BLTTransformerLayer(self.config, layer_idx)) + + + self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim) + + self.norm = BLTRMSNorm(self.config.dim, eps=self.config.norm_eps) + + self.lm_head = nn.Linear( + self.config.dim, + self.config.vocab_size, + bias=False, + ) + + def forward( + self, + token_values: torch.Tensor, + patch_size: Optional[int] = None, + include_next_token: bool = True, + threshold: Optional[float] = None, + max_patch_length: Optional[int] = None, + patching_batch_size: int = 1, + device: Optional[str] = None, + ): + + # Handle chunked processing for entropy calculation + entropies = [] + predictions = [] + max_length = self.config.max_seqlen + batch_numel = max_length * patching_batch_size + splits = torch.split(token_values.flatten(), batch_numel) + + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length) + if device is not None: + split = split.to(device) + + # Process chunk: embeddings -> layers -> output + batch_size, sequence_length = split.shape + input_embeds = self.embed_tokens(split) + + hidden_states = input_embeds + + batch_size, _, _ = input_embeds.shape + + position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) # = BLT self.rope + + for i, layer in enumerate(self.layers): + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) #, attn_impl=self.config.patcher_attn_impl ) + hidden_states = layer_outputs[0] + + logits = self.lm_head(self.norm(hidden_states)) + logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] + predictions.append(logits) + prediction_entropies = self.entropy(logits) + entropies.append(prediction_entropies) + + concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) + concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) + + # Always compute patch lengths from concatenated entropies + batch_size, sequence_length = token_values.shape + seq_len_next_tok = sequence_length + 1 if include_next_token else sequence_length + + # Find patch start IDs based on entropy + if patch_size is not None: + patch_start_ids = self.find_entropy_patch_start_ids( + concat_entropies, + patch_size, + include_next_token=include_next_token, + threshold=threshold + ) + patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok) + else: + # Default to byte-level patching + patch_lengths = torch.ones((batch_size, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) + + patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) + return concat_entropies, patch_lengths, concat_predictions + + + @staticmethod + def entropy(scores): + """ + scores: [batch_size, seq_len, vocab] + returns [batch_size, seq_len] + + Computes the entropy for each token in the batch. + Note: uses natural log. + """ + log_probs = F.log_softmax(scores, dim=-1) + probs = torch.exp(log_probs) + p_log_p = log_probs * probs + entropy = -p_log_p.sum(dim=-1) + return entropy + + @staticmethod + def patch_start_ids_from_patch_start_mask(patch_start_mask): + batch_size, trunc_seq_len = patch_start_mask.shape + max_patches = patch_start_mask.sum(dim=1).max() + if max_patches == 0: + patch_start_ids = torch.full( + (batch_size, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + else: + patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(batch_size, 1) + extra_patch_ids = torch.full( + (batch_size, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) + patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1) + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(batch_size, trunc_seq_len)[:, :max_patches] + return patch_start_ids + + @staticmethod + def patch_lengths_from_start_ids(patch_start_ids, seq_len): + """ + Calculate patch lengths from start ids. + start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then + the rest are filled to the seq len. + seq_len: ex: 7 length of the sequence + + returns the patch lengths: + [1, 6] for the above example. + """ + last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) + patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) + patch_lengths = patch_end_ids - patch_start_ids + 1 + assert torch.all(patch_lengths >= 0), f"{patch_lengths}" + assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" + return patch_lengths + + @staticmethod + def find_entropy_patch_start_ids( + entropies, + patch_size=None, + threshold=None, + include_next_token=True, + ): + """ + Use entropies to find the start ids of each patch. + Use patch_size or threshold to figure out the total number of patches to allocate. + + When threshold is not None the number of patches is not constant between + different sequences, but patches can be identified incrementally rather than + decided globally using the entire sequence. + """ + batch_size, sequence_length = entropies.shape[:2] + + first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + predictions_truncation_len = first_ids.shape[1] # remove the first predictions because they will be start of patches. + entropies = entropies[:, 1:] + if threshold is None: + num_patches = sequence_length // patch_size + patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices + patch_start_ids = patch_start_ids.sort(dim=1).values + else: + patch_start_mask = entropies > threshold + if not include_next_token: + patch_start_mask = patch_start_mask[:, :-1] + # patch_start_mask[1:] |= tokens[:-1] < OFFSET + patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) + + patch_start_ids = torch.cat((first_ids, patch_start_ids + predictions_truncation_len), dim=1) + return patch_start_ids + +def init_hash_embeddings( + config, + local_encoder_dim: int, + encoder_hash_byte_group_size: list, +): + """Initialize hash-based token embeddings for the BLT encoder.""" + if config.encoder_hash_byte_group_size is None: + return None + + embeddings = [] + emb_dim = local_encoder_dim + encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab + + for _ in range(config.encoder_hash_byte_group_nb_functions): + for _ in encoder_hash_byte_group_size: + embeddings.append( + nn.Embedding( + encoder_hash_byte_group_vocab, + emb_dim, + ) + ) + + return nn.ModuleList(embeddings) + + + + +__all__ = [ + "BLTPreTrainedModel", + "BLTModel", + "BLTPatcher", + "BLTLocalEncoder", + "BLTLocalDecoder", + "BLTGlobalTransformer", + "BLTTransformerLayer", +] \ No newline at end of file From 477406e9964840fee5e08e3cb535b89182e81f49 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 25 Jun 2025 10:38:04 +0000 Subject: [PATCH 038/139] cleaning up patch and configs --- .../models/blt_wip/configuration_blt.py | 85 +-- .../models/blt_wip/modeling_blt.py | 524 ++++++------------ 2 files changed, 224 insertions(+), 385 deletions(-) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index 8d9601433e14..476a9d22ffc1 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -45,7 +45,12 @@ class BLTLocalEncoderConfig(PretrainedConfig): model_type = "blt_local_encoder" def __init__( - self, + self, + vocab_size=256, + cross_attn_all_layers=True, + cross_attn_k=2, + dim_global=2048, + pm_size=0, hidden_size=512, num_attention_heads=8, num_key_value_heads=None, @@ -59,8 +64,14 @@ def __init__( rope_scaling=None, hidden_act="silu", multiple_of=256, + _attn_implementation="sdpa", **kwargs, ): + self.vocab_size = vocab_size + self.cross_attn_all_layers = cross_attn_all_layers + self.cross_attn_k = cross_attn_k + self.dim_global=dim_global + self.pm_size=pm_size self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads or num_attention_heads @@ -74,10 +85,13 @@ def __init__( self.rope_scaling = rope_scaling or {"rope_type": "default"} self.hidden_act = hidden_act self.multiple_of = multiple_of + self._attn_implementation = _attn_implementation + self.decoder_dim_token_emb = 1024 + self.encoder_dim_token_emb=1024 + self.encoder_dim_patch_emb=self.hidden_size super().__init__(**kwargs) - - + class BLTLocalDecoderConfig(PretrainedConfig): """ Configuration class for the BLT Local Decoder component. @@ -87,6 +101,10 @@ class BLTLocalDecoderConfig(PretrainedConfig): def __init__( self, + vocab_size=256, + cross_attn_all_layers=True, + cross_attn_k=2, + dim_global=2048, hidden_size=512, num_attention_heads=8, num_key_value_heads=None, @@ -100,8 +118,13 @@ def __init__( rope_scaling=None, hidden_act="silu", multiple_of=256, + _attn_implementation="sdpa", **kwargs, ): + self.vocab_size = vocab_size + self.cross_attn_all_layers = cross_attn_all_layers + self.cross_attn_k = cross_attn_k + self.dim_global=dim_global self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads or num_attention_heads @@ -115,7 +138,10 @@ def __init__( self.rope_scaling = rope_scaling or {"rope_type": "default"} self.hidden_act = hidden_act self.multiple_of = multiple_of - + self._attn_implementation = _attn_implementation + self.decoder_dim_token_emb=1024 + self.encoder_dim_token_emb=1024 + super().__init__(**kwargs) @@ -141,6 +167,7 @@ def __init__( rope_scaling=None, hidden_act="silu", multiple_of=256, + _attn_implementation="sdpa", **kwargs, ): self.hidden_size = hidden_size @@ -156,6 +183,7 @@ def __init__( self.rope_scaling = rope_scaling or {"rope_type": "default"} self.hidden_act = hidden_act self.multiple_of = multiple_of + self._attn_implementation = _attn_implementation super().__init__(**kwargs) @@ -511,6 +539,7 @@ def __init__( rope_use_fp32_in_outer_product=False, # Attention configuration attn_impl="sdpa", + _attn_implementation="sdpa", attn_bias_type="causal", local_attention_window_len=None, use_rope=True, @@ -537,8 +566,8 @@ def __init__( cross_attn_decoder=False, cross_attn_window_encoder=None, cross_attn_window_decoder=None, - cross_attn_k=1, - cross_attn_nheads=None, + cross_attn_k=2, + cross_attn_nheads=16, cross_attn_all_layers_decoder=False, cross_attn_all_layers_encoder=False, cross_attn_use_flex_attention=True, @@ -610,6 +639,7 @@ def __init__( # Attention configuration self.attn_impl = attn_impl + self._attn_implementation = _attn_implementation self.attn_bias_type = attn_bias_type self.local_attention_window_len = local_attention_window_len self.use_rope = use_rope @@ -675,6 +705,11 @@ def __init__( # Initialize component configurations self.encoder_config = BLTLocalEncoderConfig( + vocab_size=vocab_size, + cross_attn_all_layers=cross_attn_all_layers_encoder, + cross_attn_k=cross_attn_k, + dim_global=dim_global, + pm_size=pm_size, hidden_size=dim_local_encoder, num_attention_heads=n_heads_local_encoder, num_key_value_heads=n_kv_heads, @@ -689,6 +724,10 @@ def __init__( ) self.decoder_config = BLTLocalDecoderConfig( + vocab_size=vocab_size, + cross_attn_all_layers=cross_attn_all_layers_decoder, + cross_attn_k=cross_attn_k, + dim_global=dim_global, hidden_size=dim_local_decoder, num_attention_heads=n_heads_local_decoder, num_key_value_heads=n_kv_heads, @@ -702,6 +741,7 @@ def __init__( multiple_of=multiple_of, ) + self.global_config = BLTGlobalTransformerConfig( hidden_size=dim_global, num_attention_heads=n_heads_global, @@ -751,30 +791,9 @@ def __init__( **kwargs, ) - - @property - def encoder_dim_token_emb(self): - if self.dim_token is not None: - return self.dim_token - elif self.use_local_encoder_transformer: - return self.dim_local_encoder - else: - # Use default patch_size of 8 if not set - patch_size = self.patch_size if self.patch_size is not None else 8 - return self.dim_global // patch_size - - @property - def encoder_dim_patch_emb(self): - if self.cross_attn_encoder: - if self.cross_attn_init_by_pooling: - return self.dim_local_encoder - else: - return self.dim_global - return None - @property def global_dim_patch_emb(self): - dim_token_emb = self.encoder_dim_token_emb + dim_token_emb = self.dim_local_encoder if self.cross_attn_encoder: cross_attn_k = self.cross_attn_k if self.cross_attn_k is not None else 1 return dim_token_emb * cross_attn_k @@ -788,16 +807,6 @@ def global_dim_patch_emb(self): else: return dim_token_emb * sum([pooling in self.downsampling_by_pooling for pooling in ["avg", "min", "max"]]) - @property - def decoder_dim_token_emb(self): - if self.share_encoder_decoder_emb: - return self.encoder_dim_token_emb - elif self.dim_token is not None: - return self.dim_token - else: - return self.dim_local_decoder - - def get_init_std_factor(self, depth: int) -> float: """ diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index bda6772572f3..a08626d495d7 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -21,6 +21,7 @@ from ...activations import ACT2FN import torch +import torch.distributions import torch.nn import torch.nn as nn from torch.nn import functional as F @@ -514,47 +515,42 @@ def forward(self, x, position_ids): class BLTLocalEncoder(nn.Module): - def __init__(self, config: BLTConfig): + def __init__(self, config: BLTLocalEncoderConfig): super().__init__() - - # Extract config values to instance attributes + + self.hidden_size = config.hidden_size + self.vocab_size=config.vocab_size + self.num_hidden_layers = config.num_hidden_layers self.dropout = config.dropout - self.dim_local_encoder = config.dim_local_encoder - self.n_layers_local_encoder = config.n_layers_local_encoder - self.n_heads_local_encoder = config.n_heads_local_encoder - self.vocab_size = config.vocab_size - self.pm_size = config.pm_size - self.cross_attn_encoder = config.cross_attn_encoder - self.cross_attn_nheads = config.cross_attn_nheads - self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder - self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove + self.cross_attn_all_layers = config.cross_attn_all_layers self.cross_attn_k = config.cross_attn_k - self.norm_eps = config.norm_eps - self.sliding_window = config.sliding_window - encoder_config = config.encoder_config - self.layers = nn.ModuleList([BLTTransformerLayer(encoder_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)]) + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)]) - self.rotary_emb = BLTRotaryEmbedding(config=encoder_config) + self.rotary_emb = BLTRotaryEmbedding(config=config) self.token_embedding_projection = ( - nn.Linear(config.encoder_dim_token_emb, self.dim_local_encoder, bias=False) - if config.encoder_dim_token_emb is not None and config.encoder_dim_token_emb != self.dim_local_encoder + nn.Linear(config.encoder_dim_token_emb, self.hidden_size, bias=False) + if config.encoder_dim_token_emb is not None and config.encoder_dim_token_emb != self.hidden_size else None ) - self.patch_embedding_projection = self._create_patch_projection(config) + self.patch_embedding_projection = nn.Linear( + in_features=config.encoder_dim_patch_emb, + out_features=config.encoder_dim_token_emb * config.cross_attn_k, + bias=False, + ) + - self.embed_tokens = nn.Embedding(self.vocab_size + self.pm_size, self.dim_local_encoder) + self.embed_tokens = nn.Embedding(self.vocab_size + config.pm_size, self.hidden_size) - self.cross_attn_layers = None - if self.cross_attn_encoder and self.cross_attn_nheads is not None: - self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.n_layers_local_encoder if self.cross_attn_all_layers_encoder else 1 - for layer_idx in range(layers_to_add): - self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.dim_local_encoder) - ) + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size) + ) def forward( self, @@ -569,34 +565,28 @@ def forward( cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): """ """ - batch_size, sequence_length = input_ids.shape if input_embeds is None: input_embeds = self.embed_tokens(input_ids) batch_size, _, _ = input_embeds.shape - hidden_states = input_embeds - - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(input_embeds, p=self.dropout, training=self.training) position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) - hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) for idx, layer in enumerate(self.layers): layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) hidden_states = layer_outputs[0] - if self.cross_attn_encoder and (idx == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): - # Initialize patch_embeds if not provided when cross attention is enabled - if patch_embeds is None: - patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) - if self.patch_embedding_projection is not None: - patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.dim_local_encoder) + if idx == len(self.layers) - 1 or self.cross_attn_all_layers_encoder: + patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size) - layer_idx = idx if self.cross_attn_all_layers_encoder else 0 + layer_idx = idx if self.cross_attn_all_layers else 0 cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, @@ -608,27 +598,9 @@ def forward( ) patch_embeds = patch_embeds + cross_attention_output - encoder_cross_states = patch_embeds if self.cross_attn_encoder else None - return (hidden_states, encoder_cross_states), cache + encoder_cross_states = patch_embeds + return hidden_states, encoder_cross_states - def _create_patch_projection(self, config): - dimension_mismatch = config.encoder_dim_patch_emb is not None and config.encoder_dim_patch_emb != config.dim_local_encoder - - cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( - config.cross_attn_decoder and config.cross_attn_init_by_pooling - ) - - if not (dimension_mismatch or cross_attn_conditions): - return None - - output_dim = config.encoder_dim_token_emb * config.cross_attn_k - - return nn.Linear( - in_features=config.encoder_dim_patch_emb, - out_features=output_dim, - bias=False, - ) - def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): """ Reduce variable length patches to single embedding per patch @@ -640,7 +612,7 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): (i.e. if the sum(patch_lengths[i]) < seq_len for any i) will be sent to a dummy patch, which is trimmed before returning. """ - batch_size, seq_len, embedding_dim = hidden_states.shape + batch_size, _, embedding_dim = hidden_states.shape patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) @@ -658,70 +630,43 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): class BLTLocalDecoder(nn.Module): - def __init__(self, config: BLTConfig): + def __init__(self, config: BLTLocalDecoderConfig): super().__init__() # Extract config values to instance attributes - self.dim_local_decoder = config.dim_local_decoder - self.n_heads_local_decoder = config.n_heads_local_decoder - self.n_layers_local_decoder = config.n_layers_local_decoder - self.vocab_size = config.vocab_size - self.norm_eps = config.norm_eps + self.hidden_size = config.hidden_size + self.vocab_size=config.vocab_size + self.num_hidden_layers = config.num_hidden_layers self.dropout = config.dropout - self.cross_attn_decoder = config.cross_attn_decoder - self.cross_attn_nheads = config.cross_attn_nheads - self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder + self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove + self.cross_attn_all_layers = config.cross_attn_all_layers self.cross_attn_k = config.cross_attn_k - self.sliding_window = config.sliding_window - decoder_config = config.decoder_config - self.layers = nn.ModuleList([BLTTransformerLayer(decoder_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)]) + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)]) - self.rotary_emb = BLTRotaryEmbedding(config=decoder_config) + self.rotary_emb = BLTRotaryEmbedding(config=config) - self.token_embedding_projection = ( - nn.Linear(config.decoder_dim_token_emb, self.dim_local_decoder, bias=False) - if config.decoder_dim_token_emb is not None and config.decoder_dim_token_emb != self.dim_local_decoder - else None + self.patch_embedding_projection = nn.Linear( + in_features=config.dim_global, + out_features=config.decoder_dim_token_emb * config.cross_attn_k, + bias=False, ) - self.patch_embedding_projection = self._create_patch_projection(config) - - self.norm = BLTRMSNorm(self.dim_local_decoder, eps=self.norm_eps) + self.norm = BLTRMSNorm(self.hidden_size, eps=config.norm_eps) - self.cross_attn_layers = None - if self.cross_attn_decoder and self.cross_attn_nheads is not None: - self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.n_layers_local_decoder if self.cross_attn_all_layers_decoder else 1 - for layer_idx in range(layers_to_add): - self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.dim_local_decoder) - ) + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size) + ) self.lm_head = nn.Linear( - self.dim_local_decoder, + self.hidden_size, self.vocab_size, bias=False, ) - def _create_patch_projection(self, config): - dimension_mismatch = config.dim_global is not None and config.dim_global != config.dim_local_decoder - - # Check cross attention conditions - cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( - config.cross_attn_decoder and config.cross_attn_init_by_pooling - ) - - if not (dimension_mismatch or cross_attn_conditions): - return None - - output_dim = config.decoder_dim_token_emb * config.cross_attn_k - - return nn.Linear( - in_features=config.dim_global, - out_features=output_dim, - bias=False, - ) def forward( self, @@ -733,18 +678,12 @@ def forward( full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - batch_size, sequence_length = tokens.shape - batch_size, seq_length, _ = embeds.shape - - assert embeds is not None, "Embeddings must be provided" + batch_size, _, _ = embeds.shape hidden_states = embeds - if self.patch_embedding_projection is not None: - assert patch_embeds is not None, "Patch embeddings must be passed." - patch_embeds = self.patch_embedding_projection(patch_embeds) - if self.cross_attn_k is not None: - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.dim_local_decoder) + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size) if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds @@ -754,7 +693,7 @@ def forward( hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): - if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder): + if i == 0 or self.cross_attn_all_layers: # Use cross attention to extract info from patch_embeds into hidden_states cross_attention_output, _, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, @@ -783,8 +722,8 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.layer_idx = layer_idx # Use provided hidden_size or fallback to encoder dimension self.hidden_size = hidden_size or config.dim_local_encoder - self.num_heads = config.cross_attn_nheads - self.num_key_value_heads = config.cross_attn_nheads # Assuming same for cross attention + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.scaling = None #self.head_dim ** -0.5 @@ -839,6 +778,8 @@ def forward( attention_interface: Callable = eager_attention_forward + self.config._attn_implementation = "sdpa" + attn = "sdpa" if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -905,7 +846,6 @@ def __init__(self, config): def forward( self, input_ids: torch.Tensor, - tok_idx: Optional[torch.Tensor] = None, input_embeds: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, @@ -1042,7 +982,7 @@ class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) - # Extract frequently used config values + # Core configuration self.patch_in_forward = config.patch_in_forward self.patching_mode = config.patching_mode self.patch_size = config.patch_size @@ -1050,24 +990,27 @@ def __init__(self, config: BLTConfig): self.max_patch_length = config.max_patch_length self.patching_batch_size = config.patching_batch_size self.patching_device = config.patching_device - self.cross_attn_encoder = config.cross_attn_encoder - self.cross_attn_decoder = config.cross_attn_decoder + + # Cross attention configuration (always enabled) self.cross_attn_k = config.cross_attn_k - self.cross_attn_window_encoder = config.cross_attn_window_encoder - self.cross_attn_window_decoder = config.cross_attn_window_decoder + + # Token IDs self.boe_id = config.boe_id self.eos_token_id = config.eos_token_id - self.local_encoder = BLTLocalEncoder(config) + # Model components + self.local_encoder = BLTLocalEncoder(config.encoder_config) self.global_transformer = BLTGlobalTransformer(config) - self.local_decoder = BLTLocalDecoder(config) + self.local_decoder = BLTLocalDecoder(config.decoder_config) + # Hash embeddings self.encoder_hash_tok_embedding = init_hash_embeddings( config, local_encoder_dim=config.dim_local_encoder, encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, ) + # Patcher initialization if self.patch_in_forward: self.patcher = BLTPatcher(config) self.patcher.eval() @@ -1076,28 +1019,14 @@ def __init__(self, config: BLTConfig): else: self.patcher = None - def forward( - self, - tokens: torch.Tensor, - patch_lengths: Optional[torch.Tensor] = None, - ): - # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings - # are no longer used in the final BLT model - - batch_size, sequence_length = tokens.shape # Batch size and sequence length - - local_encoder_tokens, local_decoder_tokens = tokens, tokens + def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None): + batch_size, sequence_length = tokens.shape - # Patching + # Handle patching if patch_lengths is None: - # assert ( - # getattr(self, "patch_in_forward", None) is not None and self.patch_in_forward - # ), "Patch in forward not enabled and no patch_lengths passed." - - # PATCHER MODEL DEFINED if self.patching_mode == PatchingModeEnum.entropy: _, patch_lengths, _ = self.patcher( - local_encoder_tokens, + tokens, patch_size=self.patch_size, include_next_token=True, threshold=self.patching_threshold, @@ -1106,52 +1035,30 @@ def forward( device=self.patching_device, ) else: - # self.patching_mode == PatchingModeEnum.byte - batch_size_tokens, seq_len = local_encoder_tokens.shape - seq_len_next_tok = seq_len + 1 # include_next_token=True - patch_lengths = torch.ones( - (batch_size_tokens, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + # Default to byte-level patching + patch_lengths = process_patch_lengths( + torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device), + self.max_patch_length ) - patch_lengths = process_patch_lengths(patch_lengths, self.max_patch_length) - - #assert torch.min(patch_lengths) >= 0 - # Generate patch IDs from patch_lengths - patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1]) - # assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), ( - # f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" - # ) - - # Cross-attention encoder - cross_attn_mask_enc = None - full_text_row_masked_out_mask_enc = None - if self.cross_attn_encoder: - cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( - patch_ids=patch_ids, - num_patches=patch_lengths.shape[1], - sequence_length=sequence_length, - patches_as_queries=True, - cross_attn_k=self.cross_attn_k, - dtype=torch.float32, - ) - - # Hashing and embedding - local_encoder_embeds = compute_hash_embeddings( - local_encoder_tokens=local_encoder_tokens, - local_encoder=self.local_encoder, - encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, - encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions, - encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size, - encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab, + # Generate patch IDs and prepare cross-attention masks + patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) + cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( + patch_ids, patch_lengths.shape[1], sequence_length, True, self.cross_attn_k, torch.float32 ) - # NOTE: Frequency-based n-gram embeddings removed as per paper - # The final BLT model uses only hash-based n-gram embeddings + # Compute embeddings with hashing + encoder_embeds = compute_hash_embeddings( + tokens, self.local_encoder, self.encoder_hash_tok_embedding, + self.config.encoder_hash_byte_group_nb_functions, + self.config.encoder_hash_byte_group_size, + self.config.encoder_hash_byte_group_vocab, + ) - # Local encoder - (encoder_hidden_states, encoder_cross_states), cache_encoder = self.local_encoder( - input_ids=local_encoder_tokens, - input_embeds=local_encoder_embeds, + # Local encoder forward pass + encoder_hidden_states, encoder_cross_states = self.local_encoder( + input_ids=tokens, + input_embeds=encoder_embeds, patch_embeds=None, cross_mask=cross_attn_mask_enc, full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, @@ -1159,106 +1066,44 @@ def forward( patch_ids=patch_ids, ) - # Downsampling - if encoder_cross_states is not None: - # Cross attention is enabled - use cross states - global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - else: - # Cross attention is disabled - use reduced embeddings from encoder hidden states - global_hidden_states = self.local_encoder.patch_reduce( - encoder_hidden_states, patch_lengths.shape[1], "amax", patch_ids - ) - - # Global transformer - global_tokens = tokens.new(global_hidden_states.shape[0], global_hidden_states.shape[1]).fill_(self.boe_id) - rows, cols = torch.where(local_encoder_tokens == self.eos_token_id) - eos_patch_ids = patch_ids[rows, cols] - global_tokens[rows, eos_patch_ids] = self.eos_token_id + # Global transformer forward pass + global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) + global_tokens = tokens.new_full((batch_size, global_hidden_states.shape[1]), self.boe_id) + eos_positions = torch.where(tokens == self.eos_token_id) + global_tokens[eos_positions[0], patch_ids[eos_positions]] = self.eos_token_id global_hidden_states, _ = self.global_transformer( - input_embeds=global_hidden_states, input_ids=global_tokens, + input_embeds=global_hidden_states, ) - # Unpatching - - decoder_embeds = encoder_hidden_states - - # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens), so we need to map decoder positions to the remaining patches. - decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], local_decoder_tokens.shape[-1]) - # assert torch.max(decoder_patch_ids) + 1 <= global_hidden_states.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {global_hidden_states.shape[1]}" - # assert decoder_patch_ids.shape[1] == decoder_embeds.shape[1], ( - # f"{decoder_patch_ids.shape[1]} != {decoder_embeds.shape[1]}" - # ) - - # Cross-attention decoder - cross_attn_mask_dec = None - full_text_row_masked_out_mask_dec = None - if not self.cross_attn_decoder: - patch_hidden_states = torch.gather(global_hidden_states, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, global_hidden_states.shape[-1])) - # assert local_decoder_tokens.shape == patch_hidden_states.shape[:-1] - else: - patch_hidden_states = global_hidden_states - cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( - patch_ids=decoder_patch_ids, - num_patches=patch_lengths.shape[1], - sequence_length=sequence_length, - patches_as_queries=False, - cross_attn_k=self.cross_attn_k, - dtype=torch.float32, - ) + # Local decoder forward pass + decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) + cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( + decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.cross_attn_k, torch.float32 + ) - # Local decoder output, _ = self.local_decoder( - embeds=decoder_embeds, - patch_embeds=patch_hidden_states, - tokens=local_decoder_tokens, + tokens=tokens, + embeds=encoder_hidden_states, + patch_embeds=global_hidden_states, + mask=None, cross_mask=cross_attn_mask_dec, full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, ) + return output def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: - """ - Convert patch lengths to patch IDs for each token position. - For each token position in the sequence, determines which patch it belongs to. - - Args: - patch_lengths: [batch_size, num_patches] - length of each patch - seq_len: total sequence length - - Returns: - patch_ids: [batch_size, seq_len] - patch index for each token position - - Example: - patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1 - seq_len = 10 - Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]] - # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3 - """ - batch_size, num_patches = patch_lengths.shape - - # Create patch start positions: [0, 3, 5, 9] for the example above - patch_starts = torch.cat( - [ - torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), - patch_lengths.cumsum(dim=-1)[:, :-1], # cumsum without the final total - ], - dim=-1, - ) - - # For each token position, find which patch it belongs to - # by finding the rightmost patch start that's <= the position - token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1] - - # Broadcasting: patch_starts[batch, patch] <= token_positions[position] - # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t - position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1) - - # Count how many patch starts are <= each position, then subtract 1 to get patch index - patch_ids = position_ge_patch_start.sum(dim=-1) - 1 - - return patch_ids + """Convert patch lengths to patch IDs for each token position.""" + batch_size = patch_lengths.shape[0] + patch_starts = torch.cat([ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1] + ], dim=-1) + + token_positions = torch.arange(seq_len, device=patch_lengths.device) + return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 class BLTPatcher(BLTPreTrainedModel): @@ -1328,7 +1173,7 @@ def forward( logits = self.lm_head(self.norm(hidden_states)) logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] predictions.append(logits) - prediction_entropies = self.entropy(logits) + prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() entropies.append(prediction_entropies) concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) @@ -1340,13 +1185,14 @@ def forward( # Find patch start IDs based on entropy if patch_size is not None: - patch_start_ids = self.find_entropy_patch_start_ids( + patch_lengths = self.patch_lengths_from_entropies( concat_entropies, - patch_size, - include_next_token=include_next_token, - threshold=threshold + seq_len_next_tok, + patch_size=patch_size, + threshold=threshold, + include_next_token=include_next_token ) - patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok) + else: # Default to byte-level patching patch_lengths = torch.ones((batch_size, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) @@ -1356,83 +1202,61 @@ def forward( @staticmethod - def entropy(scores): + def patch_start_ids_from_patch_start_mask(patch_start_mask: torch.BoolTensor) -> torch.LongTensor: """ - scores: [batch_size, seq_len, vocab] - returns [batch_size, seq_len] + Convert a binary patch start mask into a tensor of patch start indices. - Computes the entropy for each token in the batch. - Note: uses natural log. + Each row in the output contains the indices where patches start, + padded with the sequence length if fewer patches were found. """ - log_probs = F.log_softmax(scores, dim=-1) - probs = torch.exp(log_probs) - p_log_p = log_probs * probs - entropy = -p_log_p.sum(dim=-1) - return entropy + batch_size, seq_len = patch_start_mask.shape + max_patches = patch_start_mask.sum(dim=1).max().item() + + # Fill with seq_len as padding + patch_start_ids = torch.full( + (batch_size, max_patches), + fill_value=seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) - @staticmethod - def patch_start_ids_from_patch_start_mask(patch_start_mask): - batch_size, trunc_seq_len = patch_start_mask.shape - max_patches = patch_start_mask.sum(dim=1).max() - if max_patches == 0: - patch_start_ids = torch.full( - (batch_size, trunc_seq_len), - trunc_seq_len, - dtype=torch.long, - device=patch_start_mask.device, - ) - else: - patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(batch_size, 1) - extra_patch_ids = torch.full( - (batch_size, trunc_seq_len), - trunc_seq_len, - dtype=torch.long, - device=patch_start_mask.device, - ) - all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) - patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1) - patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(batch_size, trunc_seq_len)[:, :max_patches] - return patch_start_ids + for i in range(batch_size): + ids = torch.nonzero(patch_start_mask[i], as_tuple=False).flatten() + patch_start_ids[i, :ids.numel()] = ids - @staticmethod - def patch_lengths_from_start_ids(patch_start_ids, seq_len): - """ - Calculate patch lengths from start ids. - start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then - the rest are filled to the seq len. - seq_len: ex: 7 length of the sequence + return patch_start_ids - returns the patch lengths: - [1, 6] for the above example. - """ - last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) - patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) - patch_lengths = patch_end_ids - patch_start_ids + 1 - assert torch.all(patch_lengths >= 0), f"{patch_lengths}" - assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" - return patch_lengths @staticmethod - def find_entropy_patch_start_ids( - entropies, - patch_size=None, - threshold=None, - include_next_token=True, - ): + def patch_lengths_from_entropies( + entropies: torch.Tensor, + seq_len: int, + patch_size: Optional[int] = None, + threshold: Optional[float] = None, + include_next_token: bool = True, + ) -> torch.Tensor: """ - Use entropies to find the start ids of each patch. - Use patch_size or threshold to figure out the total number of patches to allocate. + Compute patch lengths directly from entropies. + + Args: + entropies (Tensor): [batch_size, sequence_length] + seq_len (int): sequence length including next token if used. + patch_size (int, optional): Number of patches to extract if threshold is not given. + threshold (float, optional): Entropy threshold for dynamic patching. + include_next_token (bool): Whether to account for next token in patch span. - When threshold is not None the number of patches is not constant between - different sequences, but patches can be identified incrementally rather than - decided globally using the entire sequence. + Returns: + patch_lengths (Tensor): [batch_size, num_patches] of patch lengths """ batch_size, sequence_length = entropies.shape[:2] + # Always keep first patch starting at 0 and 1 first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) - predictions_truncation_len = first_ids.shape[1] # remove the first predictions because they will be start of patches. - entropies = entropies[:, 1:] + entropies = entropies[:, 1:] # Skip first token for entropy-based selection + offset = first_ids.shape[1] + if threshold is None: + assert patch_size is not None, "patch_size must be specified when threshold is None" num_patches = sequence_length // patch_size patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices patch_start_ids = patch_start_ids.sort(dim=1).values @@ -1440,11 +1264,17 @@ def find_entropy_patch_start_ids( patch_start_mask = entropies > threshold if not include_next_token: patch_start_mask = patch_start_mask[:, :-1] - # patch_start_mask[1:] |= tokens[:-1] < OFFSET patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) - patch_start_ids = torch.cat((first_ids, patch_start_ids + predictions_truncation_len), dim=1) - return patch_start_ids + # Final start ids (prepend 0 and 1) + patch_start_ids = torch.cat((first_ids, patch_start_ids + offset), dim=1) + + # Compute patch lengths + last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) + patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) + patch_lengths = patch_end_ids - patch_start_ids + 1 + return patch_lengths + def init_hash_embeddings( config, From 13a79a52204bbe422123364f86a14fe1904527b3 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 25 Jun 2025 11:30:40 +0000 Subject: [PATCH 039/139] clean up patcher helpers --- .../models/blt_wip/configuration_blt.py | 17 +- .../models/blt_wip/modeling_blt.py | 172 ++++++------------ 2 files changed, 64 insertions(+), 125 deletions(-) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index 476a9d22ffc1..fed44222add6 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -168,6 +168,7 @@ def __init__( hidden_act="silu", multiple_of=256, _attn_implementation="sdpa", + global_dim_patch_emb=None, **kwargs, ): self.hidden_size = hidden_size @@ -184,6 +185,7 @@ def __init__( self.hidden_act = hidden_act self.multiple_of = multiple_of self._attn_implementation = _attn_implementation + self.global_dim_patch_emb = global_dim_patch_emb super().__init__(**kwargs) @@ -754,6 +756,7 @@ def __init__( rope_scaling={"type": "default", "rope_type": "default"}, hidden_act=hidden_act, multiple_of=multiple_of, + global_dim_patch_emb=self.global_dim_patch_emb, ) # Initialize patcher configuration @@ -793,19 +796,7 @@ def __init__( @property def global_dim_patch_emb(self): - dim_token_emb = self.dim_local_encoder - if self.cross_attn_encoder: - cross_attn_k = self.cross_attn_k if self.cross_attn_k is not None else 1 - return dim_token_emb * cross_attn_k - elif ( - self.downsampling_by_pooling is None - or not self.downsampling_by_pooling - or len(self.downsampling_by_pooling) == 0 - ): - patch_size = self.patch_size if self.patch_size is not None else 8 - return dim_token_emb * patch_size - else: - return dim_token_emb * sum([pooling in self.downsampling_by_pooling for pooling in ["avg", "min", "max"]]) + return self.dim_local_encoder * self.cross_attn_k def get_init_std_factor(self, depth: int) -> float: diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index a08626d495d7..72ac16522847 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -379,6 +379,8 @@ def _prepare_patch_cross_attention_mask( cross_attn_k: int = 1, dtype: torch.dtype = torch.float32, ) -> Tuple[torch.Tensor, torch.Tensor]: + #TODO: refactor to be more readable + """ Prepare cross-attention mask for patch-based attention, following mllama's robust approach. @@ -454,6 +456,7 @@ def _prepare_patch_cross_attention_mask( return cross_attention_mask, full_text_row_masked_out_mask def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor: + #TODO: refactor to be more readable if max_patch_length is None: return patch_lengths @@ -513,7 +516,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - class BLTLocalEncoder(nn.Module): def __init__(self, config: BLTLocalEncoderConfig): super().__init__() @@ -522,7 +524,6 @@ def __init__(self, config: BLTLocalEncoderConfig): self.vocab_size=config.vocab_size self.num_hidden_layers = config.num_hidden_layers self.dropout = config.dropout - self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove self.cross_attn_all_layers = config.cross_attn_all_layers self.cross_attn_k = config.cross_attn_k @@ -530,19 +531,12 @@ def __init__(self, config: BLTLocalEncoderConfig): self.rotary_emb = BLTRotaryEmbedding(config=config) - self.token_embedding_projection = ( - nn.Linear(config.encoder_dim_token_emb, self.hidden_size, bias=False) - if config.encoder_dim_token_emb is not None and config.encoder_dim_token_emb != self.hidden_size - else None - ) - self.patch_embedding_projection = nn.Linear( in_features=config.encoder_dim_patch_emb, out_features=config.encoder_dim_token_emb * config.cross_attn_k, bias=False, ) - self.embed_tokens = nn.Embedding(self.vocab_size + config.pm_size, self.hidden_size) self.cross_attn_layers = torch.nn.ModuleList() @@ -820,28 +814,19 @@ def forward( class BLTGlobalTransformer(nn.Module): - def __init__(self, config): + def __init__(self, config: BLTGlobalTransformerConfig): super().__init__() - self.dim_global = config.dim_global - self.n_heads_global = config.n_heads_global - self.n_layers_global = config.n_layers_global + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers self.dropout = config.dropout - global_config = config.global_config self.layers = nn.ModuleList() - for layer_idx in range(self.n_layers_global): - self.layers.append(BLTTransformerLayer(global_config, layer_idx)) + for layer_idx in range(self.num_hidden_layers): + self.layers.append(BLTTransformerLayer(config, layer_idx)) - self.rotary_emb = BLTRotaryEmbedding(config=global_config) + self.rotary_emb = BLTRotaryEmbedding(config=config) - self.token_embedding_projection = None - if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim_global: - self.token_embedding_projection = nn.Linear( - config.global_dim_patch_emb, - self.dim_global, - bias=False, - ) def forward( self, @@ -850,13 +835,10 @@ def forward( mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - batch_size, seq_length, _ = input_embeds.shape + batch_size, _, _ = input_embeds.shape hidden_states = input_embeds - if self.token_embedding_projection is not None and hidden_states.shape[-1] != self.dim_global: - hidden_states = self.token_embedding_projection(hidden_states) - hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) @@ -926,7 +908,6 @@ class BLTPreTrainedModel(PreTrainedModel): def _init_weights(self, module): if isinstance(module, nn.Linear): std = getattr(module, '_custom_std', module.in_features ** (-0.5)) - nn.init.trunc_normal_( module.weight, mean=0.0, @@ -939,7 +920,6 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) - nn.init.trunc_normal_( module.weight, mean=0.0, @@ -955,23 +935,13 @@ def _init_weights(self, module): emb._custom_std = emb_std elif isinstance(module, BLTLocalEncoder): - if module.token_embedding_projection is not None: - module.token_embedding_projection._custom_std = module.config.dim_local_encoder ** (-0.5) - if module.patch_embedding_projection is not None: module.patch_embedding_projection._custom_std = module.config.encoder_dim_patch_emb ** (-0.5) elif isinstance(module, BLTLocalDecoder): - if module.token_embedding_projection is not None: - module.token_embedding_projection._custom_std = module.config.dim_local_decoder ** (-0.5) - if module.patch_embedding_projection is not None: module.patch_embedding_projection._custom_std = module.config.dim_global ** (-0.5) - elif isinstance(module, BLTGlobalTransformer): - if module.token_embedding_projection is not None: - module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5) - elif isinstance(module, BLTPatcher): emb_std = module.config.dim ** (-0.5) module.embed_tokens._custom_std = emb_std @@ -1000,7 +970,7 @@ def __init__(self, config: BLTConfig): # Model components self.local_encoder = BLTLocalEncoder(config.encoder_config) - self.global_transformer = BLTGlobalTransformer(config) + self.global_transformer = BLTGlobalTransformer(config.global_config) self.local_decoder = BLTLocalDecoder(config.decoder_config) # Hash embeddings @@ -1028,7 +998,6 @@ def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = _, patch_lengths, _ = self.patcher( tokens, patch_size=self.patch_size, - include_next_token=True, threshold=self.patching_threshold, max_patch_length=self.max_patch_length, patching_batch_size=self.patching_batch_size, @@ -1132,7 +1101,6 @@ def forward( self, token_values: torch.Tensor, patch_size: Optional[int] = None, - include_next_token: bool = True, threshold: Optional[float] = None, max_patch_length: Optional[int] = None, patching_batch_size: int = 1, @@ -1181,98 +1149,78 @@ def forward( # Always compute patch lengths from concatenated entropies batch_size, sequence_length = token_values.shape - seq_len_next_tok = sequence_length + 1 if include_next_token else sequence_length # Find patch start IDs based on entropy if patch_size is not None: patch_lengths = self.patch_lengths_from_entropies( - concat_entropies, - seq_len_next_tok, + entropies=concat_entropies, + sequence_length=sequence_length, patch_size=patch_size, threshold=threshold, - include_next_token=include_next_token ) - else: # Default to byte-level patching - patch_lengths = torch.ones((batch_size, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) - + patch_lengths = torch.ones((batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device) patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) return concat_entropies, patch_lengths, concat_predictions - @staticmethod - def patch_start_ids_from_patch_start_mask(patch_start_mask: torch.BoolTensor) -> torch.LongTensor: + def patch_lengths_from_entropies( + entropies, + sequence_length, + patch_size=None, + threshold=None, + ): """ - Convert a binary patch start mask into a tensor of patch start indices. + Computes patch lengths from token entropies. - Each row in the output contains the indices where patches start, - padded with the sequence length if fewer patches were found. + Depending on whether a threshold is provided, the function uses either: + - Top-k selection based on entropy (when `threshold` is None), or + - Thresholding the entropy values (when `threshold` is set). """ - batch_size, seq_len = patch_start_mask.shape - max_patches = patch_start_mask.sum(dim=1).max().item() - - # Fill with seq_len as padding - patch_start_ids = torch.full( - (batch_size, max_patches), - fill_value=seq_len, - dtype=torch.long, - device=patch_start_mask.device, - ) - for i in range(batch_size): - ids = torch.nonzero(patch_start_mask[i], as_tuple=False).flatten() - patch_start_ids[i, :ids.numel()] = ids + batch_size = entropies.shape[0] - return patch_start_ids + # Always include token 0 and 1 as starting tokens + init_tokens = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + offset = init_tokens.shape[1] - - @staticmethod - def patch_lengths_from_entropies( - entropies: torch.Tensor, - seq_len: int, - patch_size: Optional[int] = None, - threshold: Optional[float] = None, - include_next_token: bool = True, - ) -> torch.Tensor: - """ - Compute patch lengths directly from entropies. - - Args: - entropies (Tensor): [batch_size, sequence_length] - seq_len (int): sequence length including next token if used. - patch_size (int, optional): Number of patches to extract if threshold is not given. - threshold (float, optional): Entropy threshold for dynamic patching. - include_next_token (bool): Whether to account for next token in patch span. - - Returns: - patch_lengths (Tensor): [batch_size, num_patches] of patch lengths - """ - batch_size, sequence_length = entropies.shape[:2] - - # Always keep first patch starting at 0 and 1 - first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) - entropies = entropies[:, 1:] # Skip first token for entropy-based selection - offset = first_ids.shape[1] + # Ignore first token entropy (BOS) + entropies = entropies[:, 1:] if threshold is None: - assert patch_size is not None, "patch_size must be specified when threshold is None" + # Use top-k entropy values to define patch start points num_patches = sequence_length // patch_size - patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices - patch_start_ids = patch_start_ids.sort(dim=1).values + topk_indices = entropies.topk(num_patches - 2, dim=1).indices + patch_starts = topk_indices.sort(dim=1).values else: - patch_start_mask = entropies > threshold - if not include_next_token: - patch_start_mask = patch_start_mask[:, :-1] - patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) - - # Final start ids (prepend 0 and 1) - patch_start_ids = torch.cat((first_ids, patch_start_ids + offset), dim=1) - - # Compute patch lengths - last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) - patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) - patch_lengths = patch_end_ids - patch_start_ids + 1 + # Threshold the entropy values to define patch start points + patch_mask = entropies > threshold + + seq_len = patch_mask.shape[1] + + # Create patch IDs (token indices), and add a sentinel to ensure alignment + token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) + sentinel = torch.full_like(token_indices, seq_len) + padded_indices = torch.cat([token_indices, sentinel], dim=1) + + # Pad mask with inverse to align sentinel correctly + padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) + + # Select indices where mask is True + patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) + max_valid_patches = patch_mask.sum(dim=1).max() + patch_starts = patch_starts[:, :max_valid_patches] + + # Offset patch starts to account for the two initial tokens + patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) + + # Compute patch end positions by shifting start positions + last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1) + patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1) + + patch_lengths = patch_ends - patch_start_ids + 1 + return patch_lengths From f686d0b389ac06bbc4c5b7cb6b115de07229e7b9 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 25 Jun 2025 14:55:50 +0000 Subject: [PATCH 040/139] clean up patcher helpers further --- .../models/blt_wip/modeling_blt.py | 226 +++++++----------- 1 file changed, 82 insertions(+), 144 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 72ac16522847..3854ca5e1396 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -40,14 +40,11 @@ if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask - from ...integrations.flex_attention import make_flex_block_causal_mask - logger = logging.get_logger(__name__) -# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -59,7 +56,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# Copied from transformers.models.mllama.modeling_mllama.MllamaTextMLP class BLTMLP(nn.Module): def __init__(self, config): super().__init__() @@ -75,7 +71,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj -# Copied from transformers.models.llama.modeling_llama.eager_attention_forward def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -135,7 +130,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_rot.type_as(q), k_rot.type_as(k) -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText class BLTRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -235,7 +229,6 @@ def forward( return outputs -# Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT class BLTSelfAttention(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() @@ -316,59 +309,66 @@ def forward( attn_weights = None return attn_output, attn_weights, past_key_value - - - -def check_non_zero_after_zero(tensor): - zero_mask = tensor == 0 - shifted_mask = torch.cat( - [ - torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), - zero_mask[:, :-1], - ], - dim=1, - ) - non_zero_after_zero = (tensor != 0) & shifted_mask - return non_zero_after_zero.any() + def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): primes = [ - 1000000007, - 5915587277, - 1500450271, - 3267000013, - 5754853343, - 4093082899, - 9576890767, - 3628273133, - 2860486313, - 5463458053, - 3367900313, + 1000000007, 5915587277, 1500450271, 3267000013, 5754853343, + 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313, ] prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) - prime_powers = torch.stack([prime**i for i in range(token_tensor.shape[-1])]) + powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) + prime_powers = prime ** powers return torch.sum(token_tensor * prime_powers, dim=-1) def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): - """ - Returns a hash of the input token_ids and maps it to a value in the range [0, max_hash]. - - expects: token_ids of shape (batch_size, seq_len) with values as ids in the token vocab. - returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. - - Note: max hash can make a big difference on the number of collisions. - """ + """Hash token groups and map to range [0, max_hash].""" with torch.no_grad(): batch_size, seq_len = token_ids.shape - prefix = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) - token_ids = torch.cat([prefix, token_ids], dim=1) - windows = token_ids.unfold(1, group_size, 1) - # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) + # Add padding for sliding window + padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) + padded_tokens = torch.cat([padding, token_ids], dim=1) + + # Create sliding windows and compute hashes + windows = padded_tokens.unfold(1, group_size, 1) hashes = rolling_polynomial_hash(windows, hash_func_nb) - hash_values_range = hashes % max_hash - hash_values_range.requires_grad = False - return hash_values_range + hash_values = hashes % max_hash + + hash_values.requires_grad = False + return hash_values + + +def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): + """Initialize hash-based token embeddings for the BLT encoder.""" + num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) + embeddings = [ + nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) + for _ in range(num_embeddings) + ] + return nn.ModuleList(embeddings) + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """Compute token embeddings enhanced with hash-based embeddings.""" + embeddings = local_encoder.embed_tokens(local_encoder_tokens) + embedding_idx = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab + ) + embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) + embedding_idx += 1 + + return embeddings def _prepare_patch_cross_attention_mask( @@ -455,34 +455,47 @@ def _prepare_patch_cross_attention_mask( return cross_attention_mask, full_text_row_masked_out_mask -def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor: - #TODO: refactor to be more readable + +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: + """ + Splits patch lengths into smaller segments if they exceed `max_patch_length`. + Pads the result to uniform length across the batch. + + Args: + patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. + max_patch_length (int, optional): Maximum allowed length per patch. + + Returns: + torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. + """ if max_patch_length is None: return patch_lengths batch_size = patch_lengths.size(0) - split_all = [] - max_len = 0 + processed = [] for seq in patch_lengths: splits = [] for length in seq[seq > 0]: - # Split long patches into max_patch_length chunks - full, rem = divmod(length.item(), max_patch_length) - splits.extend([max_patch_length] * full + ([rem] if rem else [])) - split_all.append(splits) - max_len = max(max_len, len(splits)) - - # Pad sequences to the maximum length + length = length.item() + full_chunks, remainder = divmod(length, max_patch_length) + splits.extend([max_patch_length] * full_chunks) + if remainder: + splits.append(remainder) + processed.append(splits) + + # Find max length to pad to + max_len = max(len(splits) for splits in processed) padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) - for i, splits in enumerate(split_all): + + for i, splits in enumerate(processed): if splits: padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) - # Trim trailing columns that are all zeros - last_non_zero = (padded != 0).flip(1).int().argmax(1).min() - if last_non_zero < padded.shape[1]: - padded = padded[:, :padded.shape[1] - last_non_zero] + # Trim zero columns + if (padded != 0).any(dim=0).sum() < padded.shape[1]: + last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 + padded = padded[:, :last_nonzero] return padded @@ -627,12 +640,11 @@ class BLTLocalDecoder(nn.Module): def __init__(self, config: BLTLocalDecoderConfig): super().__init__() - # Extract config values to instance attributes self.hidden_size = config.hidden_size self.vocab_size=config.vocab_size self.num_hidden_layers = config.num_hidden_layers self.dropout = config.dropout - self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove + self.cross_attn_decoder = True self.cross_attn_all_layers = config.cross_attn_all_layers self.cross_attn_k = config.cross_attn_k @@ -655,11 +667,7 @@ def __init__(self, config: BLTLocalDecoderConfig): BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size) ) - self.lm_head = nn.Linear( - self.hidden_size, - self.vocab_size, - bias=False, - ) + self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False,) def forward( @@ -851,48 +859,6 @@ def forward( return hidden_states, cache -def compute_hash_embeddings( - local_encoder_tokens: torch.Tensor, - local_encoder, - encoder_hash_tok_embedding: nn.ModuleList, - encoder_hash_byte_group_nb_functions: int, - encoder_hash_byte_group_size: list, - encoder_hash_byte_group_vocab: int, -) -> torch.Tensor: - """ - Compute embeddings using hash token embeddings. - - Args: - local_encoder_tokens: Input tokens tensor - local_encoder: Encoder object with embed_tokens method - encoder_hash_tok_embedding: ModuleList of hash token embeddings - encoder_hash_byte_group_nb_functions: Number of hash functions - encoder_hash_byte_group_size: List of byte group sizes - encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings - - Returns: - torch.Tensor: Combined embeddings - """ - if encoder_hash_tok_embedding is None: - return None - - local_encoder_embeds = local_encoder.embed_tokens(local_encoder_tokens) - - i = 0 - for func_nb in range(encoder_hash_byte_group_nb_functions): - for byte_group_size in encoder_hash_byte_group_size: - hash_ids = byte_group_hash_function( - local_encoder_tokens, - byte_group_size, - hash_func_nb=func_nb, - max_hash=encoder_hash_byte_group_vocab, - ) - hash_tok_embedding = encoder_hash_tok_embedding[i] - local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) - i += 1 - - assert i == len(encoder_hash_tok_embedding) - return local_encoder_embeds class BLTPreTrainedModel(PreTrainedModel): @@ -982,7 +948,7 @@ def __init__(self, config: BLTConfig): # Patcher initialization if self.patch_in_forward: - self.patcher = BLTPatcher(config) + self.patcher = BLTPatcher(config.patcher_config) self.patcher.eval() for param in self.patcher.parameters(): param.requires_grad = False @@ -1076,8 +1042,8 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> class BLTPatcher(BLTPreTrainedModel): - def __init__(self, config): - super().__init__(config.patcher_config) + def __init__(self, config: BLTPatcherConfig): + super().__init__(config) self.rotary_emb = BLTRotaryEmbedding(config=self.config) @@ -1223,34 +1189,6 @@ def patch_lengths_from_entropies( return patch_lengths - -def init_hash_embeddings( - config, - local_encoder_dim: int, - encoder_hash_byte_group_size: list, -): - """Initialize hash-based token embeddings for the BLT encoder.""" - if config.encoder_hash_byte_group_size is None: - return None - - embeddings = [] - emb_dim = local_encoder_dim - encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab - - for _ in range(config.encoder_hash_byte_group_nb_functions): - for _ in encoder_hash_byte_group_size: - embeddings.append( - nn.Embedding( - encoder_hash_byte_group_vocab, - emb_dim, - ) - ) - - return nn.ModuleList(embeddings) - - - - __all__ = [ "BLTPreTrainedModel", "BLTModel", From cfde8979e9f27abd06142798bd5c1fac6d6d1efc Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 26 Jun 2025 11:21:47 +0000 Subject: [PATCH 041/139] clean up --- .../models/blt_wip/configuration_blt.py | 137 ++---------------- .../models/blt_wip/modeling_blt.py | 99 +++++-------- 2 files changed, 51 insertions(+), 185 deletions(-) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index fed44222add6..3170291fa607 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -213,8 +213,6 @@ class BLTPatcherConfig(PretrainedConfig): Layer normalization epsilon for the entropy model. dropout (`float`, *optional*, defaults to 0.0): Dropout probability for the entropy model. - sliding_window (`int`, *optional*): - Sliding window size for the entropy model attention. ffn_dim_multiplier (`float`, *optional*): Feedforward dimension multiplier for the entropy model. multiple_of (`int`, *optional*, defaults to 256): @@ -227,12 +225,6 @@ class BLTPatcherConfig(PretrainedConfig): Attention implementation for the entropy model. attn_bias_type (`str`, *optional*, defaults to "causal"): Attention bias type for the entropy model. - init_base_std (`float`, *optional*): - Base initialization standard deviation for the entropy model. - init_std_factor (`str`, *optional*, defaults to "disabled"): - Initialization std factor for the entropy model. - dim_token_emb (`int`, *optional*): - Token embedding dimension for the entropy model. weight_tying (`bool`, *optional*, defaults to False): Whether to tie embeddings in the entropy model. bos_token_id (`int`, *optional*, defaults to 1): @@ -254,16 +246,12 @@ def __init__( max_seqlen=1024, norm_eps=1e-5, dropout=0.0, - sliding_window=None, ffn_dim_multiplier=None, multiple_of=256, rope_theta=10000.0, rope_use_fp32_in_outer_product=False, attn_impl="sdpa", attn_bias_type="causal", - init_base_std=None, - init_std_factor="disabled", - dim_token_emb=None, weight_tying=False, bos_token_id=1, eos_token_id=2, @@ -278,16 +266,12 @@ def __init__( self.max_seqlen = max_seqlen self.norm_eps = norm_eps self.dropout = dropout - self.sliding_window = sliding_window self.ffn_dim_multiplier = ffn_dim_multiplier self.multiple_of = multiple_of self.rope_theta = rope_theta self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product self.attn_impl = attn_impl self.attn_bias_type = attn_bias_type - self.init_base_std = init_base_std - self.init_std_factor = InitStdFactor(init_std_factor) - self.dim_token_emb = dim_token_emb self.weight_tying = weight_tying super().__init__( @@ -312,7 +296,7 @@ def __init__( class BLTConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`ByteLatentTransformer`]. It is used to instantiate a + This is the configuration class to store the configuration of a [`BLTModel`]. It is used to instantiate a BLT model according to the specified arguments, defining the model architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the @@ -379,23 +363,9 @@ class BLTConfig(PretrainedConfig): Attention implementation to use ("sdpa" or "flex_attention"). attn_bias_type (`str`, *optional*, defaults to "causal"): Type of attention bias to apply. - local_attention_window_len (`int`, *optional*): - Window length for local attention. use_rope (`bool`, *optional*, defaults to True): Whether to use rotary position embeddings. - # Initialization - init_base_std (`float`, *optional*): - Base standard deviation for weight initialization. - init_std_factor (`str`, *optional*, defaults to "disabled"): - Factor for adjusting initialization standard deviation. - - # Embedding dimensions - dim_token_emb (`int`, *optional*): - Token embedding dimension. - dim_token (`int`, *optional*): - Token dimension. - # Patching configuration patch_in_forward (`bool`, *optional*, defaults to False): Whether to perform patching during forward pass. @@ -437,8 +407,6 @@ class BLTConfig(PretrainedConfig): Whether to apply cross attention to all decoder layers. cross_attn_all_layers_encoder (`bool`, *optional*, defaults to False): Whether to apply cross attention to all encoder layers. - cross_attn_use_flex_attention (`bool`, *optional*, defaults to True): - Whether to use flexible attention for cross attention. cross_attn_init_by_pooling (`bool`, *optional*, defaults to False): Whether to initialize cross attention by pooling. @@ -453,12 +421,6 @@ class BLTConfig(PretrainedConfig): Vocabulary size for hash byte groups. encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3): Number of hash functions for byte groups. - encoder_enable_byte_ngrams (`bool`, *optional*, defaults to False): - Whether to enable byte n-grams in encoder. - encoder_ngram_to_size_str (`str`, *optional*): - String defining n-gram sizes. - downsampling_by_pooling (`str`, *optional*): - Type of pooling for downsampling. # Model behavior share_encoder_decoder_emb (`bool`, *optional*, defaults to True): @@ -466,16 +428,6 @@ class BLTConfig(PretrainedConfig): weight_tying (`bool`, *optional*, defaults to False): Whether to tie input and output embeddings. - # Performance optimization - sequence_parallel (`bool`, *optional*, defaults to False): - Whether to use sequence parallelism. - loss_parallel (`bool`, *optional*, defaults to False): - Whether to use loss parallelism. - fuse_sequence_parallel (`bool`, *optional*, defaults to False): - Whether to fuse sequence parallel operations. - use_fsdp (`bool`, *optional*, defaults to True): - Whether to use fully sharded data parallel. - # Parameter mixing pm_size (`int`, *optional*, defaults to 0): Parameter mixing size. @@ -543,14 +495,7 @@ def __init__( attn_impl="sdpa", _attn_implementation="sdpa", attn_bias_type="causal", - local_attention_window_len=None, use_rope=True, - # Initialization - init_base_std=None, - init_std_factor="disabled", - # Embedding dimensions - dim_token_emb=None, - dim_token=None, # Patching configuration patch_in_forward=False, realtime_patching=True, @@ -572,7 +517,6 @@ def __init__( cross_attn_nheads=16, cross_attn_all_layers_decoder=False, cross_attn_all_layers_encoder=False, - cross_attn_use_flex_attention=True, cross_attn_init_by_pooling=False, # Encoder configurations use_local_encoder_transformer=False, @@ -580,17 +524,9 @@ def __init__( encoder_hash_byte_group_size=None, encoder_hash_byte_group_vocab=30000, encoder_hash_byte_group_nb_functions=3, - encoder_enable_byte_ngrams=False, - encoder_ngram_to_size_str=None, - downsampling_by_pooling=None, # Model behavior share_encoder_decoder_emb=True, weight_tying=False, - # Performance optimization - sequence_parallel=False, - loss_parallel=False, - fuse_sequence_parallel=False, - use_fsdp=True, # Parameter mixing pm_size=0, # Special tokens @@ -604,7 +540,6 @@ def __init__( **kwargs, ): - self.sliding_window = None # Basic model configuration self.vocab_size = vocab_size self.max_seqlen = max_seqlen @@ -643,17 +578,8 @@ def __init__( self.attn_impl = attn_impl self._attn_implementation = _attn_implementation self.attn_bias_type = attn_bias_type - self.local_attention_window_len = local_attention_window_len self.use_rope = use_rope - # Initialization - self.init_base_std = init_base_std - self.init_std_factor = InitStdFactor(init_std_factor) - - # Embedding dimensions - self.dim_token_emb = dim_token_emb - self.dim_token = dim_token - # Patching configuration self.patch_in_forward = patch_in_forward self.realtime_patching = realtime_patching @@ -676,7 +602,6 @@ def __init__( self.cross_attn_nheads = cross_attn_nheads self.cross_attn_all_layers_decoder = cross_attn_all_layers_decoder self.cross_attn_all_layers_encoder = cross_attn_all_layers_encoder - self.cross_attn_use_flex_attention = cross_attn_use_flex_attention self.cross_attn_init_by_pooling = cross_attn_init_by_pooling # Encoder configurations @@ -685,20 +610,11 @@ def __init__( self.encoder_hash_byte_group_size = encoder_hash_byte_group_size self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions - self.encoder_enable_byte_ngrams = encoder_enable_byte_ngrams - self.encoder_ngram_to_size_str = encoder_ngram_to_size_str - self.downsampling_by_pooling = downsampling_by_pooling # Model behavior self.share_encoder_decoder_emb = share_encoder_decoder_emb self.weight_tying = weight_tying - # Performance optimization - self.sequence_parallel = sequence_parallel - self.loss_parallel = loss_parallel - self.fuse_sequence_parallel = fuse_sequence_parallel - self.use_fsdp = use_fsdp - # Parameter mixing self.pm_size = pm_size @@ -720,7 +636,7 @@ def __init__( dropout=dropout, max_position_embeddings=max_encoder_seq_length or max_seqlen, rope_theta=rope_theta, - rope_scaling={"type": "default", "rope_type": "default"}, + rope_scaling={"rope_type": "default"}, hidden_act=hidden_act, multiple_of=multiple_of, ) @@ -738,12 +654,11 @@ def __init__( dropout=dropout, max_position_embeddings=max_encoder_seq_length or max_seqlen, rope_theta=rope_theta, - rope_scaling={"type": "default", "rope_type": "default"}, + rope_scaling={"rope_type": "default"}, hidden_act=hidden_act, multiple_of=multiple_of, ) - self.global_config = BLTGlobalTransformerConfig( hidden_size=dim_global, num_attention_heads=n_heads_global, @@ -753,17 +668,15 @@ def __init__( dropout=dropout, max_position_embeddings=max_seqlen, rope_theta=rope_theta, - rope_scaling={"type": "default", "rope_type": "default"}, + rope_scaling={"rope_type": "default"}, hidden_act=hidden_act, multiple_of=multiple_of, - global_dim_patch_emb=self.global_dim_patch_emb, + global_dim_patch_emb=dim_local_encoder * cross_attn_k, ) - # Initialize patcher configuration if patcher_args is not None: self.patcher_config = BLTPatcherConfig(**patcher_args) else: - # Use default values if no patcher_args provided self.patcher_config = BLTPatcherConfig() # Handle hash byte group size validation @@ -772,16 +685,14 @@ def __init__( int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0 ] - # Rope - self.rope_scaling={ - "type": "default", - "rope_type": "default" - } + # Rope scaling configuration + self.rope_scaling = {"rope_type": "default"} - self.num_key_value_heads=n_heads_local_encoder - self.max_position_embeddings=max_seqlen - self.hidden_size=dim_local_encoder - self.num_attention_heads=n_heads_local_encoder + # Set compatibility attributes for transformers + self.num_key_value_heads = n_heads_local_encoder + self.max_position_embeddings = max_seqlen + self.hidden_size = dim_local_encoder + self.num_attention_heads = n_heads_local_encoder # Calculate intermediate_size using BLTMLP logic for each component # Note: Each component uses its own hidden dimension, not the main dim @@ -794,29 +705,6 @@ def __init__( **kwargs, ) - @property - def global_dim_patch_emb(self): - return self.dim_local_encoder * self.cross_attn_k - - - def get_init_std_factor(self, depth: int) -> float: - """ - Calculate the initialization standard deviation scaling factor for a given layer depth. - - Args: - depth: Current layer depth (0-indexed) - - Returns: - Scaling factor to divide the base initialization std by - """ - if self.init_std_factor == InitStdFactor.CURRENT_DEPTH: - return (2 * (depth + 1)) ** 0.5 - else: # DISABLED - return 1.0 - - - - __all__ = [ "BLTConfig", "BLTPatcherConfig", @@ -826,4 +714,3 @@ def get_init_std_factor(self, depth: int) -> float: "InitStdFactor", "PatchingModeEnum" ] - diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 3854ca5e1396..ff019f572dfb 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -42,6 +42,7 @@ from torch.nn.attention.flex_attention import BlockMask from ...integrations.flex_attention import make_flex_block_causal_mask + logger = logging.get_logger(__name__) @@ -56,20 +57,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class BLTMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj def eager_attention_forward( module: nn.Module, @@ -97,7 +84,6 @@ def eager_attention_forward( return attn_output, attn_weights - def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # TODO: not exactly equivalent to other transformers implementations,, need feedback # Extract first head_dim//2 elements which correspond to the unique frequencies @@ -130,6 +116,23 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_rot.type_as(q), k_rot.type_as(k) + +class BLTMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + class BLTRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -379,8 +382,6 @@ def _prepare_patch_cross_attention_mask( cross_attn_k: int = 1, dtype: torch.dtype = torch.float32, ) -> Tuple[torch.Tensor, torch.Tensor]: - #TODO: refactor to be more readable - """ Prepare cross-attention mask for patch-based attention, following mllama's robust approach. @@ -640,11 +641,12 @@ class BLTLocalDecoder(nn.Module): def __init__(self, config: BLTLocalDecoderConfig): super().__init__() + # Extract config values to instance attributes self.hidden_size = config.hidden_size self.vocab_size=config.vocab_size self.num_hidden_layers = config.num_hidden_layers self.dropout = config.dropout - self.cross_attn_decoder = True + self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove self.cross_attn_all_layers = config.cross_attn_all_layers self.cross_attn_k = config.cross_attn_k @@ -667,7 +669,11 @@ def __init__(self, config: BLTLocalDecoderConfig): BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size) ) - self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False,) + self.lm_head = nn.Linear( + self.hidden_size, + self.vocab_size, + bias=False, + ) def forward( @@ -838,18 +844,17 @@ def __init__(self, config: BLTGlobalTransformerConfig): def forward( self, - input_ids: torch.Tensor, - input_embeds: Optional[torch.Tensor] = None, + input_embeds: torch.Tensor, mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - batch_size, _, _ = input_embeds.shape + batch_size, seq_len, _ = input_embeds.shape hidden_states = input_embeds hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) - position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, layer in enumerate(self.layers): @@ -918,36 +923,19 @@ class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) - # Core configuration - self.patch_in_forward = config.patch_in_forward - self.patching_mode = config.patching_mode - self.patch_size = config.patch_size - self.patching_threshold = config.patching_threshold - self.max_patch_length = config.max_patch_length - self.patching_batch_size = config.patching_batch_size - self.patching_device = config.patching_device - - # Cross attention configuration (always enabled) - self.cross_attn_k = config.cross_attn_k - - # Token IDs - self.boe_id = config.boe_id - self.eos_token_id = config.eos_token_id + self.config = config - # Model components self.local_encoder = BLTLocalEncoder(config.encoder_config) self.global_transformer = BLTGlobalTransformer(config.global_config) self.local_decoder = BLTLocalDecoder(config.decoder_config) - # Hash embeddings self.encoder_hash_tok_embedding = init_hash_embeddings( config, local_encoder_dim=config.dim_local_encoder, encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, ) - # Patcher initialization - if self.patch_in_forward: + if self.config.patch_in_forward: self.patcher = BLTPatcher(config.patcher_config) self.patcher.eval() for param in self.patcher.parameters(): @@ -960,29 +948,27 @@ def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = # Handle patching if patch_lengths is None: - if self.patching_mode == PatchingModeEnum.entropy: + if self.config.patching_mode == PatchingModeEnum.entropy: _, patch_lengths, _ = self.patcher( tokens, - patch_size=self.patch_size, - threshold=self.patching_threshold, - max_patch_length=self.max_patch_length, - patching_batch_size=self.patching_batch_size, - device=self.patching_device, + patch_size=self.config.patch_size, + threshold=self.config.patching_threshold, + max_patch_length=self.config.max_patch_length, + patching_batch_size=self.config.patching_batch_size, + device=self.config.patching_device, ) else: # Default to byte-level patching patch_lengths = process_patch_lengths( torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device), - self.max_patch_length + self.config.max_patch_length ) - # Generate patch IDs and prepare cross-attention masks patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( - patch_ids, patch_lengths.shape[1], sequence_length, True, self.cross_attn_k, torch.float32 + patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32 ) - # Compute embeddings with hashing encoder_embeds = compute_hash_embeddings( tokens, self.local_encoder, self.encoder_hash_tok_embedding, self.config.encoder_hash_byte_group_nb_functions, @@ -990,7 +976,6 @@ def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = self.config.encoder_hash_byte_group_vocab, ) - # Local encoder forward pass encoder_hidden_states, encoder_cross_states = self.local_encoder( input_ids=tokens, input_embeds=encoder_embeds, @@ -1001,21 +986,15 @@ def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = patch_ids=patch_ids, ) - # Global transformer forward pass global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - global_tokens = tokens.new_full((batch_size, global_hidden_states.shape[1]), self.boe_id) - eos_positions = torch.where(tokens == self.eos_token_id) - global_tokens[eos_positions[0], patch_ids[eos_positions]] = self.eos_token_id global_hidden_states, _ = self.global_transformer( - input_ids=global_tokens, input_embeds=global_hidden_states, ) - # Local decoder forward pass decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( - decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.cross_attn_k, torch.float32 + decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32 ) output, _ = self.local_decoder( From 09e574b8190b1453be391d26fb90a78ebc924542 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 26 Jun 2025 12:26:42 +0000 Subject: [PATCH 042/139] some config renaming --- .../models/blt_wip/configuration_blt.py | 55 ++++++++++--------- .../models/blt_wip/modeling_blt_modular.py | 44 ++++++++++----- 2 files changed, 58 insertions(+), 41 deletions(-) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index 3170291fa607..657589200a9d 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -258,7 +258,7 @@ def __init__( **kwargs, ): self.vocab_size = vocab_size - self.dim = dim + self.hidden_size = dim self.n_layers = n_layers self.n_heads = n_heads self.head_dim = head_dim if head_dim is not None else (dim // n_heads) @@ -288,7 +288,7 @@ def __init__( self.hidden_act = "silu" # BLT uses silu activation # Calculate intermediate_size using BLTMLP logic based on actual hidden_size - self.intermediate_size = multiple_of * ((int(8 * dim / 3) + multiple_of - 1) // multiple_of) + self.intermediate_size = multiple_of * ((int(8 * self.hidden_size / 3) + multiple_of - 1) // multiple_of) # Set simple rope scaling for patcher (no complex dynamic rope) self.rope_scaling = {"rope_type": "default"} @@ -305,20 +305,20 @@ class BLTConfig(PretrainedConfig): Args: vocab_size (`int`, *optional*, defaults to 256): Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented. - max_seqlen (`int`, *optional*, defaults to 1024): + max_position_embeddings (`int`, *optional*, defaults to 1024): The maximum sequence length that this model can handle. # Main architecture dimensions - dim (`int`, *optional*, defaults to 512): + hidden_size (`int`, *optional*, defaults to 512): Main dimension of the model. - n_layers (`int`, *optional*, defaults to 8): + num_hidden_layers (`int`, *optional*, defaults to 8): Number of layers in the main transformer. - n_heads (`int`, *optional*, defaults to 8): + num_attention_heads (`int`, *optional*, defaults to 8): Number of attention heads in the main transformer. head_dim (`int`, *optional*): - Dimension of each attention head. If not specified, computed as dim // n_heads. - n_kv_heads (`int`, *optional*): - Number of key-value heads for grouped query attention. If not specified, defaults to n_heads. + Dimension of each attention head. If not specified, computed as hidden_size // num_attention_heads. + num_key_value_heads (`int`, *optional*): + Number of key-value heads for grouped query attention. If not specified, defaults to num_attention_heads. # Component-specific dimensions dim_global (`int`, *optional*, defaults to 512): @@ -464,13 +464,13 @@ class BLTConfig(PretrainedConfig): def __init__( self, vocab_size=256, - max_seqlen=1024, + max_position_embeddings=1024, # Main architecture dimensions - dim=512, - n_layers=8, - n_heads=8, + hidden_size=512, + num_hidden_layers=8, + num_attention_heads=8, head_dim=None, - n_kv_heads=None, + num_key_value_heads=None, # Component-specific dimensions dim_global=512, dim_local_decoder=512, @@ -542,14 +542,14 @@ def __init__( # Basic model configuration self.vocab_size = vocab_size - self.max_seqlen = max_seqlen + self.max_position_embeddings = max_position_embeddings # Main architecture dimensions - self.dim = dim - self.n_layers = n_layers - self.n_heads = n_heads - self.head_dim = head_dim if head_dim is not None else (dim // n_heads) - self.n_kv_heads = n_kv_heads + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim if head_dim is not None else (hidden_size // num_attention_heads) + self.num_key_value_heads = num_key_value_heads # Component-specific dimensions self.dim_global = dim_global @@ -630,11 +630,11 @@ def __init__( pm_size=pm_size, hidden_size=dim_local_encoder, num_attention_heads=n_heads_local_encoder, - num_key_value_heads=n_kv_heads, + num_key_value_heads=num_key_value_heads, num_hidden_layers=n_layers_local_encoder, norm_eps=norm_eps, dropout=dropout, - max_position_embeddings=max_encoder_seq_length or max_seqlen, + max_position_embeddings=max_encoder_seq_length or max_position_embeddings, rope_theta=rope_theta, rope_scaling={"rope_type": "default"}, hidden_act=hidden_act, @@ -648,11 +648,11 @@ def __init__( dim_global=dim_global, hidden_size=dim_local_decoder, num_attention_heads=n_heads_local_decoder, - num_key_value_heads=n_kv_heads, + num_key_value_heads=num_key_value_heads, num_hidden_layers=n_layers_local_decoder, norm_eps=norm_eps, dropout=dropout, - max_position_embeddings=max_encoder_seq_length or max_seqlen, + max_position_embeddings=max_encoder_seq_length or max_position_embeddings, rope_theta=rope_theta, rope_scaling={"rope_type": "default"}, hidden_act=hidden_act, @@ -666,7 +666,7 @@ def __init__( num_hidden_layers=n_layers_global, norm_eps=norm_eps, dropout=dropout, - max_position_embeddings=max_seqlen, + max_position_embeddings=max_position_embeddings, rope_theta=rope_theta, rope_scaling={"rope_type": "default"}, hidden_act=hidden_act, @@ -690,7 +690,7 @@ def __init__( # Set compatibility attributes for transformers self.num_key_value_heads = n_heads_local_encoder - self.max_position_embeddings = max_seqlen + self.max_position_embeddings = max_position_embeddings self.hidden_size = dim_local_encoder self.num_attention_heads = n_heads_local_encoder @@ -705,6 +705,8 @@ def __init__( **kwargs, ) + + __all__ = [ "BLTConfig", "BLTPatcherConfig", @@ -714,3 +716,4 @@ def __init__( "InitStdFactor", "PatchingModeEnum" ] + diff --git a/src/transformers/models/blt_wip/modeling_blt_modular.py b/src/transformers/models/blt_wip/modeling_blt_modular.py index 217cd809af80..0dd0b76d840c 100644 --- a/src/transformers/models/blt_wip/modeling_blt_modular.py +++ b/src/transformers/models/blt_wip/modeling_blt_modular.py @@ -228,33 +228,47 @@ def _prepare_patch_cross_attention_mask( return cross_attention_mask, full_text_row_masked_out_mask -def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor: + +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: + """ + Splits patch lengths into smaller segments if they exceed `max_patch_length`. + Pads the result to uniform length across the batch. + + Args: + patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. + max_patch_length (int, optional): Maximum allowed length per patch. + + Returns: + torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. + """ if max_patch_length is None: return patch_lengths batch_size = patch_lengths.size(0) - split_all = [] - max_len = 0 + processed = [] for seq in patch_lengths: splits = [] for length in seq[seq > 0]: - # Split long patches into max_patch_length chunks - full, rem = divmod(length.item(), max_patch_length) - splits.extend([max_patch_length] * full + ([rem] if rem else [])) - split_all.append(splits) - max_len = max(max_len, len(splits)) - - # Pad sequences to the maximum length + length = length.item() + full_chunks, remainder = divmod(length, max_patch_length) + splits.extend([max_patch_length] * full_chunks) + if remainder: + splits.append(remainder) + processed.append(splits) + + # Find max length to pad to + max_len = max(len(splits) for splits in processed) padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) - for i, splits in enumerate(split_all): + + for i, splits in enumerate(processed): if splits: padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) - # Trim trailing columns that are all zeros - last_non_zero = (padded != 0).flip(1).int().argmax(1).min() - if last_non_zero < padded.shape[1]: - padded = padded[:, :padded.shape[1] - last_non_zero] + # Trim zero columns + if (padded != 0).any(dim=0).sum() < padded.shape[1]: + last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 + padded = padded[:, :last_nonzero] return padded From f649ff34b8ff71f6450dc5f1a897c7698b0070f1 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 26 Jun 2025 12:49:36 +0000 Subject: [PATCH 043/139] clean up unused configs --- .../models/blt_wip/configuration_blt.py | 96 ++++--------------- 1 file changed, 18 insertions(+), 78 deletions(-) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index 657589200a9d..f6a51971afd8 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -70,7 +70,7 @@ def __init__( self.vocab_size = vocab_size self.cross_attn_all_layers = cross_attn_all_layers self.cross_attn_k = cross_attn_k - self.dim_global=dim_global + self.hidden_size_global=dim_global self.pm_size=pm_size self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads @@ -124,7 +124,7 @@ def __init__( self.vocab_size = vocab_size self.cross_attn_all_layers = cross_attn_all_layers self.cross_attn_k = cross_attn_k - self.dim_global=dim_global + self.hidden_size_global=dim_global self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads or num_attention_heads @@ -259,14 +259,14 @@ def __init__( ): self.vocab_size = vocab_size self.hidden_size = dim - self.n_layers = n_layers - self.n_heads = n_heads + self.num_hidden_layers = n_layers + self.num_attention_heads = n_heads self.head_dim = head_dim if head_dim is not None else (dim // n_heads) - self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads + self.num_key_value_heads = n_kv_heads if n_kv_heads is not None else n_heads self.max_seqlen = max_seqlen self.norm_eps = norm_eps self.dropout = dropout - self.ffn_dim_multiplier = ffn_dim_multiplier + self.intermediate_size = ffn_dim_multiplier self.multiple_of = multiple_of self.rope_theta = rope_theta self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product @@ -283,7 +283,7 @@ def __init__( # Add attributes needed for compatibility with transformer models self.hidden_size = dim self.num_attention_heads = n_heads - self.num_key_value_heads = self.n_kv_heads # Use the computed n_kv_heads + self.num_key_value_heads = self.num_key_value_heads # Use the computed n_kv_heads self.max_position_embeddings = max_seqlen self.hidden_act = "silu" # BLT uses silu activation @@ -355,64 +355,36 @@ class BLTConfig(PretrainedConfig): # Positional encoding rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. - rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False): - Whether to use fp32 in RoPE outer product computation. # Attention configuration attn_impl (`str`, *optional*, defaults to "sdpa"): Attention implementation to use ("sdpa" or "flex_attention"). - attn_bias_type (`str`, *optional*, defaults to "causal"): - Type of attention bias to apply. - use_rope (`bool`, *optional*, defaults to True): - Whether to use rotary position embeddings. # Patching configuration patch_in_forward (`bool`, *optional*, defaults to False): Whether to perform patching during forward pass. - realtime_patching (`bool`, *optional*, defaults to True): - Whether to use realtime patching. patch_size (`float`, *optional*): Size of patches for static patching. patching_mode (`str`, *optional*): Mode for patching ("entropy", "static", etc.). patching_threshold (`float`, *optional*): Threshold for entropy-based patching. - patching_threshold_add (`float`, *optional*): - Additional threshold parameter for patching. - monotonicity (`bool`, *optional*, defaults to False): - Whether to enforce monotonicity in patching. patching_batch_size (`int`, *optional*, defaults to 1): Batch size for patching operations. patching_device (`str`, *optional*, defaults to "cuda"): Device to use for patching operations. max_patch_length (`int`, *optional*): Maximum length of patches. - entropy_model_checkpoint_dir (`str`, *optional*): - Directory containing entropy model checkpoint. # Cross attention configurations - cross_attn_encoder (`bool`, *optional*, defaults to False): - Whether to use cross attention in encoder. - cross_attn_decoder (`bool`, *optional*, defaults to False): - Whether to use cross attention in decoder. - cross_attn_window_encoder (`int`, *optional*): - Cross attention window for encoder. - cross_attn_window_decoder (`int`, *optional*): - Cross attention window for decoder. cross_attn_k (`int`, *optional*): Number of cross attention components. - cross_attn_nheads (`int`, *optional*): - Number of heads for cross attention. cross_attn_all_layers_decoder (`bool`, *optional*, defaults to False): Whether to apply cross attention to all decoder layers. cross_attn_all_layers_encoder (`bool`, *optional*, defaults to False): Whether to apply cross attention to all encoder layers. - cross_attn_init_by_pooling (`bool`, *optional*, defaults to False): - Whether to initialize cross attention by pooling. # Encoder configurations - use_local_encoder_transformer (`bool`, *optional*, defaults to False): - Whether to use transformer in local encoder. max_encoder_seq_length (`int`, *optional*): Maximum sequence length for encoder. encoder_hash_byte_group_size (`Any`, *optional*): @@ -423,8 +395,6 @@ class BLTConfig(PretrainedConfig): Number of hash functions for byte groups. # Model behavior - share_encoder_decoder_emb (`bool`, *optional*, defaults to True): - Whether to share encoder and decoder embeddings. weight_tying (`bool`, *optional*, defaults to False): Whether to tie input and output embeddings. @@ -490,42 +460,27 @@ def __init__( hidden_act="silu", # Positional encoding rope_theta=10000.0, - rope_use_fp32_in_outer_product=False, # Attention configuration attn_impl="sdpa", _attn_implementation="sdpa", - attn_bias_type="causal", - use_rope=True, # Patching configuration patch_in_forward=False, - realtime_patching=True, patch_size=None, patching_mode=None, patching_threshold=None, - patching_threshold_add=None, - monotonicity=False, patching_batch_size=1, patching_device="cuda", max_patch_length=None, - entropy_model_checkpoint_dir=None, # Cross attention configurations - cross_attn_encoder=False, - cross_attn_decoder=False, - cross_attn_window_encoder=None, - cross_attn_window_decoder=None, cross_attn_k=2, - cross_attn_nheads=16, cross_attn_all_layers_decoder=False, cross_attn_all_layers_encoder=False, - cross_attn_init_by_pooling=False, # Encoder configurations - use_local_encoder_transformer=False, max_encoder_seq_length=None, encoder_hash_byte_group_size=None, encoder_hash_byte_group_vocab=30000, encoder_hash_byte_group_nb_functions=3, # Model behavior - share_encoder_decoder_emb=True, weight_tying=False, # Parameter mixing pm_size=0, @@ -552,67 +507,52 @@ def __init__( self.num_key_value_heads = num_key_value_heads # Component-specific dimensions - self.dim_global = dim_global - self.dim_local_decoder = dim_local_decoder - self.dim_local_encoder = dim_local_encoder - self.n_layers_global = n_layers_global - self.n_layers_local_decoder = n_layers_local_decoder - self.n_layers_local_encoder = n_layers_local_encoder - self.n_heads_global = n_heads_global - self.n_heads_local_decoder = n_heads_local_decoder - self.n_heads_local_encoder = n_heads_local_encoder - self.n_kv_heads_global = n_kv_heads_global + self.hidden_size_global = dim_global + self.hidden_size_local_decoder = dim_local_decoder + self.hidden_size_local_encoder = dim_local_encoder + self.num_hidden_layers_global = n_layers_global + self.num_hidden_layers_local_decoder = n_layers_local_decoder + self.num_hidden_layers_local_encoder = n_layers_local_encoder + self.num_attention_heads_global = n_heads_global + self.num_attention_heads_local_decoder = n_heads_local_decoder + self.num_attention_heads_local_encoder = n_heads_local_encoder + self.num_key_value_heads_global = n_kv_heads_global # Transformer configuration self.norm_eps = norm_eps self.dropout = dropout - self.ffn_dim_multiplier = ffn_dim_multiplier + self.intermediate_size = ffn_dim_multiplier self.multiple_of = multiple_of self.hidden_act = hidden_act # Positional encoding self.rope_theta = rope_theta - self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product # Attention configuration self.attn_impl = attn_impl self._attn_implementation = _attn_implementation - self.attn_bias_type = attn_bias_type - self.use_rope = use_rope # Patching configuration self.patch_in_forward = patch_in_forward - self.realtime_patching = realtime_patching self.patch_size = patch_size self.patching_mode = patching_mode self.patching_threshold = patching_threshold - self.patching_threshold_add = patching_threshold_add - self.monotonicity = monotonicity self.patching_batch_size = patching_batch_size self.patching_device = patching_device self.max_patch_length = max_patch_length - self.entropy_model_checkpoint_dir = entropy_model_checkpoint_dir # Cross attention configurations - self.cross_attn_encoder = cross_attn_encoder - self.cross_attn_decoder = cross_attn_decoder - self.cross_attn_window_encoder = cross_attn_window_encoder - self.cross_attn_window_decoder = cross_attn_window_decoder self.cross_attn_k = cross_attn_k - self.cross_attn_nheads = cross_attn_nheads self.cross_attn_all_layers_decoder = cross_attn_all_layers_decoder self.cross_attn_all_layers_encoder = cross_attn_all_layers_encoder - self.cross_attn_init_by_pooling = cross_attn_init_by_pooling # Encoder configurations - self.use_local_encoder_transformer = use_local_encoder_transformer self.max_encoder_seq_length = max_encoder_seq_length self.encoder_hash_byte_group_size = encoder_hash_byte_group_size self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions # Model behavior - self.share_encoder_decoder_emb = share_encoder_decoder_emb self.weight_tying = weight_tying # Parameter mixing From c1a55088de37fff9dd340816036d4b005fe5d09a Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 26 Jun 2025 14:36:45 +0000 Subject: [PATCH 044/139] clean up configs --- .../models/blt_wip/configuration_blt.py | 189 +++++++----------- .../models/blt_wip/modeling_blt.py | 24 +-- 2 files changed, 85 insertions(+), 128 deletions(-) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index f6a51971afd8..9e3f4e30341c 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -49,7 +49,7 @@ def __init__( vocab_size=256, cross_attn_all_layers=True, cross_attn_k=2, - dim_global=2048, + hidden_size_global=2048, pm_size=0, hidden_size=512, num_attention_heads=8, @@ -70,8 +70,8 @@ def __init__( self.vocab_size = vocab_size self.cross_attn_all_layers = cross_attn_all_layers self.cross_attn_k = cross_attn_k - self.hidden_size_global=dim_global - self.pm_size=pm_size + self.hidden_size_global = hidden_size_global + self.pm_size = pm_size self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads or num_attention_heads @@ -87,8 +87,8 @@ def __init__( self.multiple_of = multiple_of self._attn_implementation = _attn_implementation self.decoder_dim_token_emb = 1024 - self.encoder_dim_token_emb=1024 - self.encoder_dim_patch_emb=self.hidden_size + self.encoder_dim_token_emb = 1024 + self.encoder_dim_patch_emb = self.hidden_size super().__init__(**kwargs) @@ -104,7 +104,7 @@ def __init__( vocab_size=256, cross_attn_all_layers=True, cross_attn_k=2, - dim_global=2048, + hidden_size_global=2048, hidden_size=512, num_attention_heads=8, num_key_value_heads=None, @@ -124,7 +124,7 @@ def __init__( self.vocab_size = vocab_size self.cross_attn_all_layers = cross_attn_all_layers self.cross_attn_k = cross_attn_k - self.hidden_size_global=dim_global + self.hidden_size_global = hidden_size_global self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads or num_attention_heads @@ -139,8 +139,8 @@ def __init__( self.hidden_act = hidden_act self.multiple_of = multiple_of self._attn_implementation = _attn_implementation - self.decoder_dim_token_emb=1024 - self.encoder_dim_token_emb=1024 + self.decoder_dim_token_emb = 1024 + self.encoder_dim_token_emb = 1024 super().__init__(**kwargs) @@ -197,17 +197,17 @@ class BLTPatcherConfig(PretrainedConfig): Args: vocab_size (`int`, *optional*, defaults to 256): Vocabulary size for the entropy model used in patching. - dim (`int`, *optional*, defaults to 512): + hidden_size (`int`, *optional*, defaults to 512): Hidden dimension for the entropy model. - n_layers (`int`, *optional*, defaults to 8): + num_hidden_layers (`int`, *optional*, defaults to 8): Number of layers in the entropy model. - n_heads (`int`, *optional*, defaults to 8): + num_attention_heads (`int`, *optional*, defaults to 8): Number of attention heads in the entropy model. head_dim (`int`, *optional*): Dimension of each attention head in the entropy model. - n_kv_heads (`int`, *optional*): + num_key_value_heads (`int`, *optional*): Number of key-value heads in the entropy model. - max_seqlen (`int`, *optional*, defaults to 1024): + max_position_embeddings (`int`, *optional*, defaults to 1024): Maximum sequence length for the entropy model. norm_eps (`float`, *optional*, defaults to 1e-5): Layer normalization epsilon for the entropy model. @@ -219,18 +219,10 @@ class BLTPatcherConfig(PretrainedConfig): Make feedforward dimension multiple of this for the entropy model. rope_theta (`float`, *optional*, defaults to 10000.0): RoPE theta parameter for the entropy model. - rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False): - Whether to use fp32 in RoPE outer product for the entropy model. attn_impl (`str`, *optional*, defaults to "sdpa"): Attention implementation for the entropy model. attn_bias_type (`str`, *optional*, defaults to "causal"): Attention bias type for the entropy model. - weight_tying (`bool`, *optional*, defaults to False): - Whether to tie embeddings in the entropy model. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of sequence token id for the entropy model. - eos_token_id (`int`, *optional*, defaults to 2): - End of sequence token id for the entropy model. """ model_type = "blt_patcher" @@ -238,12 +230,12 @@ class BLTPatcherConfig(PretrainedConfig): def __init__( self, vocab_size=256, - dim=512, - n_layers=8, - n_heads=8, + hidden_size=512, + num_hidden_layers=8, + num_attention_heads=8, head_dim=None, - n_kv_heads=None, - max_seqlen=1024, + num_key_value_heads=None, + max_position_embeddings=1024, norm_eps=1e-5, dropout=0.0, ffn_dim_multiplier=None, @@ -252,27 +244,24 @@ def __init__( rope_use_fp32_in_outer_product=False, attn_impl="sdpa", attn_bias_type="causal", - weight_tying=False, bos_token_id=1, eos_token_id=2, **kwargs, ): self.vocab_size = vocab_size - self.hidden_size = dim - self.num_hidden_layers = n_layers - self.num_attention_heads = n_heads - self.head_dim = head_dim if head_dim is not None else (dim // n_heads) - self.num_key_value_heads = n_kv_heads if n_kv_heads is not None else n_heads - self.max_seqlen = max_seqlen + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim if head_dim is not None else (hidden_size // num_attention_heads) + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + self.max_position_embeddings = max_position_embeddings self.norm_eps = norm_eps self.dropout = dropout self.intermediate_size = ffn_dim_multiplier self.multiple_of = multiple_of self.rope_theta = rope_theta - self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product self.attn_impl = attn_impl self.attn_bias_type = attn_bias_type - self.weight_tying = weight_tying super().__init__( bos_token_id=bos_token_id, @@ -281,10 +270,6 @@ def __init__( ) # Add attributes needed for compatibility with transformer models - self.hidden_size = dim - self.num_attention_heads = n_heads - self.num_key_value_heads = self.num_key_value_heads # Use the computed n_kv_heads - self.max_position_embeddings = max_seqlen self.hidden_act = "silu" # BLT uses silu activation # Calculate intermediate_size using BLTMLP logic based on actual hidden_size @@ -321,25 +306,25 @@ class BLTConfig(PretrainedConfig): Number of key-value heads for grouped query attention. If not specified, defaults to num_attention_heads. # Component-specific dimensions - dim_global (`int`, *optional*, defaults to 512): + hidden_size_global (`int`, *optional*, defaults to 512): Dimension of the global transformer component. - dim_local_decoder (`int`, *optional*, defaults to 512): + hidden_size_local_decoder (`int`, *optional*, defaults to 512): Dimension of the local decoder component. - dim_local_encoder (`int`, *optional*, defaults to 512): + hidden_size_local_encoder (`int`, *optional*, defaults to 512): Dimension of the local encoder component. - n_layers_global (`int`, *optional*, defaults to 8): + num_hidden_layers_global (`int`, *optional*, defaults to 8): Number of layers in the global transformer. - n_layers_local_decoder (`int`, *optional*, defaults to 8): + num_hidden_layers_local_decoder (`int`, *optional*, defaults to 8): Number of layers in the local decoder. - n_layers_local_encoder (`int`, *optional*, defaults to 8): + num_hidden_layers_local_encoder (`int`, *optional*, defaults to 8): Number of layers in the local encoder. - n_heads_global (`int`, *optional*, defaults to 8): + num_attention_heads_global (`int`, *optional*, defaults to 8): Number of attention heads in the global transformer. - n_heads_local_decoder (`int`, *optional*, defaults to 8): + num_attention_heads_local_decoder (`int`, *optional*, defaults to 8): Number of attention heads in the local decoder. - n_heads_local_encoder (`int`, *optional*, defaults to 8): + num_attention_heads_local_encoder (`int`, *optional*, defaults to 8): Number of attention heads in the local encoder. - n_kv_heads_global (`int`, *optional*): + num_key_value_heads_global (`int`, *optional*): Number of key-value heads in the global transformer. # Transformer configuration @@ -394,10 +379,6 @@ class BLTConfig(PretrainedConfig): encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3): Number of hash functions for byte groups. - # Model behavior - weight_tying (`bool`, *optional*, defaults to False): - Whether to tie input and output embeddings. - # Parameter mixing pm_size (`int`, *optional*, defaults to 0): Parameter mixing size. @@ -442,16 +423,16 @@ def __init__( head_dim=None, num_key_value_heads=None, # Component-specific dimensions - dim_global=512, - dim_local_decoder=512, - dim_local_encoder=512, - n_layers_global=8, - n_layers_local_decoder=8, - n_layers_local_encoder=8, - n_heads_global=8, - n_heads_local_decoder=8, - n_heads_local_encoder=8, - n_kv_heads_global=None, + hidden_size_global=512, + hidden_size_local_decoder=512, + hidden_size_local_encoder=512, + num_hidden_layers_global=8, + num_hidden_layers_local_decoder=8, + num_hidden_layers_local_encoder=8, + num_attention_heads_global=8, + num_attention_heads_local_decoder=8, + num_attention_heads_local_encoder=8, + num_key_value_heads_global=None, # Transformer configuration norm_eps=1e-5, dropout=0.0, @@ -480,15 +461,8 @@ def __init__( encoder_hash_byte_group_size=None, encoder_hash_byte_group_vocab=30000, encoder_hash_byte_group_nb_functions=3, - # Model behavior - weight_tying=False, # Parameter mixing pm_size=0, - # Special tokens - bos_token_id=1, - eos_token_id=2, - pad_token_id=-1, - boe_id=0, # Patcher configuration patcher_args=None, # Inherited @@ -507,16 +481,16 @@ def __init__( self.num_key_value_heads = num_key_value_heads # Component-specific dimensions - self.hidden_size_global = dim_global - self.hidden_size_local_decoder = dim_local_decoder - self.hidden_size_local_encoder = dim_local_encoder - self.num_hidden_layers_global = n_layers_global - self.num_hidden_layers_local_decoder = n_layers_local_decoder - self.num_hidden_layers_local_encoder = n_layers_local_encoder - self.num_attention_heads_global = n_heads_global - self.num_attention_heads_local_decoder = n_heads_local_decoder - self.num_attention_heads_local_encoder = n_heads_local_encoder - self.num_key_value_heads_global = n_kv_heads_global + self.hidden_size_global = hidden_size_global + self.hidden_size_local_decoder = hidden_size_local_decoder + self.hidden_size_local_encoder = hidden_size_local_encoder + self.num_hidden_layers_global = num_hidden_layers_global + self.num_hidden_layers_local_decoder = num_hidden_layers_local_decoder + self.num_hidden_layers_local_encoder = num_hidden_layers_local_encoder + self.num_attention_heads_global = num_attention_heads_global + self.num_attention_heads_local_decoder = num_attention_heads_local_decoder + self.num_attention_heads_local_encoder = num_attention_heads_local_encoder + self.num_key_value_heads_global = num_key_value_heads_global # Transformer configuration self.norm_eps = norm_eps @@ -552,26 +526,20 @@ def __init__( self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions - # Model behavior - self.weight_tying = weight_tying - # Parameter mixing self.pm_size = pm_size - - # Special token IDs - self.boe_id = boe_id - + # Initialize component configurations self.encoder_config = BLTLocalEncoderConfig( vocab_size=vocab_size, cross_attn_all_layers=cross_attn_all_layers_encoder, cross_attn_k=cross_attn_k, - dim_global=dim_global, + hidden_size_global=hidden_size_global, pm_size=pm_size, - hidden_size=dim_local_encoder, - num_attention_heads=n_heads_local_encoder, + hidden_size=hidden_size_local_encoder, + num_attention_heads=num_attention_heads_local_encoder, num_key_value_heads=num_key_value_heads, - num_hidden_layers=n_layers_local_encoder, + num_hidden_layers=num_hidden_layers_local_encoder, norm_eps=norm_eps, dropout=dropout, max_position_embeddings=max_encoder_seq_length or max_position_embeddings, @@ -585,11 +553,11 @@ def __init__( vocab_size=vocab_size, cross_attn_all_layers=cross_attn_all_layers_decoder, cross_attn_k=cross_attn_k, - dim_global=dim_global, - hidden_size=dim_local_decoder, - num_attention_heads=n_heads_local_decoder, + hidden_size_global=hidden_size_global, + hidden_size=hidden_size_local_decoder, + num_attention_heads=num_attention_heads_local_decoder, num_key_value_heads=num_key_value_heads, - num_hidden_layers=n_layers_local_decoder, + num_hidden_layers=num_hidden_layers_local_decoder, norm_eps=norm_eps, dropout=dropout, max_position_embeddings=max_encoder_seq_length or max_position_embeddings, @@ -600,10 +568,10 @@ def __init__( ) self.global_config = BLTGlobalTransformerConfig( - hidden_size=dim_global, - num_attention_heads=n_heads_global, - num_key_value_heads=n_kv_heads_global, - num_hidden_layers=n_layers_global, + hidden_size=hidden_size_global, + num_attention_heads=num_attention_heads_global, + num_key_value_heads=num_key_value_heads_global, + num_hidden_layers=num_hidden_layers_global, norm_eps=norm_eps, dropout=dropout, max_position_embeddings=max_position_embeddings, @@ -611,13 +579,10 @@ def __init__( rope_scaling={"rope_type": "default"}, hidden_act=hidden_act, multiple_of=multiple_of, - global_dim_patch_emb=dim_local_encoder * cross_attn_k, + global_dim_patch_emb=hidden_size_local_encoder * cross_attn_k, ) - if patcher_args is not None: - self.patcher_config = BLTPatcherConfig(**patcher_args) - else: - self.patcher_config = BLTPatcherConfig() + self.patcher_config = BLTPatcherConfig(**patcher_args) # Handle hash byte group size validation if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str: @@ -629,23 +594,15 @@ def __init__( self.rope_scaling = {"rope_type": "default"} # Set compatibility attributes for transformers - self.num_key_value_heads = n_heads_local_encoder + self.num_key_value_heads = num_attention_heads_local_encoder self.max_position_embeddings = max_position_embeddings - self.hidden_size = dim_local_encoder - self.num_attention_heads = n_heads_local_encoder + self.hidden_size = hidden_size_local_encoder + self.num_attention_heads = num_attention_heads_local_encoder # Calculate intermediate_size using BLTMLP logic for each component # Note: Each component uses its own hidden dimension, not the main dim self.intermediate_size = None # Will be calculated per component - super().__init__( - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - **kwargs, - ) - - __all__ = [ "BLTConfig", diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index ff019f572dfb..10d307214697 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -589,7 +589,7 @@ def forward( layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) hidden_states = layer_outputs[0] - if idx == len(self.layers) - 1 or self.cross_attn_all_layers_encoder: + if idx == len(self.layers) - 1 or self.cross_attn_all_layers: patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size) @@ -655,7 +655,7 @@ def __init__(self, config: BLTLocalDecoderConfig): self.rotary_emb = BLTRotaryEmbedding(config=config) self.patch_embedding_projection = nn.Linear( - in_features=config.dim_global, + in_features=config.hidden_size_global, out_features=config.decoder_dim_token_emb * config.cross_attn_k, bias=False, ) @@ -729,7 +729,7 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.config = config self.layer_idx = layer_idx # Use provided hidden_size or fallback to encoder dimension - self.hidden_size = hidden_size or config.dim_local_encoder + self.hidden_size = hidden_size or config.hidden_size_local_encoder self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention self.head_dim = self.hidden_size // self.num_heads @@ -901,7 +901,7 @@ def _init_weights(self, module): elif isinstance(module, BLTModel): if module.encoder_hash_tok_embedding is not None: - emb_std = module.config.dim_local_encoder ** (-0.5) + emb_std = module.config.hidden_size_local_encoder ** (-0.5) for emb in module.encoder_hash_tok_embedding: emb._custom_std = emb_std @@ -911,10 +911,10 @@ def _init_weights(self, module): elif isinstance(module, BLTLocalDecoder): if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.dim_global ** (-0.5) + module.patch_embedding_projection._custom_std = module.config.hidden_size_global ** (-0.5) elif isinstance(module, BLTPatcher): - emb_std = module.config.dim ** (-0.5) + emb_std = module.config.hidden_size ** (-0.5) module.embed_tokens._custom_std = emb_std module.lm_head._custom_std = emb_std @@ -931,7 +931,7 @@ def __init__(self, config: BLTConfig): self.encoder_hash_tok_embedding = init_hash_embeddings( config, - local_encoder_dim=config.dim_local_encoder, + local_encoder_dim=config.hidden_size_local_encoder, encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, ) @@ -1028,16 +1028,16 @@ def __init__(self, config: BLTPatcherConfig): self.layers = nn.ModuleList() # Create transformer layers using the patcher config - for layer_idx in range(self.config.n_layers): + for layer_idx in range(self.config.num_hidden_layers): self.layers.append(BLTTransformerLayer(self.config, layer_idx)) - self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim) + self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.hidden_size) - self.norm = BLTRMSNorm(self.config.dim, eps=self.config.norm_eps) + self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps) self.lm_head = nn.Linear( - self.config.dim, + self.config.hidden_size, self.config.vocab_size, bias=False, ) @@ -1055,7 +1055,7 @@ def forward( # Handle chunked processing for entropy calculation entropies = [] predictions = [] - max_length = self.config.max_seqlen + max_length = self.config.max_position_embeddings batch_numel = max_length * patching_batch_size splits = torch.split(token_values.flatten(), batch_numel) From 81e1f7856e024f34232bd14991f93be359ea8a39 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 26 Jun 2025 14:44:09 +0000 Subject: [PATCH 045/139] clean up configs --- src/transformers/models/blt_wip/configuration_blt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index 9e3f4e30341c..b6d61195366a 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -464,7 +464,7 @@ def __init__( # Parameter mixing pm_size=0, # Patcher configuration - patcher_args=None, + patcher_args={}, # Inherited **kwargs, ): From 60f57bbc69b11052bd216532aa7f23a4d7e9326f Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 26 Jun 2025 16:09:47 +0000 Subject: [PATCH 046/139] update modular --- ...modeling_blt_modular.py => modular_blt.py} | 1037 ++++++++--------- 1 file changed, 484 insertions(+), 553 deletions(-) rename src/transformers/models/blt_wip/{modeling_blt_modular.py => modular_blt.py} (52%) diff --git a/src/transformers/models/blt_wip/modeling_blt_modular.py b/src/transformers/models/blt_wip/modular_blt.py similarity index 52% rename from src/transformers/models/blt_wip/modeling_blt_modular.py rename to src/transformers/models/blt_wip/modular_blt.py index 0dd0b76d840c..9e78ba1129d9 100644 --- a/src/transformers/models/blt_wip/modeling_blt_modular.py +++ b/src/transformers/models/blt_wip/modular_blt.py @@ -21,6 +21,7 @@ from ...activations import ACT2FN import torch +import torch.distributions import torch.nn import torch.nn as nn from torch.nn import functional as F @@ -30,22 +31,58 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from .configuration_blt import ( BLTConfig, + BLTLocalEncoderConfig, + BLTLocalDecoderConfig, + BLTGlobalTransformerConfig, + BLTPatcherConfig, PatchingModeEnum, ) -from ..mllama.modeling_mllama import MllamaTextRMSNorm, MllamaTextMLP, MllamaTextCrossAttention, MllamaRotaryEmbedding, MllamaTextSelfAttention, MllamaSelfAttentionDecoderLayer, eager_attention_forward, repeat_kv - if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask - from ...integrations.flex_attention import make_flex_block_causal_mask +from ..mllama.modeling_mllama import repeat_kv, eager_attention_forward, MllamaRotaryEmbedding, MllamaTextRMSNorm, MllamaCrossAttentionDecoderLayer, MllamaTextCrossAttention, MllamaTextSelfAttention logger = logging.get_logger(__name__) -# Copied from transformers.models.mllama.modeling_mllama.MllamaTextMLP -class BLTMLP(MllamaTextMLP): - pass + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -80,70 +117,262 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_rot.type_as(q), k_rot.type_as(k) -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText -class BLTRMSNorm(MllamaTextRMSNorm): - pass +class BLTMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + -class BLTTransformerLayer(MllamaSelfAttentionDecoderLayer): - pass +class BLTRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BLTRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention with MllamaText->BLT -class BLTSelfAttention(MllamaTextSelfAttention): - pass + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -def check_non_zero_after_zero(tensor): - zero_mask = tensor == 0 - shifted_mask = torch.cat( - [ - torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), - zero_mask[:, :-1], - ], - dim=1, - ) - non_zero_after_zero = (tensor != 0) & shifted_mask - return non_zero_after_zero.any() +class BLTTransformerLayer(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) + self.mlp = BLTMLP(config) + self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor`, *optional*): + Position indices of tokens in the sequence for RoPE computation. + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class BLTSelfAttention(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.config = config + self.num_heads = config.num_attention_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = self.head_dim ** -0.5 + self.rope_theta = config.rope_theta + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + output_attentions = False + self.config._attn_implementation = "sdpa" + self.scaling = None + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): primes = [ - 1000000007, - 5915587277, - 1500450271, - 3267000013, - 5754853343, - 4093082899, - 9576890767, - 3628273133, - 2860486313, - 5463458053, - 3367900313, + 1000000007, 5915587277, 1500450271, 3267000013, 5754853343, + 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313, ] prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) - prime_powers = torch.stack([prime**i for i in range(token_tensor.shape[-1])]) + powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) + prime_powers = prime ** powers return torch.sum(token_tensor * prime_powers, dim=-1) def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): - """ - Returns a hash of the input token_ids and maps it to a value in the range [0, max_hash]. - - expects: token_ids of shape (batch_size, seq_len) with values as ids in the token vocab. - returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. - - Note: max hash can make a big difference on the number of collisions. - """ + """Hash token groups and map to range [0, max_hash].""" with torch.no_grad(): batch_size, seq_len = token_ids.shape - prefix = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) - token_ids = torch.cat([prefix, token_ids], dim=1) - windows = token_ids.unfold(1, group_size, 1) - # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) + # Add padding for sliding window + padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) + padded_tokens = torch.cat([padding, token_ids], dim=1) + + # Create sliding windows and compute hashes + windows = padded_tokens.unfold(1, group_size, 1) hashes = rolling_polynomial_hash(windows, hash_func_nb) - hash_values_range = hashes % max_hash - hash_values_range.requires_grad = False - return hash_values_range + hash_values = hashes % max_hash + + hash_values.requires_grad = False + return hash_values + + +def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): + """Initialize hash-based token embeddings for the BLT encoder.""" + num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) + embeddings = [ + nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) + for _ in range(num_embeddings) + ] + return nn.ModuleList(embeddings) + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """Compute token embeddings enhanced with hash-based embeddings.""" + embeddings = local_encoder.embed_tokens(local_encoder_tokens) + embedding_idx = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab + ) + embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) + embedding_idx += 1 + + return embeddings def _prepare_patch_cross_attention_mask( @@ -273,53 +502,64 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optiona return padded -class BLTRotaryEmbedding(MllamaRotaryEmbedding): - pass +class BLTRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + self.rope_type = config.rope_scaling["rope_type"] + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class BLTLocalEncoder(nn.Module): - def __init__(self, config: BLTConfig): + def __init__(self, config: BLTLocalEncoderConfig): super().__init__() - - # Extract config values to instance attributes + + self.hidden_size = config.hidden_size + self.vocab_size=config.vocab_size + self.num_hidden_layers = config.num_hidden_layers self.dropout = config.dropout - self.dim_local_encoder = config.dim_local_encoder - self.n_layers_local_encoder = config.n_layers_local_encoder - self.n_heads_local_encoder = config.n_heads_local_encoder - self.vocab_size = config.vocab_size - self.pm_size = config.pm_size - self.cross_attn_encoder = config.cross_attn_encoder - self.cross_attn_nheads = config.cross_attn_nheads - self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder - self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_all_layers = config.cross_attn_all_layers self.cross_attn_k = config.cross_attn_k - self.norm_eps = config.norm_eps - self.sliding_window = config.sliding_window - encoder_config = config.encoder_config - self.layers = nn.ModuleList([BLTTransformerLayer(encoder_config, layer_idx) for layer_idx in range(self.n_layers_local_encoder)]) + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)]) - self.rotary_emb = BLTRotaryEmbedding(config=encoder_config) + self.rotary_emb = BLTRotaryEmbedding(config=config) - self.token_embedding_projection = ( - nn.Linear(config.encoder_dim_token_emb, self.dim_local_encoder, bias=False) - if config.encoder_dim_token_emb is not None and config.encoder_dim_token_emb != self.dim_local_encoder - else None + self.patch_embedding_projection = nn.Linear( + in_features=config.encoder_dim_patch_emb, + out_features=config.encoder_dim_token_emb * config.cross_attn_k, + bias=False, ) - self.patch_embedding_projection = self._create_patch_projection(config) + self.embed_tokens = nn.Embedding(self.vocab_size + config.pm_size, self.hidden_size) - self.embed_tokens = nn.Embedding(self.vocab_size + self.pm_size, self.dim_local_encoder) - - self.cross_attn_layers = None - if self.cross_attn_encoder and self.cross_attn_nheads is not None: - self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.n_layers_local_encoder if self.cross_attn_all_layers_encoder else 1 - for layer_idx in range(layers_to_add): - self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.dim_local_encoder) - ) + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size) + ) def forward( self, @@ -334,34 +574,28 @@ def forward( cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): """ """ - batch_size, sequence_length = input_ids.shape if input_embeds is None: input_embeds = self.embed_tokens(input_ids) batch_size, _, _ = input_embeds.shape - hidden_states = input_embeds - - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(input_embeds, p=self.dropout, training=self.training) position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) - hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) for idx, layer in enumerate(self.layers): layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) hidden_states = layer_outputs[0] - if self.cross_attn_encoder and (idx == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): - # Initialize patch_embeds if not provided when cross attention is enabled - if patch_embeds is None: - patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) - if self.patch_embedding_projection is not None: - patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.dim_local_encoder) + if idx == len(self.layers) - 1 or self.cross_attn_all_layers: + patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size) - layer_idx = idx if self.cross_attn_all_layers_encoder else 0 + layer_idx = idx if self.cross_attn_all_layers else 0 cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, @@ -373,27 +607,9 @@ def forward( ) patch_embeds = patch_embeds + cross_attention_output - encoder_cross_states = patch_embeds if self.cross_attn_encoder else None - return (hidden_states, encoder_cross_states), cache + encoder_cross_states = patch_embeds + return hidden_states, encoder_cross_states - def _create_patch_projection(self, config): - dimension_mismatch = config.encoder_dim_patch_emb is not None and config.encoder_dim_patch_emb != config.dim_local_encoder - - cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( - config.cross_attn_decoder and config.cross_attn_init_by_pooling - ) - - if not (dimension_mismatch or cross_attn_conditions): - return None - - output_dim = config.encoder_dim_token_emb * config.cross_attn_k - - return nn.Linear( - in_features=config.encoder_dim_patch_emb, - out_features=output_dim, - bias=False, - ) - def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): """ Reduce variable length patches to single embedding per patch @@ -405,7 +621,7 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): (i.e. if the sum(patch_lengths[i]) < seq_len for any i) will be sent to a dummy patch, which is trimmed before returning. """ - batch_size, seq_len, embedding_dim = hidden_states.shape + batch_size, _, embedding_dim = hidden_states.shape patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) @@ -423,70 +639,43 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): class BLTLocalDecoder(nn.Module): - def __init__(self, config: BLTConfig): + def __init__(self, config: BLTLocalDecoderConfig): super().__init__() # Extract config values to instance attributes - self.dim_local_decoder = config.dim_local_decoder - self.n_heads_local_decoder = config.n_heads_local_decoder - self.n_layers_local_decoder = config.n_layers_local_decoder - self.vocab_size = config.vocab_size - self.norm_eps = config.norm_eps + self.hidden_size = config.hidden_size + self.vocab_size=config.vocab_size + self.num_hidden_layers = config.num_hidden_layers self.dropout = config.dropout - self.cross_attn_decoder = config.cross_attn_decoder - self.cross_attn_nheads = config.cross_attn_nheads - self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder + self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove + self.cross_attn_all_layers = config.cross_attn_all_layers self.cross_attn_k = config.cross_attn_k - self.sliding_window = config.sliding_window - decoder_config = config.decoder_config - self.layers = nn.ModuleList([BLTTransformerLayer(decoder_config, layer_idx) for layer_idx in range(self.n_layers_local_decoder)]) + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)]) - self.rotary_emb = BLTRotaryEmbedding(config=decoder_config) + self.rotary_emb = BLTRotaryEmbedding(config=config) - self.token_embedding_projection = ( - nn.Linear(config.decoder_dim_token_emb, self.dim_local_decoder, bias=False) - if config.decoder_dim_token_emb is not None and config.decoder_dim_token_emb != self.dim_local_decoder - else None + self.patch_embedding_projection = nn.Linear( + in_features=config.hidden_size_global, + out_features=config.decoder_dim_token_emb * config.cross_attn_k, + bias=False, ) - self.patch_embedding_projection = self._create_patch_projection(config) + self.norm = BLTRMSNorm(self.hidden_size, eps=config.norm_eps) - self.norm = BLTRMSNorm(self.dim_local_decoder, eps=self.norm_eps) - - self.cross_attn_layers = None - if self.cross_attn_decoder and self.cross_attn_nheads is not None: - self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.n_layers_local_decoder if self.cross_attn_all_layers_decoder else 1 - for layer_idx in range(layers_to_add): - self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.dim_local_decoder) - ) + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size) + ) self.lm_head = nn.Linear( - self.dim_local_decoder, + self.hidden_size, self.vocab_size, bias=False, ) - def _create_patch_projection(self, config): - dimension_mismatch = config.dim_global is not None and config.dim_global != config.dim_local_decoder - - # Check cross attention conditions - cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( - config.cross_attn_decoder and config.cross_attn_init_by_pooling - ) - - if not (dimension_mismatch or cross_attn_conditions): - return None - - output_dim = config.decoder_dim_token_emb * config.cross_attn_k - - return nn.Linear( - in_features=config.dim_global, - out_features=output_dim, - bias=False, - ) def forward( self, @@ -498,18 +687,12 @@ def forward( full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - batch_size, sequence_length = tokens.shape - batch_size, seq_length, _ = embeds.shape - - assert embeds is not None, "Embeddings must be provided" + batch_size, _, _ = embeds.shape hidden_states = embeds - if self.patch_embedding_projection is not None: - assert patch_embeds is not None, "Patch embeddings must be passed." - patch_embeds = self.patch_embedding_projection(patch_embeds) - if self.cross_attn_k is not None: - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.dim_local_decoder) + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size) if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds @@ -519,7 +702,7 @@ def forward( hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): - if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder): + if i == 0 or self.cross_attn_all_layers: # Use cross attention to extract info from patch_embeds into hidden_states cross_attention_output, _, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, @@ -547,9 +730,9 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.config = config self.layer_idx = layer_idx # Use provided hidden_size or fallback to encoder dimension - self.hidden_size = hidden_size or config.dim_local_encoder - self.num_heads = config.cross_attn_nheads - self.num_key_value_heads = config.cross_attn_nheads # Assuming same for cross attention + self.hidden_size = hidden_size or config.hidden_size_local_encoder + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.scaling = None #self.head_dim ** -0.5 @@ -604,6 +787,8 @@ def forward( attention_interface: Callable = eager_attention_forward + self.config._attn_implementation = "sdpa" + attn = "sdpa" if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -644,47 +829,33 @@ def forward( class BLTGlobalTransformer(nn.Module): - def __init__(self, config): + def __init__(self, config: BLTGlobalTransformerConfig): super().__init__() - self.dim_global = config.dim_global - self.n_heads_global = config.n_heads_global - self.n_layers_global = config.n_layers_global + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers self.dropout = config.dropout - global_config = config.global_config self.layers = nn.ModuleList() - for layer_idx in range(self.n_layers_global): - self.layers.append(BLTTransformerLayer(global_config, layer_idx)) + for layer_idx in range(self.num_hidden_layers): + self.layers.append(BLTTransformerLayer(config, layer_idx)) - self.rotary_emb = BLTRotaryEmbedding(config=global_config) + self.rotary_emb = BLTRotaryEmbedding(config=config) - self.token_embedding_projection = None - if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim_global: - self.token_embedding_projection = nn.Linear( - config.global_dim_patch_emb, - self.dim_global, - bias=False, - ) def forward( self, - input_ids: torch.Tensor, - tok_idx: Optional[torch.Tensor] = None, - input_embeds: Optional[torch.Tensor] = None, + input_embeds: torch.Tensor, mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): - batch_size, seq_length, _ = input_embeds.shape + batch_size, seq_len, _ = input_embeds.shape hidden_states = input_embeds - if self.token_embedding_projection is not None and hidden_states.shape[-1] != self.dim_global: - hidden_states = self.token_embedding_projection(hidden_states) - hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) - position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, layer in enumerate(self.layers): @@ -694,48 +865,6 @@ def forward( return hidden_states, cache -def compute_hash_embeddings( - local_encoder_tokens: torch.Tensor, - local_encoder, - encoder_hash_tok_embedding: nn.ModuleList, - encoder_hash_byte_group_nb_functions: int, - encoder_hash_byte_group_size: list, - encoder_hash_byte_group_vocab: int, -) -> torch.Tensor: - """ - Compute embeddings using hash token embeddings. - - Args: - local_encoder_tokens: Input tokens tensor - local_encoder: Encoder object with embed_tokens method - encoder_hash_tok_embedding: ModuleList of hash token embeddings - encoder_hash_byte_group_nb_functions: Number of hash functions - encoder_hash_byte_group_size: List of byte group sizes - encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings - - Returns: - torch.Tensor: Combined embeddings - """ - if encoder_hash_tok_embedding is None: - return None - - local_encoder_embeds = local_encoder.embed_tokens(local_encoder_tokens) - - i = 0 - for func_nb in range(encoder_hash_byte_group_nb_functions): - for byte_group_size in encoder_hash_byte_group_size: - hash_ids = byte_group_hash_function( - local_encoder_tokens, - byte_group_size, - hash_func_nb=func_nb, - max_hash=encoder_hash_byte_group_vocab, - ) - hash_tok_embedding = encoder_hash_tok_embedding[i] - local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) - i += 1 - - assert i == len(encoder_hash_tok_embedding) - return local_encoder_embeds class BLTPreTrainedModel(PreTrainedModel): @@ -751,7 +880,6 @@ class BLTPreTrainedModel(PreTrainedModel): def _init_weights(self, module): if isinstance(module, nn.Linear): std = getattr(module, '_custom_std', module.in_features ** (-0.5)) - nn.init.trunc_normal_( module.weight, mean=0.0, @@ -764,7 +892,6 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) - nn.init.trunc_normal_( module.weight, mean=0.0, @@ -775,30 +902,20 @@ def _init_weights(self, module): elif isinstance(module, BLTModel): if module.encoder_hash_tok_embedding is not None: - emb_std = module.config.dim_local_encoder ** (-0.5) + emb_std = module.config.hidden_size_local_encoder ** (-0.5) for emb in module.encoder_hash_tok_embedding: emb._custom_std = emb_std elif isinstance(module, BLTLocalEncoder): - if module.token_embedding_projection is not None: - module.token_embedding_projection._custom_std = module.config.dim_local_encoder ** (-0.5) - if module.patch_embedding_projection is not None: module.patch_embedding_projection._custom_std = module.config.encoder_dim_patch_emb ** (-0.5) elif isinstance(module, BLTLocalDecoder): - if module.token_embedding_projection is not None: - module.token_embedding_projection._custom_std = module.config.dim_local_decoder ** (-0.5) - if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.dim_global ** (-0.5) - - elif isinstance(module, BLTGlobalTransformer): - if module.token_embedding_projection is not None: - module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5) + module.patch_embedding_projection._custom_std = module.config.hidden_size_global ** (-0.5) elif isinstance(module, BLTPatcher): - emb_std = module.config.dim ** (-0.5) + emb_std = module.config.hidden_size ** (-0.5) module.embed_tokens._custom_std = emb_std module.lm_head._custom_std = emb_std @@ -807,116 +924,62 @@ class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) - # Extract frequently used config values - self.patch_in_forward = config.patch_in_forward - self.patching_mode = config.patching_mode - self.patch_size = config.patch_size - self.patching_threshold = config.patching_threshold - self.max_patch_length = config.max_patch_length - self.patching_batch_size = config.patching_batch_size - self.patching_device = config.patching_device - self.cross_attn_encoder = config.cross_attn_encoder - self.cross_attn_decoder = config.cross_attn_decoder - self.cross_attn_k = config.cross_attn_k - self.cross_attn_window_encoder = config.cross_attn_window_encoder - self.cross_attn_window_decoder = config.cross_attn_window_decoder - self.boe_id = config.boe_id - self.eos_token_id = config.eos_token_id + self.config = config - self.local_encoder = BLTLocalEncoder(config) - self.global_transformer = BLTGlobalTransformer(config) - self.local_decoder = BLTLocalDecoder(config) + self.local_encoder = BLTLocalEncoder(config.encoder_config) + self.global_transformer = BLTGlobalTransformer(config.global_config) + self.local_decoder = BLTLocalDecoder(config.decoder_config) self.encoder_hash_tok_embedding = init_hash_embeddings( config, - local_encoder_dim=config.dim_local_encoder, + local_encoder_dim=config.hidden_size_local_encoder, encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, ) - if self.patch_in_forward: - self.patcher = BLTPatcher(config) + if self.config.patch_in_forward: + self.patcher = BLTPatcher(config.patcher_config) self.patcher.eval() for param in self.patcher.parameters(): param.requires_grad = False else: self.patcher = None - def forward( - self, - tokens: torch.Tensor, - patch_lengths: Optional[torch.Tensor] = None, - ): - # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings - # are no longer used in the final BLT model - - batch_size, sequence_length = tokens.shape # Batch size and sequence length - - local_encoder_tokens, local_decoder_tokens = tokens, tokens + def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None): + batch_size, sequence_length = tokens.shape - # Patching + # Handle patching if patch_lengths is None: - # assert ( - # getattr(self, "patch_in_forward", None) is not None and self.patch_in_forward - # ), "Patch in forward not enabled and no patch_lengths passed." - - # PATCHER MODEL DEFINED - if self.patching_mode == PatchingModeEnum.entropy: + if self.config.patching_mode == PatchingModeEnum.entropy: _, patch_lengths, _ = self.patcher( - local_encoder_tokens, - patch_size=self.patch_size, - include_next_token=True, - threshold=self.patching_threshold, - max_patch_length=self.max_patch_length, - patching_batch_size=self.patching_batch_size, - device=self.patching_device, + tokens, + patch_size=self.config.patch_size, + threshold=self.config.patching_threshold, + max_patch_length=self.config.max_patch_length, + patching_batch_size=self.config.patching_batch_size, + device=self.config.patching_device, ) else: - # self.patching_mode == PatchingModeEnum.byte - batch_size_tokens, seq_len = local_encoder_tokens.shape - seq_len_next_tok = seq_len + 1 # include_next_token=True - patch_lengths = torch.ones( - (batch_size_tokens, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + # Default to byte-level patching + patch_lengths = process_patch_lengths( + torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device), + self.config.max_patch_length ) - patch_lengths = process_patch_lengths(patch_lengths, self.max_patch_length) - - #assert torch.min(patch_lengths) >= 0 - # Generate patch IDs from patch_lengths - patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1]) - # assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), ( - # f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" - # ) - - # Cross-attention encoder - cross_attn_mask_enc = None - full_text_row_masked_out_mask_enc = None - if self.cross_attn_encoder: - cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( - patch_ids=patch_ids, - num_patches=patch_lengths.shape[1], - sequence_length=sequence_length, - patches_as_queries=True, - cross_attn_k=self.cross_attn_k, - dtype=torch.float32, - ) - - # Hashing and embedding - local_encoder_embeds = compute_hash_embeddings( - local_encoder_tokens=local_encoder_tokens, - local_encoder=self.local_encoder, - encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, - encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions, - encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size, - encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab, + patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) + cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( + patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32 ) - # NOTE: Frequency-based n-gram embeddings removed as per paper - # The final BLT model uses only hash-based n-gram embeddings + encoder_embeds = compute_hash_embeddings( + tokens, self.local_encoder, self.encoder_hash_tok_embedding, + self.config.encoder_hash_byte_group_nb_functions, + self.config.encoder_hash_byte_group_size, + self.config.encoder_hash_byte_group_vocab, + ) - # Local encoder - (encoder_hidden_states, encoder_cross_states), cache_encoder = self.local_encoder( - input_ids=local_encoder_tokens, - input_embeds=local_encoder_embeds, + encoder_hidden_states, encoder_cross_states = self.local_encoder( + input_ids=tokens, + input_embeds=encoder_embeds, patch_embeds=None, cross_mask=cross_attn_mask_enc, full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, @@ -924,126 +987,58 @@ def forward( patch_ids=patch_ids, ) - # Downsampling - if encoder_cross_states is not None: - # Cross attention is enabled - use cross states - global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - else: - # Cross attention is disabled - use reduced embeddings from encoder hidden states - global_hidden_states = self.local_encoder.patch_reduce( - encoder_hidden_states, patch_lengths.shape[1], "amax", patch_ids - ) - - # Global transformer - global_tokens = tokens.new(global_hidden_states.shape[0], global_hidden_states.shape[1]).fill_(self.boe_id) - rows, cols = torch.where(local_encoder_tokens == self.eos_token_id) - eos_patch_ids = patch_ids[rows, cols] - global_tokens[rows, eos_patch_ids] = self.eos_token_id + global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) global_hidden_states, _ = self.global_transformer( input_embeds=global_hidden_states, - input_ids=global_tokens, ) - # Unpatching - - decoder_embeds = encoder_hidden_states - - # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens), so we need to map decoder positions to the remaining patches. - decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], local_decoder_tokens.shape[-1]) - # assert torch.max(decoder_patch_ids) + 1 <= global_hidden_states.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {global_hidden_states.shape[1]}" - # assert decoder_patch_ids.shape[1] == decoder_embeds.shape[1], ( - # f"{decoder_patch_ids.shape[1]} != {decoder_embeds.shape[1]}" - # ) - - # Cross-attention decoder - cross_attn_mask_dec = None - full_text_row_masked_out_mask_dec = None - if not self.cross_attn_decoder: - patch_hidden_states = torch.gather(global_hidden_states, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, global_hidden_states.shape[-1])) - # assert local_decoder_tokens.shape == patch_hidden_states.shape[:-1] - else: - patch_hidden_states = global_hidden_states - cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( - patch_ids=decoder_patch_ids, - num_patches=patch_lengths.shape[1], - sequence_length=sequence_length, - patches_as_queries=False, - cross_attn_k=self.cross_attn_k, - dtype=torch.float32, - ) + decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) + cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( + decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32 + ) - # Local decoder output, _ = self.local_decoder( - embeds=decoder_embeds, - patch_embeds=patch_hidden_states, - tokens=local_decoder_tokens, + tokens=tokens, + embeds=encoder_hidden_states, + patch_embeds=global_hidden_states, + mask=None, cross_mask=cross_attn_mask_dec, full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, ) + return output def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: - """ - Convert patch lengths to patch IDs for each token position. - For each token position in the sequence, determines which patch it belongs to. - - Args: - patch_lengths: [batch_size, num_patches] - length of each patch - seq_len: total sequence length - - Returns: - patch_ids: [batch_size, seq_len] - patch index for each token position - - Example: - patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1 - seq_len = 10 - Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]] - # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3 - """ - batch_size, num_patches = patch_lengths.shape - - # Create patch start positions: [0, 3, 5, 9] for the example above - patch_starts = torch.cat( - [ - torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), - patch_lengths.cumsum(dim=-1)[:, :-1], # cumsum without the final total - ], - dim=-1, - ) - - # For each token position, find which patch it belongs to - # by finding the rightmost patch start that's <= the position - token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1] - - # Broadcasting: patch_starts[batch, patch] <= token_positions[position] - # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t - position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1) - - # Count how many patch starts are <= each position, then subtract 1 to get patch index - patch_ids = position_ge_patch_start.sum(dim=-1) - 1 - - return patch_ids + """Convert patch lengths to patch IDs for each token position.""" + batch_size = patch_lengths.shape[0] + patch_starts = torch.cat([ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1] + ], dim=-1) + + token_positions = torch.arange(seq_len, device=patch_lengths.device) + return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 class BLTPatcher(BLTPreTrainedModel): - def __init__(self, config): - super().__init__(config.patcher_config) + def __init__(self, config: BLTPatcherConfig): + super().__init__(config) self.rotary_emb = BLTRotaryEmbedding(config=self.config) self.layers = nn.ModuleList() # Create transformer layers using the patcher config - for layer_idx in range(self.config.n_layers): + for layer_idx in range(self.config.num_hidden_layers): self.layers.append(BLTTransformerLayer(self.config, layer_idx)) - self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.dim) + self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.hidden_size) - self.norm = BLTRMSNorm(self.config.dim, eps=self.config.norm_eps) + self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps) self.lm_head = nn.Linear( - self.config.dim, + self.config.hidden_size, self.config.vocab_size, bias=False, ) @@ -1052,7 +1047,6 @@ def forward( self, token_values: torch.Tensor, patch_size: Optional[int] = None, - include_next_token: bool = True, threshold: Optional[float] = None, max_patch_length: Optional[int] = None, patching_batch_size: int = 1, @@ -1062,7 +1056,7 @@ def forward( # Handle chunked processing for entropy calculation entropies = [] predictions = [] - max_length = self.config.max_seqlen + max_length = self.config.max_position_embeddings batch_numel = max_length * patching_batch_size splits = torch.split(token_values.flatten(), batch_numel) @@ -1093,7 +1087,7 @@ def forward( logits = self.lm_head(self.norm(hidden_states)) logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] predictions.append(logits) - prediction_entropies = self.entropy(logits) + prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() entropies.append(prediction_entropies) concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) @@ -1101,142 +1095,79 @@ def forward( # Always compute patch lengths from concatenated entropies batch_size, sequence_length = token_values.shape - seq_len_next_tok = sequence_length + 1 if include_next_token else sequence_length # Find patch start IDs based on entropy if patch_size is not None: - patch_start_ids = self.find_entropy_patch_start_ids( - concat_entropies, - patch_size, - include_next_token=include_next_token, - threshold=threshold + patch_lengths = self.patch_lengths_from_entropies( + entropies=concat_entropies, + sequence_length=sequence_length, + patch_size=patch_size, + threshold=threshold, ) - patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok) else: # Default to byte-level patching - patch_lengths = torch.ones((batch_size, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) - + patch_lengths = torch.ones((batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device) patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) return concat_entropies, patch_lengths, concat_predictions - - @staticmethod - def entropy(scores): - """ - scores: [batch_size, seq_len, vocab] - returns [batch_size, seq_len] - - Computes the entropy for each token in the batch. - Note: uses natural log. - """ - log_probs = F.log_softmax(scores, dim=-1) - probs = torch.exp(log_probs) - p_log_p = log_probs * probs - entropy = -p_log_p.sum(dim=-1) - return entropy - - @staticmethod - def patch_start_ids_from_patch_start_mask(patch_start_mask): - batch_size, trunc_seq_len = patch_start_mask.shape - max_patches = patch_start_mask.sum(dim=1).max() - if max_patches == 0: - patch_start_ids = torch.full( - (batch_size, trunc_seq_len), - trunc_seq_len, - dtype=torch.long, - device=patch_start_mask.device, - ) - else: - patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(batch_size, 1) - extra_patch_ids = torch.full( - (batch_size, trunc_seq_len), - trunc_seq_len, - dtype=torch.long, - device=patch_start_mask.device, - ) - all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) - patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1) - patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(batch_size, trunc_seq_len)[:, :max_patches] - return patch_start_ids - - @staticmethod - def patch_lengths_from_start_ids(patch_start_ids, seq_len): - """ - Calculate patch lengths from start ids. - start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then - the rest are filled to the seq len. - seq_len: ex: 7 length of the sequence - - returns the patch lengths: - [1, 6] for the above example. - """ - last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) - patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) - patch_lengths = patch_end_ids - patch_start_ids + 1 - assert torch.all(patch_lengths >= 0), f"{patch_lengths}" - assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" - return patch_lengths - @staticmethod - def find_entropy_patch_start_ids( + def patch_lengths_from_entropies( entropies, + sequence_length, patch_size=None, threshold=None, - include_next_token=True, ): """ - Use entropies to find the start ids of each patch. - Use patch_size or threshold to figure out the total number of patches to allocate. + Computes patch lengths from token entropies. - When threshold is not None the number of patches is not constant between - different sequences, but patches can be identified incrementally rather than - decided globally using the entire sequence. + Depending on whether a threshold is provided, the function uses either: + - Top-k selection based on entropy (when `threshold` is None), or + - Thresholding the entropy values (when `threshold` is set). """ - batch_size, sequence_length = entropies.shape[:2] - first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) - predictions_truncation_len = first_ids.shape[1] # remove the first predictions because they will be start of patches. + batch_size = entropies.shape[0] + + # Always include token 0 and 1 as starting tokens + init_tokens = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + offset = init_tokens.shape[1] + + # Ignore first token entropy (BOS) entropies = entropies[:, 1:] + if threshold is None: + # Use top-k entropy values to define patch start points num_patches = sequence_length // patch_size - patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices - patch_start_ids = patch_start_ids.sort(dim=1).values + topk_indices = entropies.topk(num_patches - 2, dim=1).indices + patch_starts = topk_indices.sort(dim=1).values else: - patch_start_mask = entropies > threshold - if not include_next_token: - patch_start_mask = patch_start_mask[:, :-1] - # patch_start_mask[1:] |= tokens[:-1] < OFFSET - patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) - - patch_start_ids = torch.cat((first_ids, patch_start_ids + predictions_truncation_len), dim=1) - return patch_start_ids - -def init_hash_embeddings( - config, - local_encoder_dim: int, - encoder_hash_byte_group_size: list, -): - """Initialize hash-based token embeddings for the BLT encoder.""" - if config.encoder_hash_byte_group_size is None: - return None - - embeddings = [] - emb_dim = local_encoder_dim - encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab - - for _ in range(config.encoder_hash_byte_group_nb_functions): - for _ in encoder_hash_byte_group_size: - embeddings.append( - nn.Embedding( - encoder_hash_byte_group_vocab, - emb_dim, - ) - ) + # Threshold the entropy values to define patch start points + patch_mask = entropies > threshold - return nn.ModuleList(embeddings) + seq_len = patch_mask.shape[1] + + # Create patch IDs (token indices), and add a sentinel to ensure alignment + token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) + sentinel = torch.full_like(token_indices, seq_len) + padded_indices = torch.cat([token_indices, sentinel], dim=1) + # Pad mask with inverse to align sentinel correctly + padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) + # Select indices where mask is True + patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) + max_valid_patches = patch_mask.sum(dim=1).max() + patch_starts = patch_starts[:, :max_valid_patches] + # Offset patch starts to account for the two initial tokens + patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) + + # Compute patch end positions by shifting start positions + last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1) + patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1) + + patch_lengths = patch_ends - patch_start_ids + 1 + + return patch_lengths __all__ = [ "BLTPreTrainedModel", From c2a90995970a3034607bb11deac37bb50eb0a2b5 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 30 Jun 2025 10:22:14 +0000 Subject: [PATCH 047/139] clean --- .../models/blt_wip/configuration_blt.py | 18 +------------ .../models/blt_wip/modeling_blt.py | 25 +++++++++++-------- .../models/blt_wip/modular_blt.py | 12 ++++----- 3 files changed, 22 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index b6d61195366a..ca88ae89b738 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -28,15 +28,6 @@ class InitStdFactor(str, Enum): CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) -class PatchingModeEnum(str, Enum): - entropy = "entropy" - bpe = "bpe" - bpe_patcher = "bpe_patcher" - space = "space" - static = "static" - byte = "byte" - - class BLTLocalEncoderConfig(PretrainedConfig): """ Configuration class for the BLT Local Encoder component. @@ -241,7 +232,6 @@ def __init__( ffn_dim_multiplier=None, multiple_of=256, rope_theta=10000.0, - rope_use_fp32_in_outer_product=False, attn_impl="sdpa", attn_bias_type="causal", bos_token_id=1, @@ -269,13 +259,10 @@ def __init__( **kwargs, ) - # Add attributes needed for compatibility with transformer models self.hidden_act = "silu" # BLT uses silu activation - - # Calculate intermediate_size using BLTMLP logic based on actual hidden_size + self.intermediate_size = multiple_of * ((int(8 * self.hidden_size / 3) + multiple_of - 1) // multiple_of) - # Set simple rope scaling for patcher (no complex dynamic rope) self.rope_scaling = {"rope_type": "default"} @@ -442,7 +429,6 @@ def __init__( # Positional encoding rope_theta=10000.0, # Attention configuration - attn_impl="sdpa", _attn_implementation="sdpa", # Patching configuration patch_in_forward=False, @@ -503,7 +489,6 @@ def __init__( self.rope_theta = rope_theta # Attention configuration - self.attn_impl = attn_impl self._attn_implementation = _attn_implementation # Patching configuration @@ -611,6 +596,5 @@ def __init__( "BLTLocalDecoderConfig", "BLTGlobalTransformerConfig", "InitStdFactor", - "PatchingModeEnum" ] diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 10d307214697..d0e857bddc9b 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -17,6 +17,8 @@ from ...utils import is_torch_flex_attn_available, logging from typing import Callable, List, Optional, Tuple, Union +from enum import Enum + from ...cache_utils import Cache from ...activations import ACT2FN @@ -35,7 +37,6 @@ BLTLocalDecoderConfig, BLTGlobalTransformerConfig, BLTPatcherConfig, - PatchingModeEnum, ) if is_torch_flex_attn_available(): @@ -46,6 +47,15 @@ logger = logging.get_logger(__name__) +class PatchingModeEnum(str, Enum): + entropy = "entropy" + bpe = "bpe" + bpe_patcher = "bpe_patcher" + space = "space" + static = "static" + byte = "byte" + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -149,9 +159,6 @@ def forward(self, hidden_states): hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - class BLTTransformerLayer(nn.Module): def __init__(self, config, layer_idx: int): @@ -242,7 +249,7 @@ def __init__(self, config, layer_idx: int): self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim ** -0.5 + self.scaling = None self.rope_theta = config.rope_theta self.layer_idx = layer_idx @@ -284,7 +291,6 @@ def forward( attention_interface: Callable = eager_attention_forward output_attentions = False self.config._attn_implementation = "sdpa" - self.scaling = None if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -787,7 +793,6 @@ def forward( attention_interface: Callable = eager_attention_forward self.config._attn_implementation = "sdpa" - attn = "sdpa" if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -1077,14 +1082,14 @@ def forward( position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) # = BLT self.rope + position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, layer in enumerate(self.layers): - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) #, attn_impl=self.config.patcher_attn_impl ) + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) hidden_states = layer_outputs[0] logits = self.lm_head(self.norm(hidden_states)) - logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] + logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] predictions.append(logits) prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() entropies.append(prediction_entropies) diff --git a/src/transformers/models/blt_wip/modular_blt.py b/src/transformers/models/blt_wip/modular_blt.py index 9e78ba1129d9..f433e2c8b799 100644 --- a/src/transformers/models/blt_wip/modular_blt.py +++ b/src/transformers/models/blt_wip/modular_blt.py @@ -243,7 +243,7 @@ def __init__(self, config, layer_idx: int): self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim ** -0.5 + self.scaling = None self.rope_theta = config.rope_theta self.layer_idx = layer_idx @@ -284,8 +284,8 @@ def forward( attention_interface: Callable = eager_attention_forward output_attentions = False - self.config._attn_implementation = "sdpa" - self.scaling = None + # self.config._attn_implementation = "sdpa" + # self.scaling = None if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -735,7 +735,7 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = None #self.head_dim ** -0.5 + self.scaling = None self.dropout = config.dropout self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) @@ -787,8 +787,8 @@ def forward( attention_interface: Callable = eager_attention_forward - self.config._attn_implementation = "sdpa" - attn = "sdpa" + # self.config._attn_implementation = "sdpa" + # attn = "sdpa" if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( From 6c0b8d2fe1600c418f04c8f0e224efb2a2e8228d Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 30 Jun 2025 10:27:47 +0000 Subject: [PATCH 048/139] update demo --- src/demo_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index f43bac62fb6f..b048823ea2f6 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -101,7 +101,7 @@ def generate( def main(prompt: str = "my name is", model_name: str = "blt-1b"): device = "cuda" - blt_repo = "itazap/blt-1b-converted" + blt_repo = "itazap/blt-1b-hf" model = BLTModel.from_pretrained(blt_repo).to(device) tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) From 9c7139813b5f6dd6961d81f203f70f39e06ddd2e Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 30 Jun 2025 15:12:33 +0000 Subject: [PATCH 049/139] config more like mllama, seperated subconfigs from subdicts --- src/convert_blt_to_hf.py | 146 ++++++-- src/demo_hf.py | 2 +- .../models/blt_wip/configuration_blt.py | 348 ++++-------------- .../models/blt_wip/modeling_blt.py | 28 +- 4 files changed, 188 insertions(+), 336 deletions(-) diff --git a/src/convert_blt_to_hf.py b/src/convert_blt_to_hf.py index 3b8eb79baba3..e0cd418183a8 100644 --- a/src/convert_blt_to_hf.py +++ b/src/convert_blt_to_hf.py @@ -39,57 +39,125 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str if isinstance(patch_size, float): patch_size = int(patch_size) + # Create patcher config + patcher_hidden_size = int(entropy_model_params.get("dim", 512)) + patcher_multiple_of = int(entropy_model_params.get("multiple_of", 256)) + patcher_intermediate_size = patcher_multiple_of * ((int(8 * patcher_hidden_size / 3) + patcher_multiple_of - 1) // patcher_multiple_of) + patcher_config = { "vocab_size": int(entropy_model_params.get("vocab_size", 256)), - "dim": int(entropy_model_params.get("dim", 512)), - "n_layers": int(entropy_model_params.get("n_layers", 8)), - "n_heads": int(entropy_model_params.get("n_heads", 8)), - "head_dim": int(entropy_model_params.get("head_dim")) - if entropy_model_params.get("head_dim") is not None - else None, - "n_kv_heads": int(entropy_model_params.get("n_kv_heads")) + "hidden_size": patcher_hidden_size, + "num_hidden_layers": int(entropy_model_params.get("n_layers", 8)), + "num_attention_heads": int(entropy_model_params.get("n_heads", 8)), + "num_key_value_heads": int(entropy_model_params.get("n_kv_heads")) if entropy_model_params.get("n_kv_heads") is not None else None, - "max_seqlen": int(entropy_model_params.get("max_seqlen", 1024)), + "max_position_embeddings": int(entropy_model_params.get("max_seqlen", 1024)), "norm_eps": entropy_model_params.get("norm_eps", 1e-5), "dropout": entropy_model_params.get("dropout", 0.0), - "sliding_window": int(entropy_model_params.get("sliding_window", 512)) - if entropy_model_params.get("sliding_window") is not None - else None, - "ffn_dim_multiplier": entropy_model_params.get("ffn_dim_multiplier"), - "multiple_of": int(entropy_model_params.get("multiple_of", 256)), "rope_theta": entropy_model_params.get("rope_theta", 10000.0), - "rope_use_fp32_in_outer_product": entropy_model_params.get( - "rope_use_fp32_in_outer_product", False - ), "attn_impl": entropy_model_params.get("attn_impl", "sdpa"), "attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"), - "init_base_std": entropy_model_params.get("init_base_std"), - "init_std_factor": entropy_model_params.get("init_std_factor", "disabled"), - "dim_token_emb": entropy_model_params.get("dim_token_emb"), - "weight_tying": entropy_model_params.get("weight_tying", False), - "bos_token_id": entropy_model_params.get("bos_token_id", 1), - "eos_token_id": entropy_model_params.get("eos_token_id", 2), + "intermediate_size": patcher_intermediate_size, } - unified_config.update( - { - "patch_in_forward": True, - "realtime_patching": True, - "patching_mode": "entropy", - "patch_size": patch_size, - "patching_threshold": patcher_args.get("threshold", 0.5), - "patching_threshold_add": patcher_args.get("threshold_add", 0.0), - "max_patch_length": patcher_args.get("max_patch_length"), - "patching_batch_size": patcher_args.get("patching_batch_size", 1), - "patching_device": patcher_args.get("patching_device", "cuda"), - "monotonicity": patcher_args.get("monotonicity", False), - "patcher_args": patcher_config, - } - ) + # Create encoder config + encoder_hidden_size = unified_config.get("dim_local_encoder", 1024) + encoder_multiple_of = unified_config.get("multiple_of", 256) + encoder_intermediate_size = encoder_multiple_of * ((int(8 * encoder_hidden_size / 3) + encoder_multiple_of - 1) // encoder_multiple_of) + + encoder_config = { + "vocab_size": unified_config.get("vocab_size", 256), + "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_encoder", False), + "cross_attn_k": unified_config.get("cross_attn_k", 2), + "hidden_size_global": unified_config.get("hidden_size_global", 2048), + "pm_size": unified_config.get("pm_size", 0), + "hidden_size": encoder_hidden_size, + "num_attention_heads": unified_config.get("n_heads_local_encoder", 16), + "num_key_value_heads": unified_config.get("n_kv_heads"), + "num_hidden_layers": unified_config.get("n_layers_local_encoder", 1), + "norm_eps": unified_config.get("norm_eps", 1e-5), + "dropout": unified_config.get("dropout", 0.0), + "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024), + "rope_theta": unified_config.get("rope_theta", 10000.0), + "rope_scaling": {"rope_type": "default"}, + "hidden_act": unified_config.get("hidden_act", "silu"), + "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), + "intermediate_size": encoder_intermediate_size, + } + + # Create decoder config + decoder_hidden_size = unified_config.get("dim_local_decoder", 1024) + decoder_multiple_of = unified_config.get("multiple_of", 256) + decoder_intermediate_size = decoder_multiple_of * ((int(8 * decoder_hidden_size / 3) + decoder_multiple_of - 1) // decoder_multiple_of) + + decoder_config = { + "vocab_size": unified_config.get("vocab_size", 256), + "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_decoder", False), + "cross_attn_k": unified_config.get("cross_attn_k", 2), + "hidden_size_global": unified_config.get("hidden_size_global", 2048), + "hidden_size": decoder_hidden_size, + "num_attention_heads": unified_config.get("n_heads_local_decoder", 16), + "num_key_value_heads": unified_config.get("n_kv_heads"), + "num_hidden_layers": unified_config.get("n_layers_local_decoder", 9), + "norm_eps": unified_config.get("norm_eps", 1e-5), + "dropout": unified_config.get("dropout", 0.0), + "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024), + "rope_theta": unified_config.get("rope_theta", 10000.0), + "rope_scaling": {"rope_type": "default"}, + "hidden_act": unified_config.get("hidden_act", "silu"), + "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), + "intermediate_size": decoder_intermediate_size, + } + + # Create global transformer config + global_hidden_size = unified_config.get("dim_global", 2048) + global_multiple_of = unified_config.get("multiple_of", 256) + global_intermediate_size = global_multiple_of * ((int(8 * global_hidden_size / 3) + global_multiple_of - 1) // global_multiple_of) + + global_config = { + "hidden_size": global_hidden_size, + "num_attention_heads": unified_config.get("n_heads_global", 16), + "num_key_value_heads": unified_config.get("n_kv_heads_global"), + "num_hidden_layers": unified_config.get("n_layers_global", 25), + "norm_eps": unified_config.get("norm_eps", 1e-5), + "dropout": unified_config.get("dropout", 0.0), + "max_position_embeddings": unified_config.get("max_seqlen", 1024), + "rope_theta": unified_config.get("rope_theta", 10000.0), + "rope_scaling": {"rope_type": "default"}, + "hidden_act": unified_config.get("hidden_act", "silu"), + "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), + "intermediate_size": global_intermediate_size, + } + + # Create main config with sub-configs + main_config_dict = { + "model_type": "blt", + "vocab_size": unified_config.get("vocab_size", 256), + "max_position_embeddings": unified_config.get("max_seqlen", 1024), + "patch_in_forward": True, + "realtime_patching": True, + "patching_mode": "entropy", + "patch_size": patch_size, + "patching_threshold": patcher_args.get("threshold", 0.5), + "patching_threshold_add": patcher_args.get("threshold_add", 0.0), + "max_patch_length": patcher_args.get("max_patch_length"), + "patching_batch_size": patcher_args.get("patching_batch_size", 1), + "patching_device": patcher_args.get("patching_device", "cuda"), + "monotonicity": patcher_args.get("monotonicity", False), + "cross_attn_k": unified_config.get("cross_attn_k", 2), + "encoder_hash_byte_group_size": unified_config.get("encoder_hash_byte_group_size"), + "encoder_hash_byte_group_vocab": unified_config.get("encoder_hash_byte_group_vocab", 30000), + "encoder_hash_byte_group_nb_functions": unified_config.get("encoder_hash_byte_group_nb_functions", 3), + "pm_size": unified_config.get("pm_size", 0), + "patcher_config": patcher_config, + "encoder_config": encoder_config, + "decoder_config": decoder_config, + "global_config": global_config, + } - logger.info(f"Merged configuration with {len(unified_config)} parameters") - return unified_config + logger.info(f"Merged configuration with {len(main_config_dict)} parameters") + return main_config_dict def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: diff --git a/src/demo_hf.py b/src/demo_hf.py index b048823ea2f6..b5856cc7f39e 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -101,7 +101,7 @@ def generate( def main(prompt: str = "my name is", model_name: str = "blt-1b"): device = "cuda" - blt_repo = "itazap/blt-1b-hf" + blt_repo = "itazap/blt-1b-converted" model = BLTModel.from_pretrained(blt_repo).to(device) tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index ca88ae89b738..6d5990cdb57a 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -15,6 +15,7 @@ """BLT model configuration""" from enum import Enum +from typing import Union from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -22,12 +23,6 @@ logger = logging.get_logger(__name__) - -class InitStdFactor(str, Enum): - DISABLED = "disabled" # Init std is divided by 1.0 - CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) - - class BLTLocalEncoderConfig(PretrainedConfig): """ Configuration class for the BLT Local Encoder component. @@ -41,12 +36,9 @@ def __init__( cross_attn_all_layers=True, cross_attn_k=2, hidden_size_global=2048, - pm_size=0, hidden_size=512, num_attention_heads=8, num_key_value_heads=None, - head_dim=None, - intermediate_size=None, num_hidden_layers=8, norm_eps=1e-5, dropout=0.0, @@ -54,7 +46,7 @@ def __init__( rope_theta=10000.0, rope_scaling=None, hidden_act="silu", - multiple_of=256, + intermediate_size=None, _attn_implementation="sdpa", **kwargs, ): @@ -62,12 +54,11 @@ def __init__( self.cross_attn_all_layers = cross_attn_all_layers self.cross_attn_k = cross_attn_k self.hidden_size_global = hidden_size_global - self.pm_size = pm_size self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads or num_attention_heads - self.head_dim = head_dim or (hidden_size // num_attention_heads) - self.intermediate_size = intermediate_size or multiple_of * ((int(8 * hidden_size / 3) + multiple_of - 1) // multiple_of) + self.head_dim = hidden_size // num_attention_heads + self.intermediate_size = intermediate_size or int(8 * hidden_size / 3) self.num_hidden_layers = num_hidden_layers self.norm_eps = norm_eps self.dropout = dropout @@ -75,11 +66,7 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling or {"rope_type": "default"} self.hidden_act = hidden_act - self.multiple_of = multiple_of self._attn_implementation = _attn_implementation - self.decoder_dim_token_emb = 1024 - self.encoder_dim_token_emb = 1024 - self.encoder_dim_patch_emb = self.hidden_size super().__init__(**kwargs) @@ -99,8 +86,6 @@ def __init__( hidden_size=512, num_attention_heads=8, num_key_value_heads=None, - head_dim=None, - intermediate_size=None, num_hidden_layers=8, norm_eps=1e-5, dropout=0.0, @@ -108,7 +93,7 @@ def __init__( rope_theta=10000.0, rope_scaling=None, hidden_act="silu", - multiple_of=256, + intermediate_size=None, _attn_implementation="sdpa", **kwargs, ): @@ -119,8 +104,8 @@ def __init__( self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads or num_attention_heads - self.head_dim = head_dim or (hidden_size // num_attention_heads) - self.intermediate_size = intermediate_size or multiple_of * ((int(8 * hidden_size / 3) + multiple_of - 1) // multiple_of) + self.head_dim = hidden_size // num_attention_heads + self.intermediate_size = intermediate_size or int(8 * hidden_size / 3) self.num_hidden_layers = num_hidden_layers self.norm_eps = norm_eps self.dropout = dropout @@ -128,10 +113,7 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling or {"rope_type": "default"} self.hidden_act = hidden_act - self.multiple_of = multiple_of self._attn_implementation = _attn_implementation - self.decoder_dim_token_emb = 1024 - self.encoder_dim_token_emb = 1024 super().__init__(**kwargs) @@ -148,8 +130,6 @@ def __init__( hidden_size=512, num_attention_heads=8, num_key_value_heads=None, - head_dim=None, - intermediate_size=None, num_hidden_layers=8, norm_eps=1e-5, dropout=0.0, @@ -157,16 +137,15 @@ def __init__( rope_theta=10000.0, rope_scaling=None, hidden_act="silu", - multiple_of=256, + intermediate_size=None, _attn_implementation="sdpa", - global_dim_patch_emb=None, **kwargs, ): self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads or num_attention_heads - self.head_dim = head_dim or (hidden_size // num_attention_heads) - self.intermediate_size = intermediate_size or multiple_of * ((int(8 * hidden_size / 3) + multiple_of - 1) // multiple_of) + self.head_dim = hidden_size // num_attention_heads + self.intermediate_size = intermediate_size or int(8 * hidden_size / 3) self.num_hidden_layers = num_hidden_layers self.norm_eps = norm_eps self.dropout = dropout @@ -174,9 +153,7 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling or {"rope_type": "default"} self.hidden_act = hidden_act - self.multiple_of = multiple_of self._attn_implementation = _attn_implementation - self.global_dim_patch_emb = global_dim_patch_emb super().__init__(**kwargs) @@ -224,46 +201,32 @@ def __init__( hidden_size=512, num_hidden_layers=8, num_attention_heads=8, - head_dim=None, num_key_value_heads=None, max_position_embeddings=1024, norm_eps=1e-5, dropout=0.0, - ffn_dim_multiplier=None, - multiple_of=256, rope_theta=10000.0, attn_impl="sdpa", attn_bias_type="causal", - bos_token_id=1, - eos_token_id=2, + intermediate_size=None, **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads - self.head_dim = head_dim if head_dim is not None else (hidden_size // num_attention_heads) + self.head_dim = hidden_size // num_attention_heads self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads self.max_position_embeddings = max_position_embeddings self.norm_eps = norm_eps self.dropout = dropout - self.intermediate_size = ffn_dim_multiplier - self.multiple_of = multiple_of self.rope_theta = rope_theta self.attn_impl = attn_impl self.attn_bias_type = attn_bias_type - - super().__init__( - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - **kwargs, - ) - self.hidden_act = "silu" # BLT uses silu activation - - self.intermediate_size = multiple_of * ((int(8 * self.hidden_size / 3) + multiple_of - 1) // multiple_of) - + self.intermediate_size = intermediate_size or int(8 * self.hidden_size / 3) self.rope_scaling = {"rope_type": "default"} + super().__init__(**kwargs) class BLTConfig(PretrainedConfig): @@ -280,58 +243,6 @@ class BLTConfig(PretrainedConfig): max_position_embeddings (`int`, *optional*, defaults to 1024): The maximum sequence length that this model can handle. - # Main architecture dimensions - hidden_size (`int`, *optional*, defaults to 512): - Main dimension of the model. - num_hidden_layers (`int`, *optional*, defaults to 8): - Number of layers in the main transformer. - num_attention_heads (`int`, *optional*, defaults to 8): - Number of attention heads in the main transformer. - head_dim (`int`, *optional*): - Dimension of each attention head. If not specified, computed as hidden_size // num_attention_heads. - num_key_value_heads (`int`, *optional*): - Number of key-value heads for grouped query attention. If not specified, defaults to num_attention_heads. - - # Component-specific dimensions - hidden_size_global (`int`, *optional*, defaults to 512): - Dimension of the global transformer component. - hidden_size_local_decoder (`int`, *optional*, defaults to 512): - Dimension of the local decoder component. - hidden_size_local_encoder (`int`, *optional*, defaults to 512): - Dimension of the local encoder component. - num_hidden_layers_global (`int`, *optional*, defaults to 8): - Number of layers in the global transformer. - num_hidden_layers_local_decoder (`int`, *optional*, defaults to 8): - Number of layers in the local decoder. - num_hidden_layers_local_encoder (`int`, *optional*, defaults to 8): - Number of layers in the local encoder. - num_attention_heads_global (`int`, *optional*, defaults to 8): - Number of attention heads in the global transformer. - num_attention_heads_local_decoder (`int`, *optional*, defaults to 8): - Number of attention heads in the local decoder. - num_attention_heads_local_encoder (`int`, *optional*, defaults to 8): - Number of attention heads in the local encoder. - num_key_value_heads_global (`int`, *optional*): - Number of key-value heads in the global transformer. - - # Transformer configuration - norm_eps (`float`, *optional*, defaults to 1e-5): - The epsilon used by the layer normalization layers. - dropout (`float`, *optional*, defaults to 0.0): - The dropout probability for all fully connected layers. - ffn_dim_multiplier (`float`, *optional*, defaults to 1.0): - Multiplier for the feedforward network dimension. - multiple_of (`int`, *optional*, defaults to 256): - Make feedforward network dimension multiple of this value. - - # Positional encoding - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - - # Attention configuration - attn_impl (`str`, *optional*, defaults to "sdpa"): - Attention implementation to use ("sdpa" or "flex_attention"). - # Patching configuration patch_in_forward (`bool`, *optional*, defaults to False): Whether to perform patching during forward pass. @@ -351,14 +262,8 @@ class BLTConfig(PretrainedConfig): # Cross attention configurations cross_attn_k (`int`, *optional*): Number of cross attention components. - cross_attn_all_layers_decoder (`bool`, *optional*, defaults to False): - Whether to apply cross attention to all decoder layers. - cross_attn_all_layers_encoder (`bool`, *optional*, defaults to False): - Whether to apply cross attention to all encoder layers. # Encoder configurations - max_encoder_seq_length (`int`, *optional*): - Maximum sequence length for encoder. encoder_hash_byte_group_size (`Any`, *optional*): Hash byte group size for encoder. encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 30000): @@ -366,31 +271,24 @@ class BLTConfig(PretrainedConfig): encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3): Number of hash functions for byte groups. - # Parameter mixing - pm_size (`int`, *optional*, defaults to 0): - Parameter mixing size. - - # Special tokens - bos_token_id (`int`, *optional*, defaults to 1): - The id of the "beginning-of-sequence" token. - eos_token_id (`int`, *optional*, defaults to 2): - The id of the "end-of-sequence" token. - pad_token_id (`int`, *optional*, defaults to -1): - The id of the padding token. - - # Patcher configuration - patcher_args (`dict`, *optional*): - Dictionary containing configuration arguments for the BLT patcher/entropy model. - If provided, these will be used to initialize a BLTPatcherConfig instance. + # Component configurations + patcher_config (`Union[BLTPatcherConfig, dict]`, *optional*): + Configuration for the BLT patcher/entropy model component. + encoder_config (`Union[BLTLocalEncoderConfig, dict]`, *optional*): + Configuration for the BLT local encoder component. + decoder_config (`Union[BLTLocalDecoderConfig, dict]`, *optional*): + Configuration for the BLT local decoder component. + global_config (`Union[BLTGlobalTransformerConfig, dict]`, *optional*): + Configuration for the BLT global transformer component. ```python - >>> from transformers import ByteLatentTransformer, BLTConfig + >>> from transformers import BLTModel, BLTConfig >>> # Initializing a BLT configuration >>> configuration = BLTConfig() >>> # Initializing a model from the configuration - >>> model = ByteLatentTransformer(configuration) + >>> model = BLTModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -398,60 +296,31 @@ class BLTConfig(PretrainedConfig): model_type = "blt" keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = { + "patcher_config": BLTPatcherConfig, + "encoder_config": BLTLocalEncoderConfig, + "decoder_config": BLTLocalDecoderConfig, + "global_config": BLTGlobalTransformerConfig + } def __init__( self, vocab_size=256, max_position_embeddings=1024, - # Main architecture dimensions - hidden_size=512, - num_hidden_layers=8, - num_attention_heads=8, - head_dim=None, - num_key_value_heads=None, - # Component-specific dimensions - hidden_size_global=512, - hidden_size_local_decoder=512, - hidden_size_local_encoder=512, - num_hidden_layers_global=8, - num_hidden_layers_local_decoder=8, - num_hidden_layers_local_encoder=8, - num_attention_heads_global=8, - num_attention_heads_local_decoder=8, - num_attention_heads_local_encoder=8, - num_key_value_heads_global=None, - # Transformer configuration - norm_eps=1e-5, - dropout=0.0, - ffn_dim_multiplier=1.0, - multiple_of=256, - hidden_act="silu", - # Positional encoding - rope_theta=10000.0, - # Attention configuration - _attn_implementation="sdpa", - # Patching configuration patch_in_forward=False, patch_size=None, patching_mode=None, patching_threshold=None, patching_batch_size=1, - patching_device="cuda", max_patch_length=None, - # Cross attention configurations cross_attn_k=2, - cross_attn_all_layers_decoder=False, - cross_attn_all_layers_encoder=False, - # Encoder configurations - max_encoder_seq_length=None, encoder_hash_byte_group_size=None, encoder_hash_byte_group_vocab=30000, encoder_hash_byte_group_nb_functions=3, - # Parameter mixing - pm_size=0, - # Patcher configuration - patcher_args={}, - # Inherited + patcher_config=None, + encoder_config=None, + decoder_config=None, + global_config=None, **kwargs, ): @@ -459,135 +328,56 @@ def __init__( self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings - # Main architecture dimensions - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.head_dim = head_dim if head_dim is not None else (hidden_size // num_attention_heads) - self.num_key_value_heads = num_key_value_heads - - # Component-specific dimensions - self.hidden_size_global = hidden_size_global - self.hidden_size_local_decoder = hidden_size_local_decoder - self.hidden_size_local_encoder = hidden_size_local_encoder - self.num_hidden_layers_global = num_hidden_layers_global - self.num_hidden_layers_local_decoder = num_hidden_layers_local_decoder - self.num_hidden_layers_local_encoder = num_hidden_layers_local_encoder - self.num_attention_heads_global = num_attention_heads_global - self.num_attention_heads_local_decoder = num_attention_heads_local_decoder - self.num_attention_heads_local_encoder = num_attention_heads_local_encoder - self.num_key_value_heads_global = num_key_value_heads_global - - # Transformer configuration - self.norm_eps = norm_eps - self.dropout = dropout - self.intermediate_size = ffn_dim_multiplier - self.multiple_of = multiple_of - self.hidden_act = hidden_act - - # Positional encoding - self.rope_theta = rope_theta - - # Attention configuration - self._attn_implementation = _attn_implementation - # Patching configuration self.patch_in_forward = patch_in_forward self.patch_size = patch_size self.patching_mode = patching_mode self.patching_threshold = patching_threshold self.patching_batch_size = patching_batch_size - self.patching_device = patching_device self.max_patch_length = max_patch_length # Cross attention configurations self.cross_attn_k = cross_attn_k - self.cross_attn_all_layers_decoder = cross_attn_all_layers_decoder - self.cross_attn_all_layers_encoder = cross_attn_all_layers_encoder # Encoder configurations - self.max_encoder_seq_length = max_encoder_seq_length - self.encoder_hash_byte_group_size = encoder_hash_byte_group_size + self.encoder_hash_byte_group_size = encoder_hash_byte_group_size or [2, 3, 4] self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions - # Parameter mixing - self.pm_size = pm_size - # Initialize component configurations - self.encoder_config = BLTLocalEncoderConfig( - vocab_size=vocab_size, - cross_attn_all_layers=cross_attn_all_layers_encoder, - cross_attn_k=cross_attn_k, - hidden_size_global=hidden_size_global, - pm_size=pm_size, - hidden_size=hidden_size_local_encoder, - num_attention_heads=num_attention_heads_local_encoder, - num_key_value_heads=num_key_value_heads, - num_hidden_layers=num_hidden_layers_local_encoder, - norm_eps=norm_eps, - dropout=dropout, - max_position_embeddings=max_encoder_seq_length or max_position_embeddings, - rope_theta=rope_theta, - rope_scaling={"rope_type": "default"}, - hidden_act=hidden_act, - multiple_of=multiple_of, - ) - - self.decoder_config = BLTLocalDecoderConfig( - vocab_size=vocab_size, - cross_attn_all_layers=cross_attn_all_layers_decoder, - cross_attn_k=cross_attn_k, - hidden_size_global=hidden_size_global, - hidden_size=hidden_size_local_decoder, - num_attention_heads=num_attention_heads_local_decoder, - num_key_value_heads=num_key_value_heads, - num_hidden_layers=num_hidden_layers_local_decoder, - norm_eps=norm_eps, - dropout=dropout, - max_position_embeddings=max_encoder_seq_length or max_position_embeddings, - rope_theta=rope_theta, - rope_scaling={"rope_type": "default"}, - hidden_act=hidden_act, - multiple_of=multiple_of, - ) - - self.global_config = BLTGlobalTransformerConfig( - hidden_size=hidden_size_global, - num_attention_heads=num_attention_heads_global, - num_key_value_heads=num_key_value_heads_global, - num_hidden_layers=num_hidden_layers_global, - norm_eps=norm_eps, - dropout=dropout, - max_position_embeddings=max_position_embeddings, - rope_theta=rope_theta, - rope_scaling={"rope_type": "default"}, - hidden_act=hidden_act, - multiple_of=multiple_of, - global_dim_patch_emb=hidden_size_local_encoder * cross_attn_k, - ) - - self.patcher_config = BLTPatcherConfig(**patcher_args) - - # Handle hash byte group size validation - if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str: - self.encoder_hash_byte_group_size = [ - int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0 - ] - - # Rope scaling configuration - self.rope_scaling = {"rope_type": "default"} - - # Set compatibility attributes for transformers - self.num_key_value_heads = num_attention_heads_local_encoder - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size_local_encoder - self.num_attention_heads = num_attention_heads_local_encoder - - # Calculate intermediate_size using BLTMLP logic for each component - # Note: Each component uses its own hidden dimension, not the main dim - self.intermediate_size = None # Will be calculated per component + if patcher_config is None: + self.patcher_config = BLTPatcherConfig() + logger.info("patcher_config is None, using default BLT patcher config") + elif isinstance(patcher_config, dict): + self.patcher_config = BLTPatcherConfig(**patcher_config) + elif isinstance(patcher_config, BLTPatcherConfig): + self.patcher_config = patcher_config + + if encoder_config is None: + self.encoder_config = BLTLocalEncoderConfig() + logger.info("encoder_config is None, using default BLT encoder config") + elif isinstance(encoder_config, dict): + self.encoder_config = BLTLocalEncoderConfig(**encoder_config) + elif isinstance(encoder_config, BLTLocalEncoderConfig): + self.encoder_config = encoder_config + + if decoder_config is None: + self.decoder_config = BLTLocalDecoderConfig() + logger.info("decoder_config is None, using default BLT decoder config") + elif isinstance(decoder_config, dict): + self.decoder_config = BLTLocalDecoderConfig(**decoder_config) + elif isinstance(decoder_config, BLTLocalDecoderConfig): + self.decoder_config = decoder_config + + if global_config is None: + self.global_config = BLTGlobalTransformerConfig() + logger.info("global_config is None, using default BLT global config") + elif isinstance(global_config, dict): + self.global_config = BLTGlobalTransformerConfig(**global_config) + elif isinstance(global_config, BLTGlobalTransformerConfig): + self.global_config = global_config + super().__init__(**kwargs) __all__ = [ "BLTConfig", @@ -595,6 +385,4 @@ def __init__( "BLTLocalEncoderConfig", "BLTLocalDecoderConfig", "BLTGlobalTransformerConfig", - "InitStdFactor", ] - diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index d0e857bddc9b..32b11ae8e316 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -46,7 +46,6 @@ logger = logging.get_logger(__name__) - class PatchingModeEnum(str, Enum): entropy = "entropy" bpe = "bpe" @@ -130,7 +129,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class BLTMLP(nn.Module): def __init__(self, config): super().__init__() - self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -552,12 +550,12 @@ def __init__(self, config: BLTLocalEncoderConfig): self.rotary_emb = BLTRotaryEmbedding(config=config) self.patch_embedding_projection = nn.Linear( - in_features=config.encoder_dim_patch_emb, - out_features=config.encoder_dim_token_emb * config.cross_attn_k, + in_features=self.hidden_size, + out_features=self.hidden_size * config.cross_attn_k, bias=False, ) - self.embed_tokens = nn.Embedding(self.vocab_size + config.pm_size, self.hidden_size) + self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size) self.cross_attn_layers = torch.nn.ModuleList() layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1 @@ -662,7 +660,7 @@ def __init__(self, config: BLTLocalDecoderConfig): self.patch_embedding_projection = nn.Linear( in_features=config.hidden_size_global, - out_features=config.decoder_dim_token_emb * config.cross_attn_k, + out_features=self.hidden_size * config.cross_attn_k, bias=False, ) @@ -735,7 +733,7 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.config = config self.layer_idx = layer_idx # Use provided hidden_size or fallback to encoder dimension - self.hidden_size = hidden_size or config.hidden_size_local_encoder + self.hidden_size = hidden_size or config.encoder_config.hidden_size self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention self.head_dim = self.hidden_size // self.num_heads @@ -812,7 +810,7 @@ def forward( key_states, value_states, attention_mask, - dropout=0.0, #if not self.training else self.dropout, + dropout=0.0, scaling=self.scaling, **kwargs, ) @@ -820,7 +818,6 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - # Apply full row masking if provided (following mllama pattern) if full_text_row_masked_out_mask is not None: attn_output = full_text_row_masked_out_mask[:, 0] * attn_output @@ -906,17 +903,17 @@ def _init_weights(self, module): elif isinstance(module, BLTModel): if module.encoder_hash_tok_embedding is not None: - emb_std = module.config.hidden_size_local_encoder ** (-0.5) + emb_std = module.config.encoder_config.hidden_size ** (-0.5) for emb in module.encoder_hash_tok_embedding: emb._custom_std = emb_std elif isinstance(module, BLTLocalEncoder): if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.encoder_dim_patch_emb ** (-0.5) + module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) elif isinstance(module, BLTLocalDecoder): if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.hidden_size_global ** (-0.5) + module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) elif isinstance(module, BLTPatcher): emb_std = module.config.hidden_size ** (-0.5) @@ -936,7 +933,7 @@ def __init__(self, config: BLTConfig): self.encoder_hash_tok_embedding = init_hash_embeddings( config, - local_encoder_dim=config.hidden_size_local_encoder, + local_encoder_dim=config.encoder_config.hidden_size, encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, ) @@ -960,7 +957,7 @@ def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = threshold=self.config.patching_threshold, max_patch_length=self.config.max_patch_length, patching_batch_size=self.config.patching_batch_size, - device=self.config.patching_device, + device=tokens.device, ) else: # Default to byte-level patching @@ -1032,12 +1029,11 @@ def __init__(self, config: BLTPatcherConfig): self.rotary_emb = BLTRotaryEmbedding(config=self.config) self.layers = nn.ModuleList() - # Create transformer layers using the patcher config for layer_idx in range(self.config.num_hidden_layers): self.layers.append(BLTTransformerLayer(self.config, layer_idx)) - self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.hidden_size) + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps) From 31700033ebb0343d5a46ad892693d106abe2ce21 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 1 Jul 2025 09:42:11 +0000 Subject: [PATCH 050/139] read from config instead of self args --- .../models/blt_wip/modeling_blt.py | 69 ++++++++----------- 1 file changed, 29 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 32b11ae8e316..01b48351e4c9 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -85,8 +85,8 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = F.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -538,30 +538,25 @@ class BLTLocalEncoder(nn.Module): def __init__(self, config: BLTLocalEncoderConfig): super().__init__() - self.hidden_size = config.hidden_size - self.vocab_size=config.vocab_size - self.num_hidden_layers = config.num_hidden_layers - self.dropout = config.dropout - self.cross_attn_all_layers = config.cross_attn_all_layers - self.cross_attn_k = config.cross_attn_k + self.config = config - self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)]) + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.rotary_emb = BLTRotaryEmbedding(config=config) self.patch_embedding_projection = nn.Linear( - in_features=self.hidden_size, - out_features=self.hidden_size * config.cross_attn_k, + in_features=config.hidden_size, + out_features=config.hidden_size * config.cross_attn_k, bias=False, ) - self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1 + layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size) + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) def forward( @@ -582,23 +577,23 @@ def forward( batch_size, _, _ = input_embeds.shape - hidden_states = nn.functional.dropout(input_embeds, p=self.dropout, training=self.training) + hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for idx, layer in enumerate(self.layers): layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) hidden_states = layer_outputs[0] - if idx == len(self.layers) - 1 or self.cross_attn_all_layers: + if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) - layer_idx = idx if self.cross_attn_all_layers else 0 + layer_idx = idx if self.config.cross_attn_all_layers else 0 cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, @@ -646,36 +641,31 @@ def __init__(self, config: BLTLocalDecoderConfig): super().__init__() # Extract config values to instance attributes - self.hidden_size = config.hidden_size - self.vocab_size=config.vocab_size - self.num_hidden_layers = config.num_hidden_layers - self.dropout = config.dropout + self.config = config self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove - self.cross_attn_all_layers = config.cross_attn_all_layers - self.cross_attn_k = config.cross_attn_k - self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)]) + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.rotary_emb = BLTRotaryEmbedding(config=config) self.patch_embedding_projection = nn.Linear( in_features=config.hidden_size_global, - out_features=self.hidden_size * config.cross_attn_k, + out_features=config.hidden_size * config.cross_attn_k, bias=False, ) - self.norm = BLTRMSNorm(self.hidden_size, eps=config.norm_eps) + self.norm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1 + layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size) + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) self.lm_head = nn.Linear( - self.hidden_size, - self.vocab_size, + config.hidden_size, + config.vocab_size, bias=False, ) @@ -695,7 +685,7 @@ def forward( hidden_states = embeds patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds @@ -703,9 +693,9 @@ def forward( position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) - hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for i, layer in enumerate(self.layers): - if i == 0 or self.cross_attn_all_layers: + if i == 0 or self.config.cross_attn_all_layers: # Use cross attention to extract info from patch_embeds into hidden_states cross_attention_output, _, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, @@ -833,12 +823,10 @@ class BLTGlobalTransformer(nn.Module): def __init__(self, config: BLTGlobalTransformerConfig): super().__init__() - self.hidden_size = config.hidden_size - self.num_hidden_layers = config.num_hidden_layers - self.dropout = config.dropout + self.config = config self.layers = nn.ModuleList() - for layer_idx in range(self.num_hidden_layers): + for layer_idx in range(config.num_hidden_layers): self.layers.append(BLTTransformerLayer(config, layer_idx)) self.rotary_emb = BLTRotaryEmbedding(config=config) @@ -854,7 +842,7 @@ def forward( hidden_states = input_embeds - hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -1029,6 +1017,7 @@ def __init__(self, config: BLTPatcherConfig): self.rotary_emb = BLTRotaryEmbedding(config=self.config) self.layers = nn.ModuleList() + for layer_idx in range(self.config.num_hidden_layers): self.layers.append(BLTTransformerLayer(self.config, layer_idx)) From 17720dc58f904f4d03939a165c0d47a9e86ac7db Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 1 Jul 2025 12:14:33 +0000 Subject: [PATCH 051/139] update demo file --- src/demo_hf.py | 43 ++----------------------------------------- 1 file changed, 2 insertions(+), 41 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index b5856cc7f39e..d6271f715325 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -21,26 +21,6 @@ def get_generation_range(prompt_tokens: list[list[int]] | None, max_gen_len: int return batch_min_prompt_length, batch_max_prompt_length + max_gen_len -def sample_top_k(probs, k): - topk_value, _ = torch.topk(probs, k) # batch_sz x topk - min_value_top_k = topk_value[:, [-1]] - probs[probs < min_value_top_k] = 0.0 - probs.div_(probs.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs, num_samples=1) - return next_token - - -def sample_top_p(probs, p): - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token - - @torch.inference_mode() def generate( prompts: list[str] | None, @@ -49,11 +29,6 @@ def generate( tokenizer: BLTTokenizer, max_prompt_len: int = 256, max_gen_len: int = 256, - use_sampling: bool = False, - temp: float = 1.0, - top_k: int = 0, - top_p: float = 0.0, - remove_prompts: bool = True, device: torch.device = torch.device("cpu"), ) -> list[list[int]]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -75,26 +50,12 @@ def generate( current_tokens = tokens[:, :curr_pos] logits = model(current_tokens)[:, -1] - if use_sampling: - probs = torch.softmax(logits / temp, dim=-1) - if top_p > 0.0: - next_token = sample_top_p(probs, top_p) - elif top_k > 0: - next_token = sample_top_k(probs, top_k) - else: - next_token = torch.multinomial(probs, num_samples=1) - else: - next_token = torch.argmax(logits, dim=-1) + next_token = torch.argmax(logits, dim=-1) next_token = torch.where(input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token) tokens[:, curr_pos] = next_token - if remove_prompts: - generated_tokens = [ - t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist() for i, t in enumerate(tokens) - ] - else: - generated_tokens = [t[: len(prompt_tokens[i]) + max_gen_len].tolist() for i, t in enumerate(tokens)] + generated_tokens = [t[: len(prompt_tokens[i]) + max_gen_len].tolist() for i, t in enumerate(tokens)] return generated_tokens From 8ca6c73e4325a54f72ef86f1799e5d2b77d73a34 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 2 Jul 2025 14:53:54 +0000 Subject: [PATCH 052/139] model weights to causal lm weights --- src/demo_hf.py | 81 +- .../models/blt_wip/configuration_blt.py | 10 + .../models/blt_wip/modeling_blt_dev.py | 1225 +++++++++++++++++ .../models/blt_wip/tokenization_blt.py | 14 + 4 files changed, 1277 insertions(+), 53 deletions(-) create mode 100644 src/transformers/models/blt_wip/modeling_blt_dev.py diff --git a/src/demo_hf.py b/src/demo_hf.py index d6271f715325..e0a3ef369f06 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -2,8 +2,12 @@ import os import torch +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file +from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM from transformers.models.blt_wip.modeling_blt import BLTModel + from transformers.models.blt_wip.tokenization_blt import BLTTokenizer @@ -15,69 +19,40 @@ gc.collect() torch.cuda.empty_cache() -def get_generation_range(prompt_tokens: list[list[int]] | None, max_gen_len: int) -> tuple[int, int]: - batch_min_prompt_length = min([len(t) for t in prompt_tokens]) - batch_max_prompt_length = max([len(t) for t in prompt_tokens]) - return batch_min_prompt_length, batch_max_prompt_length + max_gen_len - - -@torch.inference_mode() -def generate( - prompts: list[str] | None, - *, - model: BLTModel, - tokenizer: BLTTokenizer, - max_prompt_len: int = 256, - max_gen_len: int = 256, - device: torch.device = torch.device("cpu"), -) -> list[list[int]]: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - model.eval() - prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts] - # Truncation - prompt_tokens = [t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :] for t in prompt_tokens] - start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len) - batch_size = len(prompt_tokens) - tokens = torch.full((batch_size, end_pos), tokenizer.pad_id, dtype=torch.long, device=device) - - # Copy inputs to tensor for generated tokens - for i, row_tokens in enumerate(prompt_tokens): - tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long() - input_text_mask = tokens != tokenizer.pad_id - - for i, curr_pos in enumerate(range(start_pos, end_pos)): - current_tokens = tokens[:, :curr_pos] - logits = model(current_tokens)[:, -1] - - next_token = torch.argmax(logits, dim=-1) - - next_token = torch.where(input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token) - tokens[:, curr_pos] = next_token - - generated_tokens = [t[: len(prompt_tokens[i]) + max_gen_len].tolist() for i, t in enumerate(tokens)] - return generated_tokens - def main(prompt: str = "my name is", model_name: str = "blt-1b"): device = "cuda" - blt_repo = "itazap/blt-1b-converted" + blt_model = BLTModel.from_pretrained("itazap/blt-1b-converted") - model = BLTModel.from_pretrained(blt_repo).to(device) + causal_lm = BLTForCausalLM(blt_model.config) + causal_lm.model.load_state_dict(blt_model.state_dict(), strict=False) + causal_lm.lm_head.weight = blt_model.local_decoder.lm_head.weight + causal_lm.save_pretrained( "./blt-1b-causallm") + + # model = causal_lm + # model = model.to(device) + model = BLTForCausalLM.from_pretrained("./blt-1b-causallm").to(device) + tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) - prompts = [prompt] + input_ids = torch.tensor([tokenizer.encode(prompt, add_eos=False)]).to(device) - outputs = generate(prompts, model=model, tokenizer=tokenizer, max_gen_len=200, device=device) + with torch.no_grad(): + output_ids = model.generate( + input_ids, + max_new_tokens=200, + do_sample=False, + temperature=1.0, + pad_token_id=tokenizer.pad_id, + eos_token_id=tokenizer.eos_id, + ) - text_outputs = [tokenizer.decode(t) for t in outputs] + generated_ids = output_ids[0][len(input_ids[0]):] + output_text = tokenizer.decode(generated_ids.tolist()) - for p, t in zip(prompts, text_outputs): - print(f'Prompt: "{p}"') - print(f'Completion: "{t}"') - print() - + print(f'Prompt: "{prompt}"') + print(f'Completion: "{output_text}"') print('here') diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index 6d5990cdb57a..821d4f8202b4 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -321,6 +321,10 @@ def __init__( encoder_config=None, decoder_config=None, global_config=None, + # Generation configuration + bos_token_id=1, + eos_token_id=2, + pad_token_id=-1, **kwargs, ): @@ -328,6 +332,12 @@ def __init__( self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings + # Generation configuration + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.return_dict = True + # Patching configuration self.patch_in_forward = patch_in_forward self.patch_size = patch_size diff --git a/src/transformers/models/blt_wip/modeling_blt_dev.py b/src/transformers/models/blt_wip/modeling_blt_dev.py new file mode 100644 index 000000000000..c08fc10d9731 --- /dev/null +++ b/src/transformers/models/blt_wip/modeling_blt_dev.py @@ -0,0 +1,1225 @@ +# coding=utf-8 +# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLT model.""" + +from ...utils import is_torch_flex_attn_available, logging +from typing import Callable, List, Optional, Tuple, Union + +from enum import Enum + +from ...cache_utils import Cache +from ...activations import ACT2FN + +import torch +import torch.distributions +import torch.nn +import torch.nn as nn +from torch.nn import functional as F + +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update + +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from .configuration_blt import ( + BLTConfig, + BLTLocalEncoderConfig, + BLTLocalDecoderConfig, + BLTGlobalTransformerConfig, + BLTPatcherConfig, +) + +from ...generation.utils import GenerationMixin +from ...modeling_outputs import CausalLMOutputWithPast + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +class PatchingModeEnum(str, Enum): + entropy = "entropy" + bpe = "bpe" + bpe_patcher = "bpe_patcher" + space = "space" + static = "static" + byte = "byte" + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = F.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + # TODO: not exactly equivalent to other transformers implementations,, need feedback + # Extract first head_dim//2 elements which correspond to the unique frequencies + # This matches the original BLT approach which uses head_dim//2 frequency pairs + head_dim = q.shape[-1] + cos_freqs = cos[..., :head_dim//2] # [B, S, D/2] + sin_freqs = sin[..., :head_dim//2] # [B, S, D/2] + + # Expand cos/sin to match query/key tensor format [B, H, S, D/2] + cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + + # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... + q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] + k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] + + # Extract real and i parts + q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] + k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] + + # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] + q_real_rot = cos_freqs * q_real - sin_freqs * q_imag + q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag + k_real_rot = cos_freqs * k_real - sin_freqs * k_imag + k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag + + # Recombine pairs and reshape back to original format + q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] + k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] + + return q_rot.type_as(q), k_rot.type_as(k) + + +class BLTMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class BLTRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BLTRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class BLTTransformerLayer(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) + self.mlp = BLTMLP(config) + self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor`, *optional*): + Position indices of tokens in the sequence for RoPE computation. + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class BLTSelfAttention(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.config = config + self.num_heads = config.num_attention_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = None + self.rope_theta = config.rope_theta + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + output_attentions = False + self.config._attn_implementation = "sdpa" + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): + primes = [ + 1000000007, 5915587277, 1500450271, 3267000013, 5754853343, + 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313, + ] + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) + powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) + prime_powers = prime ** powers + return torch.sum(token_tensor * prime_powers, dim=-1) + + +def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): + """Hash token groups and map to range [0, max_hash].""" + with torch.no_grad(): + batch_size, seq_len = token_ids.shape + # Add padding for sliding window + padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) + padded_tokens = torch.cat([padding, token_ids], dim=1) + + # Create sliding windows and compute hashes + windows = padded_tokens.unfold(1, group_size, 1) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values = hashes % max_hash + + hash_values.requires_grad = False + return hash_values + + +def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): + """Initialize hash-based token embeddings for the BLT encoder.""" + num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) + embeddings = [ + nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) + for _ in range(num_embeddings) + ] + return nn.ModuleList(embeddings) + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """Compute token embeddings enhanced with hash-based embeddings.""" + embeddings = local_encoder.embed_tokens(local_encoder_tokens) + embedding_idx = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab + ) + embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) + embedding_idx += 1 + + return embeddings + + +def _prepare_patch_cross_attention_mask( + patch_ids: torch.Tensor, + num_patches: int, + sequence_length: int, + patches_as_queries: bool = False, + cross_attn_k: int = 1, + dtype: torch.dtype = torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Prepare cross-attention mask for patch-based attention, following mllama's robust approach. + + This function creates masks that control which patches can attend to which other patches, + with support for query/key role swapping and cross-attention multipliers. + + Args: + patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. + num_patches (int): Total number of patches. + sequence_length (int): Length of the sequence. + patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. + cross_attn_k (int): Cross-attention multiplier for repeating patches. + dtype (torch.dtype): Data type for the output mask. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] + - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows + """ + batch_size, seq_len = patch_ids.shape + device = patch_ids.device + + # Determine query and key lengths based on configuration + if patches_as_queries: + q_len = num_patches * cross_attn_k + kv_len = sequence_length + # Create patch-to-sequence mapping + q_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(-1).expand( + batch_size, num_patches, seq_len + ) + kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) + else: + q_len = sequence_length + kv_len = num_patches * cross_attn_k + # Create sequence-to-patch mapping + q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) + kv_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand( + batch_size, seq_len, num_patches + ) + + # Create base attention mask - boolean mask where True means "should attend" + # Exact patch matching + cross_attention_mask = q_patch_ids == kv_patch_ids + + # Handle cross_attn_k multiplier by repeating along appropriate dimension + repeat_dim = 1 if patches_as_queries else -1 + cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) + + # Validate dimensions + expected_shape = (batch_size, q_len, kv_len) + if cross_attention_mask.shape != expected_shape: + raise ValueError(f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}") + + # Reshape so it can be used by attn module - add head dimension + cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len] + + # Invert the mask (following mllama pattern exactly) + # True -> 0.0 (attend), False -> 1.0 (will become -inf) + inverted_cross_attn_mask = (1.0 - cross_attention_mask.to(dtype)) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # Apply full-row bias (following mllama pattern exactly) + # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: + """ + Splits patch lengths into smaller segments if they exceed `max_patch_length`. + Pads the result to uniform length across the batch. + + Args: + patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. + max_patch_length (int, optional): Maximum allowed length per patch. + + Returns: + torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. + """ + if max_patch_length is None: + return patch_lengths + + batch_size = patch_lengths.size(0) + processed = [] + + for seq in patch_lengths: + splits = [] + for length in seq[seq > 0]: + length = length.item() + full_chunks, remainder = divmod(length, max_patch_length) + splits.extend([max_patch_length] * full_chunks) + if remainder: + splits.append(remainder) + processed.append(splits) + + # Find max length to pad to + max_len = max(len(splits) for splits in processed) + padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) + + for i, splits in enumerate(processed): + if splits: + padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) + + # Trim zero columns + if (padded != 0).any(dim=0).sum() < padded.shape[1]: + last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 + padded = padded[:, :last_nonzero] + + return padded + + +class BLTRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + self.rope_type = config.rope_scaling["rope_type"] + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class BLTLocalEncoder(nn.Module): + def __init__(self, config: BLTLocalEncoderConfig): + super().__init__() + + self.config = config + + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + + self.rotary_emb = BLTRotaryEmbedding(config=config) + + self.patch_embedding_projection = nn.Linear( + in_features=config.hidden_size, + out_features=config.hidden_size * config.cross_attn_k, + bias=False, + ) + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) + ) + + def forward( + self, + input_ids: torch.Tensor, + input_embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ """ + if input_embeds is None: + input_embeds = self.embed_tokens(input_ids) + + batch_size, _, _ = input_embeds.shape + + hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) + + position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + + for idx, layer in enumerate(self.layers): + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: + patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) + + layer_idx = idx if self.config.cross_attn_all_layers else 0 + cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( + hidden_states=patch_embeds, + cross_attention_states=hidden_states, + attention_mask=cross_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + output_attentions=False, + use_cache=False, + cache_position=None, + ) + patch_embeds = patch_embeds + cross_attention_output + + encoder_cross_states = patch_embeds + return hidden_states, encoder_cross_states + + def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + batch_size, _, embedding_dim = hidden_states.shape + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) + + reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device) + reduced_embeddings = reduced_embeddings.scatter_reduce( + src=hidden_states, + dim=1, + index=patch_ids, + reduce=reduction, + include_self=False, + ) + reduced_embeddings = reduced_embeddings[:, :max_num_patches, :] + + return reduced_embeddings + + +class BLTLocalDecoder(nn.Module): + def __init__(self, config: BLTLocalDecoderConfig): + super().__init__() + + # Extract config values to instance attributes + self.config = config + self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove + + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + + self.rotary_emb = BLTRotaryEmbedding(config=config) + + self.patch_embedding_projection = nn.Linear( + in_features=config.hidden_size_global, + out_features=config.hidden_size * config.cross_attn_k, + bias=False, + ) + + self.norm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) + ) + + # self.lm_head = nn.Linear( + # config.hidden_size, + # config.vocab_size, + # bias=False, + # ) + + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor], + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + batch_size, _, _ = embeds.shape + + hidden_states = embeds + + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) + + if patch_embeds is not None and not self.cross_attn_decoder: + hidden_states = hidden_states + patch_embeds + + position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + for i, layer in enumerate(self.layers): + if i == 0 or self.config.cross_attn_all_layers: + # Use cross attention to extract info from patch_embeds into hidden_states + cross_attention_output, _, _ = self.cross_attn_layers[i]( + hidden_states=hidden_states, + cross_attention_states=patch_embeds, + attention_mask=cross_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + output_attentions=False, + use_cache=False, + cache_position=None, + ) + hidden_states = hidden_states + cross_attention_output + + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + logits = self.norm(hidden_states) + # logits = self.lm_head(logits) + return logits, cache + + +class BLTCrossAttention(nn.Module): + """Cross-attention module for BLT, following transformers style""" + + def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + # Use provided hidden_size or fallback to encoder dimension + self.hidden_size = hidden_size or config.encoder_config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = None #self.head_dim ** -0.5 + self.dropout = config.dropout + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.q_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) + self.k_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(hidden_states) # BLT normalizes first + query_states = self.q_proj(query_states) + + if cross_attention_states is not None: + cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + if past_key_value is not None: + # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states + # we still update the cross key states, past_cross_states, new_cross_states. And use it! + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position is not None and cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + if cross_attention_states is None: + raise ValueError( + "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!" + ) + + attention_interface: Callable = eager_attention_forward + + self.config._attn_implementation = "sdpa" + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if full_text_row_masked_out_mask is not None: + attn_output = full_text_row_masked_out_mask[:, 0] * attn_output + + attn_output = attn_output + hidden_states + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class BLTGlobalTransformer(nn.Module): + def __init__(self, config: BLTGlobalTransformerConfig): + super().__init__() + + self.config = config + + self.layers = nn.ModuleList() + for layer_idx in range(config.num_hidden_layers): + self.layers.append(BLTTransformerLayer(config, layer_idx)) + + self.rotary_emb = BLTRotaryEmbedding(config=config) + + + def forward( + self, + input_embeds: torch.Tensor, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + batch_size, seq_len, _ = input_embeds.shape + + hidden_states = input_embeds + + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + + position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for i, layer in enumerate(self.layers): + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + return hidden_states, cache + + + + +class BLTPreTrainedModel(PreTrainedModel): + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False # BLT uses its own attention implementation + _supports_sdpa = True + _supports_cache_class = False + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + std = getattr(module, '_custom_std', module.in_features ** (-0.5)) + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.Embedding): + std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + elif isinstance(module, BLTModel): + if module.encoder_hash_tok_embedding is not None: + emb_std = module.config.encoder_config.hidden_size ** (-0.5) + for emb in module.encoder_hash_tok_embedding: + emb._custom_std = emb_std + + elif isinstance(module, BLTLocalEncoder): + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) + + elif isinstance(module, BLTLocalDecoder): + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) + + elif isinstance(module, BLTPatcher): + emb_std = module.config.hidden_size ** (-0.5) + module.embed_tokens._custom_std = emb_std + module.lm_head._custom_std = emb_std + + +class BLTModel(BLTPreTrainedModel): + def __init__(self, config: BLTConfig): + super().__init__(config) + + self.config = config + + self.local_encoder = BLTLocalEncoder(config.encoder_config) + self.global_transformer = BLTGlobalTransformer(config.global_config) + self.local_decoder = BLTLocalDecoder(config.decoder_config) + + self.encoder_hash_tok_embedding = init_hash_embeddings( + config, + local_encoder_dim=config.encoder_config.hidden_size, + encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, + ) + + if self.config.patch_in_forward: + self.patcher = BLTPatcher(config.patcher_config) + self.patcher.eval() + for param in self.patcher.parameters(): + param.requires_grad = False + else: + self.patcher = None + + def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None): + batch_size, sequence_length = tokens.shape + + # Handle patching + if patch_lengths is None: + if self.config.patching_mode == PatchingModeEnum.entropy: + _, patch_lengths, _ = self.patcher( + tokens, + patch_size=self.config.patch_size, + threshold=self.config.patching_threshold, + max_patch_length=self.config.max_patch_length, + patching_batch_size=self.config.patching_batch_size, + device=tokens.device, + ) + else: + # Default to byte-level patching + patch_lengths = process_patch_lengths( + torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device), + self.config.max_patch_length + ) + + patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) + cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( + patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32 + ) + + encoder_embeds = compute_hash_embeddings( + tokens, self.local_encoder, self.encoder_hash_tok_embedding, + self.config.encoder_hash_byte_group_nb_functions, + self.config.encoder_hash_byte_group_size, + self.config.encoder_hash_byte_group_vocab, + ) + + encoder_hidden_states, encoder_cross_states = self.local_encoder( + input_ids=tokens, + input_embeds=encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + + global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) + + global_hidden_states, _ = self.global_transformer( + input_embeds=global_hidden_states, + ) + + decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) + cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( + decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32 + ) + + output, _ = self.local_decoder( + tokens=tokens, + embeds=encoder_hidden_states, + patch_embeds=global_hidden_states, + mask=None, + cross_mask=cross_attn_mask_dec, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, + ) + + return output + + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: + """Convert patch lengths to patch IDs for each token position.""" + batch_size = patch_lengths.shape[0] + patch_starts = torch.cat([ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1] + ], dim=-1) + + token_positions = torch.arange(seq_len, device=patch_lengths.device) + return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 + + +class BLTPatcher(BLTPreTrainedModel): + def __init__(self, config: BLTPatcherConfig): + super().__init__(config) + + self.rotary_emb = BLTRotaryEmbedding(config=self.config) + + self.layers = nn.ModuleList() + + for layer_idx in range(self.config.num_hidden_layers): + self.layers.append(BLTTransformerLayer(self.config, layer_idx)) + + + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + + self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps) + + self.lm_head = nn.Linear( + self.config.hidden_size, + self.config.vocab_size, + bias=False, + ) + + def forward( + self, + token_values: torch.Tensor, + patch_size: Optional[int] = None, + threshold: Optional[float] = None, + max_patch_length: Optional[int] = None, + patching_batch_size: int = 1, + device: Optional[str] = None, + ): + + # Handle chunked processing for entropy calculation + entropies = [] + predictions = [] + max_length = self.config.max_position_embeddings + batch_numel = max_length * patching_batch_size + splits = torch.split(token_values.flatten(), batch_numel) + + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length) + if device is not None: + split = split.to(device) + + # Process chunk: embeddings -> layers -> output + batch_size, sequence_length = split.shape + input_embeds = self.embed_tokens(split) + + hidden_states = input_embeds + + batch_size, _, _ = input_embeds.shape + + position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for i, layer in enumerate(self.layers): + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + logits = self.lm_head(self.norm(hidden_states)) + logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] + predictions.append(logits) + prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() + entropies.append(prediction_entropies) + + concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) + concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) + + # Always compute patch lengths from concatenated entropies + batch_size, sequence_length = token_values.shape + + # Find patch start IDs based on entropy + if patch_size is not None: + patch_lengths = self.patch_lengths_from_entropies( + entropies=concat_entropies, + sequence_length=sequence_length, + patch_size=patch_size, + threshold=threshold, + ) + else: + # Default to byte-level patching + patch_lengths = torch.ones((batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device) + patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) + return concat_entropies, patch_lengths, concat_predictions + + @staticmethod + def patch_lengths_from_entropies( + entropies, + sequence_length, + patch_size=None, + threshold=None, + ): + """ + Computes patch lengths from token entropies. + + Depending on whether a threshold is provided, the function uses either: + - Top-k selection based on entropy (when `threshold` is None), or + - Thresholding the entropy values (when `threshold` is set). + """ + + batch_size = entropies.shape[0] + + # Always include token 0 and 1 as starting tokens + init_tokens = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + offset = init_tokens.shape[1] + + # Ignore first token entropy (BOS) + entropies = entropies[:, 1:] + + if threshold is None: + # Use top-k entropy values to define patch start points + num_patches = sequence_length // patch_size + topk_indices = entropies.topk(num_patches - 2, dim=1).indices + patch_starts = topk_indices.sort(dim=1).values + else: + # Threshold the entropy values to define patch start points + patch_mask = entropies > threshold + + seq_len = patch_mask.shape[1] + + # Create patch IDs (token indices), and add a sentinel to ensure alignment + token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) + sentinel = torch.full_like(token_indices, seq_len) + padded_indices = torch.cat([token_indices, sentinel], dim=1) + + # Pad mask with inverse to align sentinel correctly + padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) + + # Select indices where mask is True + patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) + max_valid_patches = patch_mask.sum(dim=1).max() + patch_starts = patch_starts[:, :max_valid_patches] + + # Offset patch starts to account for the two initial tokens + patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) + + # Compute patch end positions by shifting start positions + last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1) + patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1) + + patch_lengths = patch_ends - patch_start_ids + 1 + + return patch_lengths + + +class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): + def __init__(self, config): + super().__init__(config) + self.model = BLTModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + **kwargs, + ): + logits = self.model(input_ids) + + logits = self.lm_head(logits) + + return CausalLMOutputWithPast( + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + if past_key_values is not None: + input_ids = input_ids[:, -1:] + return {"input_ids": input_ids, "past_key_values": past_key_values} + + def get_input_embeddings(self): + return self.model.local_encoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.local_encoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + +__all__ = [ + "BLTPreTrainedModel", + "BLTModel", + "BLTPatcher", + "BLTLocalEncoder", + "BLTLocalDecoder", + "BLTGlobalTransformer", + "BLTTransformerLayer", + "BLTForCausalLM", +] \ No newline at end of file diff --git a/src/transformers/models/blt_wip/tokenization_blt.py b/src/transformers/models/blt_wip/tokenization_blt.py index f5fba8a50625..dfbe1602583c 100644 --- a/src/transformers/models/blt_wip/tokenization_blt.py +++ b/src/transformers/models/blt_wip/tokenization_blt.py @@ -15,6 +15,7 @@ """Tokenization classes for BLT.""" import os +import torch from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from ...tokenization_utils import AddedToken, PreTrainedTokenizer @@ -269,4 +270,17 @@ def get_vocab_size(self) -> int: """Get vocab size like the original tokenizer.""" return self.vocab_size_unit_1 + self.offsetting_special_char + def __call__(self, text, **kwargs): + """Override the default __call__ method to properly handle BOS/EOS tokens.""" + # Use our custom encode method to ensure consistent behavior + if isinstance(text, str): + tokens = self.encode(text, add_bos=self.add_bos_token, add_eos=self.add_eos_token) + return {"input_ids": torch.tensor([tokens])} + elif isinstance(text, list): + tokens_list = [self.encode(t, add_bos=self.add_bos_token, add_eos=self.add_eos_token) for t in text] + return {"input_ids": torch.tensor(tokens_list)} + else: + # Fallback to parent implementation + return super().__call__(text, **kwargs) + __all__ = ["BLTTokenizer"] \ No newline at end of file From a260bb1bd298eeff9c62bc23f3fa89387873b0ca Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 2 Jul 2025 14:58:46 +0000 Subject: [PATCH 053/139] missed file --- src/demo_hf.py | 6 +- .../models/blt_wip/modeling_blt.py | 57 ++++++++++++++++++- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index e0a3ef369f06..3f87f87da842 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -30,10 +30,12 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): causal_lm.lm_head.weight = blt_model.local_decoder.lm_head.weight causal_lm.save_pretrained( "./blt-1b-causallm") - # model = causal_lm - # model = model.to(device) + # TRUE causal_lm.lm_head.weight == blt_model.local_decoder.lm_head.weight + model = BLTForCausalLM.from_pretrained("./blt-1b-causallm").to(device) + # FALSE model.lm_head.weight != blt_model.local_decoder.lm_head.weight + tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) input_ids = torch.tensor([tokenizer.encode(prompt, add_eos=False)]).to(device) diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 01b48351e4c9..9c98f0add3e1 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -39,6 +39,9 @@ BLTPatcherConfig, ) +from ...generation.utils import GenerationMixin +from ...modeling_outputs import CausalLMOutputWithPast + if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask from ...integrations.flex_attention import make_flex_block_causal_mask @@ -669,6 +672,9 @@ def __init__(self, config: BLTLocalDecoderConfig): bias=False, ) + z = 5 + z = 5+1 + def forward( self, @@ -711,7 +717,8 @@ def forward( layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) hidden_states = layer_outputs[0] - logits = self.lm_head(self.norm(hidden_states)) + logits = self.norm(hidden_states) + logits = self.lm_head(logits) return logits, cache @@ -1158,6 +1165,53 @@ def patch_lengths_from_entropies( return patch_lengths + +class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): + def __init__(self, config): + super().__init__(config) + self.model = BLTModel(config) + self.vocab_size = config.vocab_size + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + **kwargs, + ): + logits = self.model(input_ids) + + return CausalLMOutputWithPast( + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + if past_key_values is not None: + input_ids = input_ids[:, -1:] + return {"input_ids": input_ids, "past_key_values": past_key_values} + + def get_input_embeddings(self): + return self.model.local_encoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.local_encoder.embed_tokens = value + + def get_output_embeddings(self): + return self.model.local_decoder.lm_head + + def set_output_embeddings(self, new_embeddings): + self.model.local_decoder.lm_head = new_embeddings + __all__ = [ "BLTPreTrainedModel", "BLTModel", @@ -1166,4 +1220,5 @@ def patch_lengths_from_entropies( "BLTLocalDecoder", "BLTGlobalTransformer", "BLTTransformerLayer", + "BLTForCausalLM", ] \ No newline at end of file From 0b9db70fb5b473196387fcc4c63275e79c5788ab Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 2 Jul 2025 15:14:25 +0000 Subject: [PATCH 054/139] added tied weights keys --- .../models/blt_wip/modeling_blt_dev.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/transformers/models/blt_wip/modeling_blt_dev.py b/src/transformers/models/blt_wip/modeling_blt_dev.py index c08fc10d9731..3e6a12bd9b7c 100644 --- a/src/transformers/models/blt_wip/modeling_blt_dev.py +++ b/src/transformers/models/blt_wip/modeling_blt_dev.py @@ -911,6 +911,10 @@ def _init_weights(self, module): emb_std = module.config.hidden_size ** (-0.5) module.embed_tokens._custom_std = emb_std module.lm_head._custom_std = emb_std + + elif isinstance(module, BLTForCausalLM): + if module.lm_head is not None: + module.lm_head._custom_std = module.config.decoder_config.hidden_size ** (-0.5) class BLTModel(BLTPreTrainedModel): @@ -1164,6 +1168,12 @@ def patch_lengths_from_entropies( class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + _tied_weights_keys = ["lm_head.weight"] + def __init__(self, config): super().__init__(config) self.model = BLTModel(config) @@ -1213,6 +1223,12 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + __all__ = [ "BLTPreTrainedModel", "BLTModel", From 107e26d549125b90daab95db7fa5e86e37d70c21 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 3 Jul 2025 13:25:37 +0000 Subject: [PATCH 055/139] BLTForCausalLM --- src/convert_blt_to_hf.py | 25 +- src/demo_hf.py | 13 +- .../models/blt_wip/configuration_blt.py | 14 +- .../models/blt_wip/modeling_blt.py | 149 +- .../models/blt_wip/modeling_blt_dev.py | 1241 ----------------- .../models/blt_wip/tokenization_blt.py | 15 - 6 files changed, 130 insertions(+), 1327 deletions(-) delete mode 100644 src/transformers/models/blt_wip/modeling_blt_dev.py diff --git a/src/convert_blt_to_hf.py b/src/convert_blt_to_hf.py index e0cd418183a8..26c05477a169 100644 --- a/src/convert_blt_to_hf.py +++ b/src/convert_blt_to_hf.py @@ -10,6 +10,7 @@ from transformers.models.blt_wip.configuration_blt import BLTConfig from transformers.models.blt_wip.modeling_blt import BLTModel +from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM from transformers.utils import logging as transformers_logging @@ -156,6 +157,8 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str "global_config": global_config, } + main_config_dict["tie_word_embeddings"] = False + logger.info(f"Merged configuration with {len(main_config_dict)} parameters") return main_config_dict @@ -203,8 +206,6 @@ def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, tor elif "state_dict" in entropy_weights: entropy_weights = entropy_weights["state_dict"] - logger.info(f"Loaded entropy model weights: {len(entropy_weights)} tensors") - unified_weights = main_weights.copy() for key, tensor in entropy_weights.items(): @@ -213,6 +214,22 @@ def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, tor unified_weights = apply_weight_mapping(unified_weights) + decoder_lm_head_key = "local_decoder.lm_head.weight" + top_lm_head_key = "lm_head.weight" + unified_weights[top_lm_head_key] = unified_weights[decoder_lm_head_key] + del unified_weights[decoder_lm_head_key] + + prefixed_weights = {} + for key, tensor in unified_weights.items(): + if key == top_lm_head_key: + prefixed_weights[key] = tensor + elif not key.startswith("model."): + prefixed_weights[f"model.{key}"] = tensor + else: + prefixed_weights[key] = tensor + + unified_weights = prefixed_weights + return unified_weights @@ -233,8 +250,6 @@ def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): with open(tokenizer_path, "w") as f: json.dump(tokenizer_config, f, indent=2) - logger.info(f"Tokenizer config saved to {tokenizer_path}") - def push_to_hub( local_dir: str, @@ -344,7 +359,7 @@ def main(): parser.add_argument( "--push_to_hub", type=str, - default="itazap/blt-1b-converted", + default=None, ) parser.add_argument( "--hub_private", diff --git a/src/demo_hf.py b/src/demo_hf.py index 3f87f87da842..7c64e47b6723 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -23,19 +23,8 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): device = "cuda" - blt_model = BLTModel.from_pretrained("itazap/blt-1b-converted") - - causal_lm = BLTForCausalLM(blt_model.config) - causal_lm.model.load_state_dict(blt_model.state_dict(), strict=False) - causal_lm.lm_head.weight = blt_model.local_decoder.lm_head.weight - causal_lm.save_pretrained( "./blt-1b-causallm") - - # TRUE causal_lm.lm_head.weight == blt_model.local_decoder.lm_head.weight - - model = BLTForCausalLM.from_pretrained("./blt-1b-causallm").to(device) + model = BLTForCausalLM.from_pretrained("itazap/blt-1b").to(device) - # FALSE model.lm_head.weight != blt_model.local_decoder.lm_head.weight - tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) input_ids = torch.tensor([tokenizer.encode(prompt, add_eos=False)]).to(device) diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt_wip/configuration_blt.py index 821d4f8202b4..a9a245afebc3 100644 --- a/src/transformers/models/blt_wip/configuration_blt.py +++ b/src/transformers/models/blt_wip/configuration_blt.py @@ -321,23 +321,15 @@ def __init__( encoder_config=None, decoder_config=None, global_config=None, - # Generation configuration - bos_token_id=1, - eos_token_id=2, - pad_token_id=-1, + tie_word_embeddings=False, **kwargs, ): # Basic model configuration + self.tie_word_embeddings = tie_word_embeddings self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings - # Generation configuration - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.return_dict = True - # Patching configuration self.patch_in_forward = patch_in_forward self.patch_size = patch_size @@ -387,7 +379,7 @@ def __init__( elif isinstance(global_config, BLTGlobalTransformerConfig): self.global_config = global_config - super().__init__(**kwargs) + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) __all__ = [ "BLTConfig", diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt_wip/modeling_blt.py index 9c98f0add3e1..84d874daa76e 100644 --- a/src/transformers/models/blt_wip/modeling_blt.py +++ b/src/transformers/models/blt_wip/modeling_blt.py @@ -666,14 +666,11 @@ def __init__(self, config: BLTLocalDecoderConfig): BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) - self.lm_head = nn.Linear( - config.hidden_size, - config.vocab_size, - bias=False, - ) - - z = 5 - z = 5+1 + # self.lm_head = nn.Linear( + # config.hidden_size, + # config.vocab_size, + # bias=False, + # ) def forward( @@ -718,7 +715,7 @@ def forward( hidden_states = layer_outputs[0] logits = self.norm(hidden_states) - logits = self.lm_head(logits) + # logits = self.lm_head(logits) return logits, cache @@ -914,24 +911,24 @@ def _init_weights(self, module): emb_std = module.config.hidden_size ** (-0.5) module.embed_tokens._custom_std = emb_std module.lm_head._custom_std = emb_std + + elif isinstance(module, BLTForCausalLM): + if module.lm_head is not None: + module.lm_head._custom_std = module.config.decoder_config.hidden_size ** (-0.5) class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) - self.config = config - self.local_encoder = BLTLocalEncoder(config.encoder_config) self.global_transformer = BLTGlobalTransformer(config.global_config) self.local_decoder = BLTLocalDecoder(config.decoder_config) - self.encoder_hash_tok_embedding = init_hash_embeddings( config, local_encoder_dim=config.encoder_config.hidden_size, encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, ) - if self.config.patch_in_forward: self.patcher = BLTPatcher(config.patcher_config) self.patcher.eval() @@ -940,9 +937,30 @@ def __init__(self, config: BLTConfig): else: self.patcher = None - def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None): + def forward( + self, + tokens: torch.Tensor, + patch_lengths: Optional[torch.Tensor] = None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + **kwargs, + ): + """ + Args: + tokens (torch.Tensor): Input token ids. + patch_lengths (Optional[torch.Tensor]): Patch lengths for patching. + attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Ignored, for compatibility. + Returns: + torch.Tensor: Final hidden states (as before). + """ batch_size, sequence_length = tokens.shape - # Handle patching if patch_lengths is None: if self.config.patching_mode == PatchingModeEnum.entropy: @@ -955,24 +973,20 @@ def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = device=tokens.device, ) else: - # Default to byte-level patching patch_lengths = process_patch_lengths( torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device), self.config.max_patch_length ) - patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32 ) - encoder_embeds = compute_hash_embeddings( tokens, self.local_encoder, self.encoder_hash_tok_embedding, self.config.encoder_hash_byte_group_nb_functions, self.config.encoder_hash_byte_group_size, self.config.encoder_hash_byte_group_vocab, ) - encoder_hidden_states, encoder_cross_states = self.local_encoder( input_ids=tokens, input_embeds=encoder_embeds, @@ -982,18 +996,14 @@ def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = num_patches=patch_lengths.shape[1], patch_ids=patch_ids, ) - global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - global_hidden_states, _ = self.global_transformer( input_embeds=global_hidden_states, ) - decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32 ) - output, _ = self.local_decoder( tokens=tokens, embeds=encoder_hidden_states, @@ -1002,7 +1012,11 @@ def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = cross_mask=cross_attn_mask_dec, full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, ) - + if output_hidden_states or output_attentions: + if return_dict: + return {"last_hidden_state": output, "hidden_states": None, "attentions": None} + else: + return (output, None, None) return output def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: @@ -1167,10 +1181,35 @@ def patch_lengths_from_entropies( class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + def __init__(self, config): super().__init__(config) self.model = BLTModel(config) self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.local_encoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.local_encoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model def forward( self, @@ -1183,35 +1222,59 @@ def forward( use_cache=None, output_attentions=None, output_hidden_states=None, + return_dict=None, cache_position=None, **kwargs, ): - logits = self.model(input_ids) - + """ + Args: + input_ids (torch.LongTensor): Input token ids. + attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Standard transformers arguments. + labels (torch.LongTensor, optional): Labels for language modeling loss. + Returns: + CausalLMOutputWithPast or tuple: Standard transformers output. + """ + # Route only input_ids to BLTModel (as tokens) + hidden_states = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + if isinstance(hidden_states, dict): + sequence_output = hidden_states["last_hidden_state"] + elif isinstance(hidden_states, tuple): + sequence_output = hidden_states[0] + else: + sequence_output = hidden_states + logits = self.lm_head(sequence_output) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + if not return_dict: + output = (logits,) + if loss is not None: + output = (loss,) + output + return output return CausalLMOutputWithPast( + loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None, ) - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - if past_key_values is not None: - input_ids = input_ids[:, -1:] - return {"input_ids": input_ids, "past_key_values": past_key_values} - - def get_input_embeddings(self): - return self.model.local_encoder.embed_tokens - - def set_input_embeddings(self, value): - self.model.local_encoder.embed_tokens = value - - def get_output_embeddings(self): - return self.model.local_decoder.lm_head - - def set_output_embeddings(self, new_embeddings): - self.model.local_decoder.lm_head = new_embeddings - __all__ = [ "BLTPreTrainedModel", "BLTModel", diff --git a/src/transformers/models/blt_wip/modeling_blt_dev.py b/src/transformers/models/blt_wip/modeling_blt_dev.py deleted file mode 100644 index 3e6a12bd9b7c..000000000000 --- a/src/transformers/models/blt_wip/modeling_blt_dev.py +++ /dev/null @@ -1,1241 +0,0 @@ -# coding=utf-8 -# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""BLT model.""" - -from ...utils import is_torch_flex_attn_available, logging -from typing import Callable, List, Optional, Tuple, Union - -from enum import Enum - -from ...cache_utils import Cache -from ...activations import ACT2FN - -import torch -import torch.distributions -import torch.nn -import torch.nn as nn -from torch.nn import functional as F - -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update - -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from .configuration_blt import ( - BLTConfig, - BLTLocalEncoderConfig, - BLTLocalDecoderConfig, - BLTGlobalTransformerConfig, - BLTPatcherConfig, -) - -from ...generation.utils import GenerationMixin -from ...modeling_outputs import CausalLMOutputWithPast - -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - from ...integrations.flex_attention import make_flex_block_causal_mask - - -logger = logging.get_logger(__name__) - -class PatchingModeEnum(str, Enum): - entropy = "entropy" - bpe = "bpe" - bpe_patcher = "bpe_patcher" - space = "space" - static = "static" - byte = "byte" - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = F.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - # TODO: not exactly equivalent to other transformers implementations,, need feedback - # Extract first head_dim//2 elements which correspond to the unique frequencies - # This matches the original BLT approach which uses head_dim//2 frequency pairs - head_dim = q.shape[-1] - cos_freqs = cos[..., :head_dim//2] # [B, S, D/2] - sin_freqs = sin[..., :head_dim//2] # [B, S, D/2] - - # Expand cos/sin to match query/key tensor format [B, H, S, D/2] - cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] - sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] - - # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... - q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] - k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] - - # Extract real and i parts - q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] - k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] - - # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] - q_real_rot = cos_freqs * q_real - sin_freqs * q_imag - q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag - k_real_rot = cos_freqs * k_real - sin_freqs * k_imag - k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag - - # Recombine pairs and reshape back to original format - q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] - k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] - - return q_rot.type_as(q), k_rot.type_as(k) - - -class BLTMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class BLTRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - BLTRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class BLTTransformerLayer(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - self.layer_idx = layer_idx - - self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) - self.mlp = BLTMLP(config) - self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) - self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - position_ids (`torch.LongTensor`, *optional*): - Position indices of tokens in the sequence for RoPE computation. - past_key_value (`Cache`, *optional*): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class BLTSelfAttention(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - self.config = config - self.num_heads = config.num_attention_heads - self.dropout = config.dropout - self.hidden_size = config.hidden_size - self.num_key_value_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // self.num_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = None - self.rope_theta = config.rope_theta - self.layer_idx = layer_idx - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, - output_attentions: bool = False, - use_cache: bool = False, - past_key_value=None, - cache_position=None, - **kwargs, - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - output_attentions = False - self.config._attn_implementation = "sdpa" - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): - primes = [ - 1000000007, 5915587277, 1500450271, 3267000013, 5754853343, - 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313, - ] - prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) - powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) - prime_powers = prime ** powers - return torch.sum(token_tensor * prime_powers, dim=-1) - - -def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): - """Hash token groups and map to range [0, max_hash].""" - with torch.no_grad(): - batch_size, seq_len = token_ids.shape - # Add padding for sliding window - padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) - padded_tokens = torch.cat([padding, token_ids], dim=1) - - # Create sliding windows and compute hashes - windows = padded_tokens.unfold(1, group_size, 1) - hashes = rolling_polynomial_hash(windows, hash_func_nb) - hash_values = hashes % max_hash - - hash_values.requires_grad = False - return hash_values - - -def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): - """Initialize hash-based token embeddings for the BLT encoder.""" - num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) - embeddings = [ - nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) - for _ in range(num_embeddings) - ] - return nn.ModuleList(embeddings) - - -def compute_hash_embeddings( - local_encoder_tokens: torch.Tensor, - local_encoder, - encoder_hash_tok_embedding: nn.ModuleList, - encoder_hash_byte_group_nb_functions: int, - encoder_hash_byte_group_size: list, - encoder_hash_byte_group_vocab: int, -) -> torch.Tensor: - """Compute token embeddings enhanced with hash-based embeddings.""" - embeddings = local_encoder.embed_tokens(local_encoder_tokens) - embedding_idx = 0 - for func_nb in range(encoder_hash_byte_group_nb_functions): - for group_size in encoder_hash_byte_group_size: - hash_ids = byte_group_hash_function( - local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab - ) - embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) - embedding_idx += 1 - - return embeddings - - -def _prepare_patch_cross_attention_mask( - patch_ids: torch.Tensor, - num_patches: int, - sequence_length: int, - patches_as_queries: bool = False, - cross_attn_k: int = 1, - dtype: torch.dtype = torch.float32, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Prepare cross-attention mask for patch-based attention, following mllama's robust approach. - - This function creates masks that control which patches can attend to which other patches, - with support for query/key role swapping and cross-attention multipliers. - - Args: - patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. - num_patches (int): Total number of patches. - sequence_length (int): Length of the sequence. - patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. - cross_attn_k (int): Cross-attention multiplier for repeating patches. - dtype (torch.dtype): Data type for the output mask. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] - - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows - """ - batch_size, seq_len = patch_ids.shape - device = patch_ids.device - - # Determine query and key lengths based on configuration - if patches_as_queries: - q_len = num_patches * cross_attn_k - kv_len = sequence_length - # Create patch-to-sequence mapping - q_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(-1).expand( - batch_size, num_patches, seq_len - ) - kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) - else: - q_len = sequence_length - kv_len = num_patches * cross_attn_k - # Create sequence-to-patch mapping - q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) - kv_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand( - batch_size, seq_len, num_patches - ) - - # Create base attention mask - boolean mask where True means "should attend" - # Exact patch matching - cross_attention_mask = q_patch_ids == kv_patch_ids - - # Handle cross_attn_k multiplier by repeating along appropriate dimension - repeat_dim = 1 if patches_as_queries else -1 - cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) - - # Validate dimensions - expected_shape = (batch_size, q_len, kv_len) - if cross_attention_mask.shape != expected_shape: - raise ValueError(f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}") - - # Reshape so it can be used by attn module - add head dimension - cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len] - - # Invert the mask (following mllama pattern exactly) - # True -> 0.0 (attend), False -> 1.0 (will become -inf) - inverted_cross_attn_mask = (1.0 - cross_attention_mask.to(dtype)) - cross_attention_mask = inverted_cross_attn_mask.masked_fill( - inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min - ) - - # Apply full-row bias (following mllama pattern exactly) - # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's - # last dimension contains negative infinity values, otherwise it's 1 - negative_inf_value = torch.finfo(dtype).min - full_text_row_masked_out_mask = ( - (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] - ) - cross_attention_mask *= full_text_row_masked_out_mask - - return cross_attention_mask, full_text_row_masked_out_mask - - -def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: - """ - Splits patch lengths into smaller segments if they exceed `max_patch_length`. - Pads the result to uniform length across the batch. - - Args: - patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. - max_patch_length (int, optional): Maximum allowed length per patch. - - Returns: - torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. - """ - if max_patch_length is None: - return patch_lengths - - batch_size = patch_lengths.size(0) - processed = [] - - for seq in patch_lengths: - splits = [] - for length in seq[seq > 0]: - length = length.item() - full_chunks, remainder = divmod(length, max_patch_length) - splits.extend([max_patch_length] * full_chunks) - if remainder: - splits.append(remainder) - processed.append(splits) - - # Find max length to pad to - max_len = max(len(splits) for splits in processed) - padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) - - for i, splits in enumerate(processed): - if splits: - padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) - - # Trim zero columns - if (padded != 0).any(dim=0).sum() < padded.shape[1]: - last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 - padded = padded[:, :last_nonzero] - - return padded - - -class BLTRotaryEmbedding(nn.Module): - def __init__(self, config, device=None): - super().__init__() - self.rope_type = config.rope_scaling["rope_type"] - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class BLTLocalEncoder(nn.Module): - def __init__(self, config: BLTLocalEncoderConfig): - super().__init__() - - self.config = config - - self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) - - self.rotary_emb = BLTRotaryEmbedding(config=config) - - self.patch_embedding_projection = nn.Linear( - in_features=config.hidden_size, - out_features=config.hidden_size * config.cross_attn_k, - bias=False, - ) - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - - self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 - for layer_idx in range(layers_to_add): - self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) - ) - - def forward( - self, - input_ids: torch.Tensor, - input_embeds: Optional[torch.Tensor] = None, - patch_embeds: Optional[torch.Tensor] = None, - mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, - cross_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - num_patches: Optional[int] = None, - patch_ids: Optional[torch.Tensor] = None, - cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, - ): - """ """ - if input_embeds is None: - input_embeds = self.embed_tokens(input_ids) - - batch_size, _, _ = input_embeds.shape - - hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) - - position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) - - for idx, layer in enumerate(self.layers): - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) - hidden_states = layer_outputs[0] - - if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: - patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) - patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) - - layer_idx = idx if self.config.cross_attn_all_layers else 0 - cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( - hidden_states=patch_embeds, - cross_attention_states=hidden_states, - attention_mask=cross_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - output_attentions=False, - use_cache=False, - cache_position=None, - ) - patch_embeds = patch_embeds + cross_attention_output - - encoder_cross_states = patch_embeds - return hidden_states, encoder_cross_states - - def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): - """ - Reduce variable length patches to single embedding per patch - Note: this works with variable number of patches for different sequences in the batch - It handles variable length patches by assuming that patch_lengths will be 0 for any - extra patches on the *right*. Since there can be a variable number of patches - this function also return the number of patches for each sequence in the batch. - Any embeddings on the right that are not allocated to a patch - (i.e. if the sum(patch_lengths[i]) < seq_len for any i) - will be sent to a dummy patch, which is trimmed before returning. - """ - batch_size, _, embedding_dim = hidden_states.shape - - patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) - - reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device) - reduced_embeddings = reduced_embeddings.scatter_reduce( - src=hidden_states, - dim=1, - index=patch_ids, - reduce=reduction, - include_self=False, - ) - reduced_embeddings = reduced_embeddings[:, :max_num_patches, :] - - return reduced_embeddings - - -class BLTLocalDecoder(nn.Module): - def __init__(self, config: BLTLocalDecoderConfig): - super().__init__() - - # Extract config values to instance attributes - self.config = config - self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove - - self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) - - self.rotary_emb = BLTRotaryEmbedding(config=config) - - self.patch_embedding_projection = nn.Linear( - in_features=config.hidden_size_global, - out_features=config.hidden_size * config.cross_attn_k, - bias=False, - ) - - self.norm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) - - self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 - for layer_idx in range(layers_to_add): - self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) - ) - - # self.lm_head = nn.Linear( - # config.hidden_size, - # config.vocab_size, - # bias=False, - # ) - - - def forward( - self, - tokens: torch.Tensor, - embeds: Optional[torch.Tensor], - patch_embeds: Optional[torch.Tensor] = None, - mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, - cross_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, - ): - batch_size, _, _ = embeds.shape - - hidden_states = embeds - - patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) - - if patch_embeds is not None and not self.cross_attn_decoder: - hidden_states = hidden_states + patch_embeds - - position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) - for i, layer in enumerate(self.layers): - if i == 0 or self.config.cross_attn_all_layers: - # Use cross attention to extract info from patch_embeds into hidden_states - cross_attention_output, _, _ = self.cross_attn_layers[i]( - hidden_states=hidden_states, - cross_attention_states=patch_embeds, - attention_mask=cross_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - output_attentions=False, - use_cache=False, - cache_position=None, - ) - hidden_states = hidden_states + cross_attention_output - - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) - hidden_states = layer_outputs[0] - - logits = self.norm(hidden_states) - # logits = self.lm_head(logits) - return logits, cache - - -class BLTCrossAttention(nn.Module): - """Cross-attention module for BLT, following transformers style""" - - def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - # Use provided hidden_size or fallback to encoder dimension - self.hidden_size = hidden_size or config.encoder_config.hidden_size - self.num_heads = config.num_attention_heads - self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = None #self.head_dim ** -0.5 - self.dropout = config.dropout - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.q_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) - self.k_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_norm(hidden_states) # BLT normalizes first - query_states = self.q_proj(query_states) - - if cross_attention_states is not None: - cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - if past_key_value is not None: - # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states - # we still update the cross key states, past_cross_states, new_cross_states. And use it! - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - elif cache_position is not None and cache_position[0] != 0: - key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], - ) - else: - if cross_attention_states is None: - raise ValueError( - "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!" - ) - - attention_interface: Callable = eager_attention_forward - - self.config._attn_implementation = "sdpa" - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if full_text_row_masked_out_mask is not None: - attn_output = full_text_row_masked_out_mask[:, 0] * attn_output - - attn_output = attn_output + hidden_states - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class BLTGlobalTransformer(nn.Module): - def __init__(self, config: BLTGlobalTransformerConfig): - super().__init__() - - self.config = config - - self.layers = nn.ModuleList() - for layer_idx in range(config.num_hidden_layers): - self.layers.append(BLTTransformerLayer(config, layer_idx)) - - self.rotary_emb = BLTRotaryEmbedding(config=config) - - - def forward( - self, - input_embeds: torch.Tensor, - mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, - cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, - ): - batch_size, seq_len, _ = input_embeds.shape - - hidden_states = input_embeds - - hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) - - position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - for i, layer in enumerate(self.layers): - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) - hidden_states = layer_outputs[0] - - return hidden_states, cache - - - - -class BLTPreTrainedModel(PreTrainedModel): - config_class = BLTConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = False # BLT uses its own attention implementation - _supports_sdpa = True - _supports_cache_class = False - - def _init_weights(self, module): - if isinstance(module, nn.Linear): - std = getattr(module, '_custom_std', module.in_features ** (-0.5)) - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - if module.bias is not None: - nn.init.zeros_(module.bias) - - elif isinstance(module, nn.Embedding): - std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - - elif isinstance(module, BLTModel): - if module.encoder_hash_tok_embedding is not None: - emb_std = module.config.encoder_config.hidden_size ** (-0.5) - for emb in module.encoder_hash_tok_embedding: - emb._custom_std = emb_std - - elif isinstance(module, BLTLocalEncoder): - if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) - - elif isinstance(module, BLTLocalDecoder): - if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) - - elif isinstance(module, BLTPatcher): - emb_std = module.config.hidden_size ** (-0.5) - module.embed_tokens._custom_std = emb_std - module.lm_head._custom_std = emb_std - - elif isinstance(module, BLTForCausalLM): - if module.lm_head is not None: - module.lm_head._custom_std = module.config.decoder_config.hidden_size ** (-0.5) - - -class BLTModel(BLTPreTrainedModel): - def __init__(self, config: BLTConfig): - super().__init__(config) - - self.config = config - - self.local_encoder = BLTLocalEncoder(config.encoder_config) - self.global_transformer = BLTGlobalTransformer(config.global_config) - self.local_decoder = BLTLocalDecoder(config.decoder_config) - - self.encoder_hash_tok_embedding = init_hash_embeddings( - config, - local_encoder_dim=config.encoder_config.hidden_size, - encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, - ) - - if self.config.patch_in_forward: - self.patcher = BLTPatcher(config.patcher_config) - self.patcher.eval() - for param in self.patcher.parameters(): - param.requires_grad = False - else: - self.patcher = None - - def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None): - batch_size, sequence_length = tokens.shape - - # Handle patching - if patch_lengths is None: - if self.config.patching_mode == PatchingModeEnum.entropy: - _, patch_lengths, _ = self.patcher( - tokens, - patch_size=self.config.patch_size, - threshold=self.config.patching_threshold, - max_patch_length=self.config.max_patch_length, - patching_batch_size=self.config.patching_batch_size, - device=tokens.device, - ) - else: - # Default to byte-level patching - patch_lengths = process_patch_lengths( - torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device), - self.config.max_patch_length - ) - - patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) - cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( - patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32 - ) - - encoder_embeds = compute_hash_embeddings( - tokens, self.local_encoder, self.encoder_hash_tok_embedding, - self.config.encoder_hash_byte_group_nb_functions, - self.config.encoder_hash_byte_group_size, - self.config.encoder_hash_byte_group_vocab, - ) - - encoder_hidden_states, encoder_cross_states = self.local_encoder( - input_ids=tokens, - input_embeds=encoder_embeds, - patch_embeds=None, - cross_mask=cross_attn_mask_enc, - full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, - num_patches=patch_lengths.shape[1], - patch_ids=patch_ids, - ) - - global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - - global_hidden_states, _ = self.global_transformer( - input_embeds=global_hidden_states, - ) - - decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) - cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( - decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32 - ) - - output, _ = self.local_decoder( - tokens=tokens, - embeds=encoder_hidden_states, - patch_embeds=global_hidden_states, - mask=None, - cross_mask=cross_attn_mask_dec, - full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, - ) - - return output - - def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: - """Convert patch lengths to patch IDs for each token position.""" - batch_size = patch_lengths.shape[0] - patch_starts = torch.cat([ - torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), - patch_lengths.cumsum(dim=-1)[:, :-1] - ], dim=-1) - - token_positions = torch.arange(seq_len, device=patch_lengths.device) - return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 - - -class BLTPatcher(BLTPreTrainedModel): - def __init__(self, config: BLTPatcherConfig): - super().__init__(config) - - self.rotary_emb = BLTRotaryEmbedding(config=self.config) - - self.layers = nn.ModuleList() - - for layer_idx in range(self.config.num_hidden_layers): - self.layers.append(BLTTransformerLayer(self.config, layer_idx)) - - - self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) - - self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps) - - self.lm_head = nn.Linear( - self.config.hidden_size, - self.config.vocab_size, - bias=False, - ) - - def forward( - self, - token_values: torch.Tensor, - patch_size: Optional[int] = None, - threshold: Optional[float] = None, - max_patch_length: Optional[int] = None, - patching_batch_size: int = 1, - device: Optional[str] = None, - ): - - # Handle chunked processing for entropy calculation - entropies = [] - predictions = [] - max_length = self.config.max_position_embeddings - batch_numel = max_length * patching_batch_size - splits = torch.split(token_values.flatten(), batch_numel) - - for split in splits: - pad_size = (max_length - (split.numel() % max_length)) % max_length - pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False) - split = torch.cat((split, pad), dim=0) - split = split.reshape(-1, max_length) - if device is not None: - split = split.to(device) - - # Process chunk: embeddings -> layers -> output - batch_size, sequence_length = split.shape - input_embeds = self.embed_tokens(split) - - hidden_states = input_embeds - - batch_size, _, _ = input_embeds.shape - - position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - for i, layer in enumerate(self.layers): - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) - hidden_states = layer_outputs[0] - - logits = self.lm_head(self.norm(hidden_states)) - logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] - predictions.append(logits) - prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() - entropies.append(prediction_entropies) - - concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) - concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) - - # Always compute patch lengths from concatenated entropies - batch_size, sequence_length = token_values.shape - - # Find patch start IDs based on entropy - if patch_size is not None: - patch_lengths = self.patch_lengths_from_entropies( - entropies=concat_entropies, - sequence_length=sequence_length, - patch_size=patch_size, - threshold=threshold, - ) - else: - # Default to byte-level patching - patch_lengths = torch.ones((batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device) - patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) - return concat_entropies, patch_lengths, concat_predictions - - @staticmethod - def patch_lengths_from_entropies( - entropies, - sequence_length, - patch_size=None, - threshold=None, - ): - """ - Computes patch lengths from token entropies. - - Depending on whether a threshold is provided, the function uses either: - - Top-k selection based on entropy (when `threshold` is None), or - - Thresholding the entropy values (when `threshold` is set). - """ - - batch_size = entropies.shape[0] - - # Always include token 0 and 1 as starting tokens - init_tokens = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) - offset = init_tokens.shape[1] - - # Ignore first token entropy (BOS) - entropies = entropies[:, 1:] - - if threshold is None: - # Use top-k entropy values to define patch start points - num_patches = sequence_length // patch_size - topk_indices = entropies.topk(num_patches - 2, dim=1).indices - patch_starts = topk_indices.sort(dim=1).values - else: - # Threshold the entropy values to define patch start points - patch_mask = entropies > threshold - - seq_len = patch_mask.shape[1] - - # Create patch IDs (token indices), and add a sentinel to ensure alignment - token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) - sentinel = torch.full_like(token_indices, seq_len) - padded_indices = torch.cat([token_indices, sentinel], dim=1) - - # Pad mask with inverse to align sentinel correctly - padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) - - # Select indices where mask is True - patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) - max_valid_patches = patch_mask.sum(dim=1).max() - patch_starts = patch_starts[:, :max_valid_patches] - - # Offset patch starts to account for the two initial tokens - patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) - - # Compute patch end positions by shifting start positions - last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1) - patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1) - - patch_lengths = patch_ends - patch_start_ids + 1 - - return patch_lengths - - -class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): - config_class = BLTConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = BLTModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) - self.post_init() - - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - cache_position=None, - **kwargs, - ): - logits = self.model(input_ids) - - logits = self.lm_head(logits) - - return CausalLMOutputWithPast( - logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - if past_key_values is not None: - input_ids = input_ids[:, -1:] - return {"input_ids": input_ids, "past_key_values": past_key_values} - - def get_input_embeddings(self): - return self.model.local_encoder.embed_tokens - - def set_input_embeddings(self, value): - self.model.local_encoder.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - -__all__ = [ - "BLTPreTrainedModel", - "BLTModel", - "BLTPatcher", - "BLTLocalEncoder", - "BLTLocalDecoder", - "BLTGlobalTransformer", - "BLTTransformerLayer", - "BLTForCausalLM", -] \ No newline at end of file diff --git a/src/transformers/models/blt_wip/tokenization_blt.py b/src/transformers/models/blt_wip/tokenization_blt.py index dfbe1602583c..ff4004f6261b 100644 --- a/src/transformers/models/blt_wip/tokenization_blt.py +++ b/src/transformers/models/blt_wip/tokenization_blt.py @@ -14,8 +14,6 @@ # limitations under the License. """Tokenization classes for BLT.""" -import os -import torch from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from ...tokenization_utils import AddedToken, PreTrainedTokenizer @@ -270,17 +268,4 @@ def get_vocab_size(self) -> int: """Get vocab size like the original tokenizer.""" return self.vocab_size_unit_1 + self.offsetting_special_char - def __call__(self, text, **kwargs): - """Override the default __call__ method to properly handle BOS/EOS tokens.""" - # Use our custom encode method to ensure consistent behavior - if isinstance(text, str): - tokens = self.encode(text, add_bos=self.add_bos_token, add_eos=self.add_eos_token) - return {"input_ids": torch.tensor([tokens])} - elif isinstance(text, list): - tokens_list = [self.encode(t, add_bos=self.add_bos_token, add_eos=self.add_eos_token) for t in text] - return {"input_ids": torch.tensor(tokens_list)} - else: - # Fallback to parent implementation - return super().__call__(text, **kwargs) - __all__ = ["BLTTokenizer"] \ No newline at end of file From d9d6d7303e8e4f7faec890ae2a6b27253c063e8d Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 16 May 2025 09:17:51 +0000 Subject: [PATCH 056/139] adding files after add-new-model-like --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/blt.md | 102 ++ .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/blt/__init__.py | 0 .../{blt_wip => blt}/configuration_blt.py | 0 .../models/blt/convert_blt_weights_to_hf.py | 397 ++++++++ .../models/{blt_wip => blt}/modeling_blt.py | 0 .../models/{blt_wip => blt}/modular_blt.py | 0 .../{blt_wip => blt}/tokenization_blt.py | 0 tests/models/blt/__init__.py | 0 tests/models/blt/test_modeling_blt.py | 930 ++++++++++++++++++ tests/models/blt/test_tokenization_blt.py | 914 +++++++++++++++++ 13 files changed, 2349 insertions(+) create mode 100644 docs/source/en/model_doc/blt.md create mode 100644 src/transformers/models/blt/__init__.py rename src/transformers/models/{blt_wip => blt}/configuration_blt.py (100%) create mode 100644 src/transformers/models/blt/convert_blt_weights_to_hf.py rename src/transformers/models/{blt_wip => blt}/modeling_blt.py (100%) rename src/transformers/models/{blt_wip => blt}/modular_blt.py (100%) rename src/transformers/models/{blt_wip => blt}/tokenization_blt.py (100%) create mode 100644 tests/models/blt/__init__.py create mode 100644 tests/models/blt/test_modeling_blt.py create mode 100644 tests/models/blt/test_tokenization_blt.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3d1b0b169636..062b0eb50637 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -411,6 +411,8 @@ title: Blenderbot Small - local: model_doc/bloom title: BLOOM + - local: model_doc/blt + title: BLT - local: model_doc/bort title: BORT - local: model_doc/byt5 diff --git a/docs/source/en/model_doc/blt.md b/docs/source/en/model_doc/blt.md new file mode 100644 index 000000000000..8ab1fcc5dfdd --- /dev/null +++ b/docs/source/en/model_doc/blt.md @@ -0,0 +1,102 @@ + + +
+
+ PyTorch + Flax + FlashAttention + SDPA +
+
+ +# BLT + +# BLT + +## Overview + +The BLT model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + + +## BLTConfig + +[[autodoc]] BLTConfig + +## BLTTokenizer + +[[autodoc]] BLTTokenizer + - build_inputs_with_special_tokens + - get_special_tokens_mask + - create_token_type_ids_from_sequences + - save_vocabulary + +## BLTTokenizerFast + +[[autodoc]] BLTTokenizerFast + - build_inputs_with_special_tokens + - get_special_tokens_mask + - create_token_type_ids_from_sequences + - update_post_processor + - save_vocabulary + +## BLTModel + +[[autodoc]] BLTModel + - forward + +## BLTForCausalLM + +[[autodoc]] BLTForCausalLM + - forward + +## BLTForSequenceClassification + +[[autodoc]] BLTForSequenceClassification + - forward + +## BLTForQuestionAnswering + +[[autodoc]] BLTForQuestionAnswering + - forward + +## BLTForTokenClassification + +[[autodoc]] BLTForTokenClassification + - forward + +## FlaxBLTModel + +[[autodoc]] FlaxBLTModel + - __call__ + +## FlaxBLTForCausalLM + +[[autodoc]] FlaxBLTForCausalLM + - __call__ diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 06023f09c9d8..3ddd8f367c84 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -226,6 +226,7 @@ ("lightglue", "LightGlueConfig"), ("lilt", "LiltConfig"), ("llama", "LlamaConfig"), + ("blt", "BLTConfig"), ("llama4", "Llama4Config"), ("llama4_text", "Llama4TextConfig"), ("llava", "LlavaConfig"), @@ -663,6 +664,7 @@ ("lightglue", "LightGlue"), ("lilt", "LiLT"), ("llama", "LLaMA"), + ("blt", "BLT"), ("llama2", "Llama2"), ("llama3", "Llama3"), ("llama4", "Llama4"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 025a7a1f90a0..abc4557ecfe4 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -226,6 +226,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("lightglue", "LightGlueForKeypointMatching"), ("lilt", "LiltModel"), ("llama", "LlamaModel"), + ("blt", "BLTModel"), ("llama4", "Llama4ForConditionalGeneration"), ("llama4_text", "Llama4TextModel"), ("llava", "LlavaModel"), @@ -688,6 +689,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("jetmoe", "JetMoeForCausalLM"), ("lfm2", "Lfm2ForCausalLM"), ("llama", "LlamaForCausalLM"), + ("blt", "BLTForCausalLM"), ("llama4", "Llama4ForCausalLM"), ("llama4_text", "Llama4ForCausalLM"), ("longcat_flash", "LongcatFlashForCausalLM"), diff --git a/src/transformers/models/blt/__init__.py b/src/transformers/models/blt/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/transformers/models/blt_wip/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py similarity index 100% rename from src/transformers/models/blt_wip/configuration_blt.py rename to src/transformers/models/blt/configuration_blt.py diff --git a/src/transformers/models/blt/convert_blt_weights_to_hf.py b/src/transformers/models/blt/convert_blt_weights_to_hf.py new file mode 100644 index 000000000000..26c05477a169 --- /dev/null +++ b/src/transformers/models/blt/convert_blt_weights_to_hf.py @@ -0,0 +1,397 @@ +import argparse +import json +import logging +import os +from typing import Any, Dict, Optional + +import torch +from huggingface_hub import hf_hub_download, upload_folder +from safetensors.torch import load_file, save_file + +from transformers.models.blt_wip.configuration_blt import BLTConfig +from transformers.models.blt_wip.modeling_blt import BLTModel +from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM +from transformers.utils import logging as transformers_logging + + +logger = transformers_logging.get_logger(__name__) +transformers_logging.set_verbosity_info() + + +def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]: + logger.info("Merging configurations") + + with open(config_path, "r") as f: + main_config = json.load(f) + + with open(entropy_params_path, "r") as f: + entropy_data = json.load(f) + + entropy_model_params = entropy_data.get("entropy_model", {}) + patcher_args = entropy_data.get("data", {}).get("patcher_args", {}) + + unified_config = main_config.copy()["args"] + + for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]: + if key in unified_config and not isinstance(unified_config[key], int): + unified_config[key] = int(unified_config[key]) + + patch_size = patcher_args.get("patch_size", 8) + if isinstance(patch_size, float): + patch_size = int(patch_size) + + # Create patcher config + patcher_hidden_size = int(entropy_model_params.get("dim", 512)) + patcher_multiple_of = int(entropy_model_params.get("multiple_of", 256)) + patcher_intermediate_size = patcher_multiple_of * ((int(8 * patcher_hidden_size / 3) + patcher_multiple_of - 1) // patcher_multiple_of) + + patcher_config = { + "vocab_size": int(entropy_model_params.get("vocab_size", 256)), + "hidden_size": patcher_hidden_size, + "num_hidden_layers": int(entropy_model_params.get("n_layers", 8)), + "num_attention_heads": int(entropy_model_params.get("n_heads", 8)), + "num_key_value_heads": int(entropy_model_params.get("n_kv_heads")) + if entropy_model_params.get("n_kv_heads") is not None + else None, + "max_position_embeddings": int(entropy_model_params.get("max_seqlen", 1024)), + "norm_eps": entropy_model_params.get("norm_eps", 1e-5), + "dropout": entropy_model_params.get("dropout", 0.0), + "rope_theta": entropy_model_params.get("rope_theta", 10000.0), + "attn_impl": entropy_model_params.get("attn_impl", "sdpa"), + "attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"), + "intermediate_size": patcher_intermediate_size, + } + + # Create encoder config + encoder_hidden_size = unified_config.get("dim_local_encoder", 1024) + encoder_multiple_of = unified_config.get("multiple_of", 256) + encoder_intermediate_size = encoder_multiple_of * ((int(8 * encoder_hidden_size / 3) + encoder_multiple_of - 1) // encoder_multiple_of) + + encoder_config = { + "vocab_size": unified_config.get("vocab_size", 256), + "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_encoder", False), + "cross_attn_k": unified_config.get("cross_attn_k", 2), + "hidden_size_global": unified_config.get("hidden_size_global", 2048), + "pm_size": unified_config.get("pm_size", 0), + "hidden_size": encoder_hidden_size, + "num_attention_heads": unified_config.get("n_heads_local_encoder", 16), + "num_key_value_heads": unified_config.get("n_kv_heads"), + "num_hidden_layers": unified_config.get("n_layers_local_encoder", 1), + "norm_eps": unified_config.get("norm_eps", 1e-5), + "dropout": unified_config.get("dropout", 0.0), + "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024), + "rope_theta": unified_config.get("rope_theta", 10000.0), + "rope_scaling": {"rope_type": "default"}, + "hidden_act": unified_config.get("hidden_act", "silu"), + "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), + "intermediate_size": encoder_intermediate_size, + } + + # Create decoder config + decoder_hidden_size = unified_config.get("dim_local_decoder", 1024) + decoder_multiple_of = unified_config.get("multiple_of", 256) + decoder_intermediate_size = decoder_multiple_of * ((int(8 * decoder_hidden_size / 3) + decoder_multiple_of - 1) // decoder_multiple_of) + + decoder_config = { + "vocab_size": unified_config.get("vocab_size", 256), + "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_decoder", False), + "cross_attn_k": unified_config.get("cross_attn_k", 2), + "hidden_size_global": unified_config.get("hidden_size_global", 2048), + "hidden_size": decoder_hidden_size, + "num_attention_heads": unified_config.get("n_heads_local_decoder", 16), + "num_key_value_heads": unified_config.get("n_kv_heads"), + "num_hidden_layers": unified_config.get("n_layers_local_decoder", 9), + "norm_eps": unified_config.get("norm_eps", 1e-5), + "dropout": unified_config.get("dropout", 0.0), + "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024), + "rope_theta": unified_config.get("rope_theta", 10000.0), + "rope_scaling": {"rope_type": "default"}, + "hidden_act": unified_config.get("hidden_act", "silu"), + "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), + "intermediate_size": decoder_intermediate_size, + } + + # Create global transformer config + global_hidden_size = unified_config.get("dim_global", 2048) + global_multiple_of = unified_config.get("multiple_of", 256) + global_intermediate_size = global_multiple_of * ((int(8 * global_hidden_size / 3) + global_multiple_of - 1) // global_multiple_of) + + global_config = { + "hidden_size": global_hidden_size, + "num_attention_heads": unified_config.get("n_heads_global", 16), + "num_key_value_heads": unified_config.get("n_kv_heads_global"), + "num_hidden_layers": unified_config.get("n_layers_global", 25), + "norm_eps": unified_config.get("norm_eps", 1e-5), + "dropout": unified_config.get("dropout", 0.0), + "max_position_embeddings": unified_config.get("max_seqlen", 1024), + "rope_theta": unified_config.get("rope_theta", 10000.0), + "rope_scaling": {"rope_type": "default"}, + "hidden_act": unified_config.get("hidden_act", "silu"), + "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), + "intermediate_size": global_intermediate_size, + } + + # Create main config with sub-configs + main_config_dict = { + "model_type": "blt", + "vocab_size": unified_config.get("vocab_size", 256), + "max_position_embeddings": unified_config.get("max_seqlen", 1024), + "patch_in_forward": True, + "realtime_patching": True, + "patching_mode": "entropy", + "patch_size": patch_size, + "patching_threshold": patcher_args.get("threshold", 0.5), + "patching_threshold_add": patcher_args.get("threshold_add", 0.0), + "max_patch_length": patcher_args.get("max_patch_length"), + "patching_batch_size": patcher_args.get("patching_batch_size", 1), + "patching_device": patcher_args.get("patching_device", "cuda"), + "monotonicity": patcher_args.get("monotonicity", False), + "cross_attn_k": unified_config.get("cross_attn_k", 2), + "encoder_hash_byte_group_size": unified_config.get("encoder_hash_byte_group_size"), + "encoder_hash_byte_group_vocab": unified_config.get("encoder_hash_byte_group_vocab", 30000), + "encoder_hash_byte_group_nb_functions": unified_config.get("encoder_hash_byte_group_nb_functions", 3), + "pm_size": unified_config.get("pm_size", 0), + "patcher_config": patcher_config, + "encoder_config": encoder_config, + "decoder_config": decoder_config, + "global_config": global_config, + } + + main_config_dict["tie_word_embeddings"] = False + + logger.info(f"Merged configuration with {len(main_config_dict)} parameters") + return main_config_dict + + +def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + component_mappings = { + ".attention.": ".self_attn.", + ".feed_forward.": ".mlp.", + ".attention_norm.": ".input_layernorm.", + ".ffn_norm.": ".post_attention_layernorm.", + ".tok_embeddings.": ".embed_tokens.", + ".cross_attn_norm_q.": ".q_norm.", + ".cross_attn_norm_kv.": ".k_norm.", + ".w1.": ".gate_proj.", + ".w2.": ".down_proj.", + ".w3.": ".up_proj.", + ".wq.": ".q_proj.", + ".wk.": ".k_proj.", + ".wv.": ".v_proj.", + ".wo.": ".o_proj.", + ".output.": ".lm_head.", + } + + new_state_dict = {} + + for old_key, tensor in state_dict.items(): + new_key = old_key + + for old_pattern, new_pattern in component_mappings.items(): + if old_pattern in new_key: + new_key = new_key.replace(old_pattern, new_pattern) + + new_state_dict[new_key] = tensor + + return new_state_dict + + +def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]: + main_weights = load_file(weights_path) + + entropy_weights = torch.load(entropy_weights_path, map_location="cpu", weights_only=True) + + if "model" in entropy_weights: + entropy_weights = entropy_weights["model"] + elif "state_dict" in entropy_weights: + entropy_weights = entropy_weights["state_dict"] + + unified_weights = main_weights.copy() + + for key, tensor in entropy_weights.items(): + patcher_key = f"patcher.{key}" + unified_weights[patcher_key] = tensor + + unified_weights = apply_weight_mapping(unified_weights) + + decoder_lm_head_key = "local_decoder.lm_head.weight" + top_lm_head_key = "lm_head.weight" + unified_weights[top_lm_head_key] = unified_weights[decoder_lm_head_key] + del unified_weights[decoder_lm_head_key] + + prefixed_weights = {} + for key, tensor in unified_weights.items(): + if key == top_lm_head_key: + prefixed_weights[key] = tensor + elif not key.startswith("model."): + prefixed_weights[f"model.{key}"] = tensor + else: + prefixed_weights[key] = tensor + + unified_weights = prefixed_weights + + return unified_weights + + +def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): + tokenizer_config = { + "tokenizer_class": "BltTokenizer", + "vocab_size": config.get("vocab_size", 256), + "model_max_length": config.get("max_seqlen", 1024), + "add_bos_token": True, + "add_eos_token": True, + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "", + } + + tokenizer_path = os.path.join(output_dir, "tokenizer_config.json") + with open(tokenizer_path, "w") as f: + json.dump(tokenizer_config, f, indent=2) + + +def push_to_hub( + local_dir: str, + repo_id: str, + commit_message: str = "Upload converted BLT model", + private: bool = False, + token: Optional[str] = None, +) -> None: + try: + upload_folder( + folder_path=local_dir, + repo_id=repo_id, + commit_message=commit_message, + repo_type="model", + token=token, + ) + logger.info(f"Successfully pushed model to {repo_id}") + + except Exception as e: + logger.error(f"Failed to push model to Hub: {e}") + raise + + +def convert_hf_blt_to_unified( + model_id: str, + output_dir: str, + config_name: str = "config.json", + weights_name: str = "model.bin", + cache_dir: Optional[str] = None, + push_to_hub_repo: Optional[str] = None, + hub_private: bool = False, + hub_token: Optional[str] = None, +) -> None: + # Download model files + config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir) + weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir) + entropy_params_path = hf_hub_download(repo_id=model_id, filename="entropy_model/params.json", cache_dir=cache_dir) + entropy_weights_path = hf_hub_download( + repo_id=model_id, filename="entropy_model/consolidated.pth", cache_dir=cache_dir + ) + + unified_config = merge_configurations(config_path, entropy_params_path) + unified_weights = merge_weights(weights_path, entropy_weights_path) + + os.makedirs(output_dir, exist_ok=True) + + config_path = os.path.join(output_dir, config_name) + with open(config_path, "w") as f: + json.dump(unified_config, f, indent=2) + + if weights_name.endswith(".bin"): + weights_name = weights_name.replace(".bin", ".safetensors") + + weights_path = os.path.join(output_dir, weights_name) + save_file(unified_weights, weights_path) + + create_tokenizer_config(output_dir, unified_config) + + logger.info(f"Conversion completed, model saved to: {output_dir}") + + if push_to_hub_repo: + push_to_hub( + local_dir=output_dir, + repo_id=push_to_hub_repo, + commit_message="Upload BLT model converted", + private=hub_private, + token=hub_token, + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Convert BLT models from HuggingFace Hub format to unified format", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--model_id", + type=str, + default="facebook/blt-1b", + ) + parser.add_argument( + "--output_dir", + type=str, + default="./blt_converted", + ) + parser.add_argument( + "--config_name", + type=str, + default="config.json", + ) + parser.add_argument( + "--weights_name", + type=str, + default="model.bin", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + ) + parser.add_argument( + "--debug", + action="store_true", + default=True, + ) + parser.add_argument( + "--push_to_hub", + type=str, + default=None, + ) + parser.add_argument( + "--hub_private", + action="store_true", + default=False, + ) + parser.add_argument( + "--hub_token", + type=str, + default="hf_token", + ) + + args = parser.parse_args() + + transformers_logging.set_verbosity_debug() + logging.basicConfig(level=logging.DEBUG) + + try: + convert_hf_blt_to_unified( + model_id=args.model_id, + output_dir=args.output_dir, + config_name=args.config_name, + weights_name=args.weights_name, + cache_dir=args.cache_dir, + push_to_hub_repo=args.push_to_hub, + hub_private=args.hub_private, + hub_token=args.hub_token, + ) + except Exception as e: + logger.error(f"Conversion failed: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/blt_wip/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py similarity index 100% rename from src/transformers/models/blt_wip/modeling_blt.py rename to src/transformers/models/blt/modeling_blt.py diff --git a/src/transformers/models/blt_wip/modular_blt.py b/src/transformers/models/blt/modular_blt.py similarity index 100% rename from src/transformers/models/blt_wip/modular_blt.py rename to src/transformers/models/blt/modular_blt.py diff --git a/src/transformers/models/blt_wip/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py similarity index 100% rename from src/transformers/models/blt_wip/tokenization_blt.py rename to src/transformers/models/blt/tokenization_blt.py diff --git a/tests/models/blt/__init__.py b/tests/models/blt/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py new file mode 100644 index 000000000000..47ca2d000f11 --- /dev/null +++ b/tests/models/blt/test_modeling_blt.py @@ -0,0 +1,930 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch BLT model.""" + +import unittest + +from packaging import version +from parameterized import parameterized + +from transformers import AutoTokenizer, BLTConfig, StaticCache, is_torch_available, set_seed +from transformers.generation.configuration_utils import GenerationConfig +from transformers.testing_utils import ( + Expectations, + cleanup, + require_read_token, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + BLTForCausalLM, + BLTForQuestionAnswering, + BLTForSequenceClassification, + BLTForTokenClassification, + BLTModel, + BLTTokenizer, + ) + from transformers.models.blt.modeling_blt import BLTRotaryEmbedding + + +class BLTModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + pad_token_id=0, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return BLTConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + ) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = BLTModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class BLTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + BLTModel, + BLTForCausalLM, + BLTForSequenceClassification, + BLTForQuestionAnswering, + BLTForTokenClassification, + ) + if is_torch_available() + else () + ) + test_headmasking = False + test_pruning = False + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = BLTForCausalLM if is_torch_available() else None + + def setUp(self): + self.model_tester = BLTModelTester(self) + self.config_tester = ConfigTester(self, config_class=BLTConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_blt_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = BLTForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_blt_sequence_classification_model_for_single_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = BLTForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_blt_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + model = BLTForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_blt_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = BLTForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + + @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) + def test_model_rope_scaling_from_config(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = BLTModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = BLTModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn( + 1, dtype=torch.float32, device=torch_device + ) # used exclusively to get the dtype and the device + position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) + position_ids_short = position_ids_short.unsqueeze(0) + position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) + position_ids_long = position_ids_long.unsqueeze(0) + + # Sanity check original RoPE + original_rope = BLTRotaryEmbedding(config=config).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, position_ids_short) + original_cos_long, original_sin_long = original_rope(x, position_ids_long) + torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = BLTRotaryEmbedding(config=config).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) + torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = BLTRotaryEmbedding(config=config).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + + # Sanity check Yarn RoPE scaling + # Scaling should be over the entire input + config.rope_scaling = {"type": "yarn", "factor": scaling_factor} + yarn_scaling_rope = BLTRotaryEmbedding(config=config).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_short, original_cos_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + + def test_model_loading_old_rope_configs(self): + def _reinitialize_config(base_config, new_kwargs): + # Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation + # steps. + base_config_dict = base_config.to_dict() + new_config = BLTConfig.from_dict(config_dict={**base_config_dict, **new_kwargs}) + return new_config + + # from untouched config -> ✅ + base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common() + original_model = BLTForCausalLM(base_config).to(torch_device) + original_model(**model_inputs) + + # from a config with the expected rope configuration -> ✅ + config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}}) + original_model = BLTForCausalLM(config).to(torch_device) + original_model(**model_inputs) + + # from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC + config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}}) + original_model = BLTForCausalLM(config).to(torch_device) + original_model(**model_inputs) + + # from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config) + config = _reinitialize_config( + base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}} + ) + self.assertTrue(config.rope_scaling["type"] == "linear") + self.assertTrue(config.rope_scaling["rope_type"] == "linear") + original_model = BLTForCausalLM(config).to(torch_device) + original_model(**model_inputs) + + # from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning + with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: + config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}}) + original_model = BLTForCausalLM(config).to(torch_device) + original_model(**model_inputs) + self.assertEqual(len(logs.output), 1) + self.assertIn("factor field", logs.output[0]) + + # from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning + with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: + config = _reinitialize_config( + base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}} + ) + original_model = BLTForCausalLM(config).to(torch_device) + original_model(**model_inputs) + self.assertEqual(len(logs.output), 1) + self.assertIn("Unrecognized keys", logs.output[0]) + + # from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception + with self.assertRaises(KeyError): + config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" + + +@require_torch_accelerator +class BLTIntegrationTest(unittest.TestCase): + def tearDown(self): + # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves + # some memory allocated in the cache, which means some object is not being released properly. This causes some + # unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU. + # Investigate the root cause. + cleanup(torch_device, gc_collect=False) + + @slow + @require_read_token + def test_blt_3_1_hard(self): + """ + An integration test for blt 3.1. It tests against a long output to ensure the subtle numerical differences + from blt 3.1.'s RoPE can be detected + """ + # diff on `EXPECTED_TEXT`: + # 2024-08-26: updating from torch 2.3.1 to 2.4.0 slightly changes the results. + EXPECTED_TEXT = ( + "Tell me about the french revolution. The french revolution was a period of radical political and social " + "upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked " + "by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the " + "First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative " + "assembly that had not met since 1614. The Third Estate, which represented the common people, " + "demanded greater representation and eventually broke away to form the National Assembly. This marked " + "the beginning of the end of the absolute monarchy and the rise of the middle class.\n" + ) + + tokenizer = AutoTokenizer.from_pretrained("meta-blt/Meta-BLT-3.1-8B-Instruct") + model = BLTForCausalLM.from_pretrained( + "meta-blt/Meta-BLT-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16 + ) + input_text = ["Tell me about the french revolution."] + model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device) + + generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False) + generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(generated_text, EXPECTED_TEXT) + + @slow + @require_read_token + def test_model_7b_logits_bf16(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + + model = BLTForCausalLM.from_pretrained( + "meta-blt/BLT-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" + ) + + with torch.no_grad(): + out = model(torch.tensor([input_ids]).to(torch_device)) + # Expected mean on dim = -1 + + # fmt: off + expected_means = Expectations( + { + ("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]), + ("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]), + ("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]) + }) + + expected_mean = expected_means.get_expectation() + self.assertTrue( + torch.allclose( + expected_mean.to(torch_device), + out.logits.float().mean(-1), + atol=1e-2, + rtol=1e-2 + ) + ) + + # slicing logits[0, 0, 0:15] + expected_slices = Expectations( + { + ("xpu", 3): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]), + ("cuda", 7): torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]), + ("cuda", 8): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]) + }) + # fmt: on + expected_slice = expected_slices.get_expectation() + self.assertTrue( + torch.allclose( + expected_slice.to(torch_device), + out.logits[0, 0, :15].float(), + atol=1e-2, + rtol=1e-2, + ) + ) + + @slow + @require_read_token + def test_model_7b_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + + model = BLTForCausalLM.from_pretrained( + "meta-blt/BLT-2-7b-hf", device_map="auto", torch_dtype=torch.float16 + ) + + with torch.no_grad(): + out = model(torch.tensor([input_ids]).to(torch_device)) + + # fmt: off + # Expected mean on dim = -1 + expected_means = Expectations( + { + ("xpu", 3): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), + ("cuda", 7): torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]), + ("cuda", 8): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), + }) + + expected_mean = expected_means.get_expectation() + self.assertTrue( + torch.allclose( + expected_mean.to(torch_device), + out.logits.float().mean(-1), + atol=1e-2, + rtol=1e-2 + ) + ) + + # slicing logits[0, 0, 0:15] + expected_slices = Expectations( + { + ("xpu", 3): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]), + ("cuda", 7): torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]), + ("cuda", 8): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]) + }) + # fmt: on + + expected_slice = expected_slices.get_expectation() + self.assertTrue( + torch.allclose( + expected_slice.to(torch_device), + out.logits[0, 0, :15].float(), + atol=1e-2, + rtol=1e-2, + ) + ) + + @slow + def test_model_7b_dola_generation(self): + # ground truth text generated with dola_layers="low", repetition_penalty=1.2 + EXPECTED_TEXT_COMPLETION = ( + "Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of " + "physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of " + "relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our " + "understanding of space and time." + ) + prompt = "Simply put, the theory of relativity states that " + tokenizer = BLTTokenizer.from_pretrained("meta-blt/BLT-2-7b-chat-hf") + model = BLTForCausalLM.from_pretrained( + "meta-blt/BLT-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16 + ) + model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + # greedy generation outputs + generated_ids = model.generate( + **model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low" + ) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + @require_torch_accelerator + @require_read_token + def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ] + + prompts = [ + "Simply put, the theory of relativity states that ", + "My favorite all time favorite condiment is ketchup.", + ] + tokenizer = BLTTokenizer.from_pretrained("meta-blt/BLT-2-7b-hf", pad_token="", padding_side="right") + model = BLTForCausalLM.from_pretrained( + "meta-blt/BLT-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + @slow + @require_read_token + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + convert_and_export_with_cache, + ) + + blt_models = { + "meta-blt/BLT-3.2-1B": [ + "Simply put, the theory of relativity states that 1) the speed of light is the same for all " + "observers, regardless of their location, and 2) the laws of physics are the same for all observers" + ], + } + + for blt_model_ckp, EXPECTED_TEXT_COMPLETION in blt_models.items(): + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(blt_model_ckp, pad_token="", padding_side="right") + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + model = BLTForCausalLM.from_pretrained( + blt_model_ckp, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + "device": device, + }, + ), + ) + + prompts = ["Simply put, the theory of relativity states that "] + prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + export + exported_program = convert_and_export_with_cache(model) + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + + +@slow +@require_torch_accelerator +class Mask4DTestHard(unittest.TestCase): + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def setUp(self): + cleanup(torch_device, gc_collect=True) + model_name = "TinyBLT/TinyBLT-1.1B-Chat-v1.0" + self.model_dtype = torch.float32 + self.tokenizer = BLTTokenizer.from_pretrained(model_name) + self.model = BLTForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) + + def get_test_data(self): + template = "my favorite {}" + items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item + + batch_separate = [template.format(x) for x in items] # 3 separate lines + batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated + + input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) + input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) + + mask_shared_prefix = torch.tensor( + [ + [ + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], + ] + ] + ], + device=torch_device, + ) + + position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) + + # building custom positions ids based on custom mask + position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) + # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) + + # inverting the mask + min_dtype = torch.finfo(self.model_dtype).min + mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype + + return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix + + def test_stacked_causal_mask(self): + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # single forward run with 4D custom mask + logits_shared_prefix = self.model.forward( + input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix + ).logits + logits_shared_prefix_last = logits_shared_prefix[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : + ] # last three tokens + decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] + + self.assertEqual(decoded, decoded_shared_prefix) + + def test_partial_stacked_causal_mask(self): + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks + + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # 2 forward runs with custom 4D masks + part_a = 3 # split point + + input_1a = input_ids_shared_prefix[:, :part_a] + position_ids_1a = position_ids_shared_prefix[:, :part_a] + mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] + + outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a) + past_key_values_a = outs_1a["past_key_values"] + + # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) + input_1b = input_ids_shared_prefix[:, part_a:] + position_ids_1b = position_ids_shared_prefix[:, part_a:] + mask_1b = mask_shared_prefix[:, :, part_a:, :] + outs_1b = self.model.forward( + input_1b, + attention_mask=mask_1b, + position_ids=position_ids_1b, + past_key_values=past_key_values_a, + ) + decoded_1b = [ + self.tokenizer.decode(t) + for t in outs_1b.logits.argmax(-1)[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a + ] + ] + self.assertEqual(decoded, decoded_1b) + + def test_stacked_causal_mask_static_cache(self): + """same as above but with StaticCache""" + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # upgrade the model with StaticCache + max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] + past_key_values = StaticCache( + config=self.model.config, + max_batch_size=1, + max_cache_len=max_cache_len, + device=torch_device, + dtype=self.model.dtype, + ) + + padded_attention_mask = torch.nn.functional.pad( + input=mask_shared_prefix, + pad=(0, max_cache_len - mask_shared_prefix.shape[-1]), + mode="constant", + value=torch.finfo(self.model_dtype).min, + ) + + # single forward run with 4D custom mask + logits_shared_prefix = self.model.forward( + input_ids_shared_prefix, + attention_mask=padded_attention_mask, + position_ids=position_ids_shared_prefix, + cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device), + past_key_values=past_key_values, + ).logits + logits_shared_prefix_last = logits_shared_prefix[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : + ] # last three tokens + decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] + + self.assertEqual(decoded, decoded_shared_prefix) + + def test_partial_stacked_causal_mask_static_cache(self): + # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks + # we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len]) + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self.get_test_data() + + # regular batch + logits = self.model.forward(input_ids, position_ids=position_ids).logits + logits_last = logits[:, -1, :] # last tokens in each batch line + decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] + + # upgrade the model with StaticCache + max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] + past_key_values = StaticCache( + config=self.model.config, + max_batch_size=1, + max_cache_len=max_cache_len, + device=torch_device, + dtype=self.model.dtype, + ) + + # forward run for the first part of input + part_a = 3 # split point + + input_1a = input_ids_shared_prefix[:, :part_a] + position_ids_1a = position_ids_shared_prefix[:, :part_a] + mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] + + padded_mask_1a = torch.nn.functional.pad( + input=mask_1a, + pad=(0, max_cache_len - mask_1a.shape[-1]), + mode="constant", + value=torch.finfo(self.model_dtype).min, + ) + + _ = self.model.forward( + input_1a, + attention_mask=padded_mask_1a, + position_ids=position_ids_1a, + cache_position=torch.arange(part_a, device=torch_device), + past_key_values=past_key_values, + ) + + # forward run for the second part of input + input_1b = input_ids_shared_prefix[:, part_a:] + position_ids_1b = position_ids_shared_prefix[:, part_a:] + mask_1b = mask_shared_prefix[:, :, part_a:, :] + + padded_mask_1b = torch.nn.functional.pad( + input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0 + ) + + outs_1b = self.model.forward( + input_1b, + attention_mask=padded_mask_1b, + position_ids=position_ids_1b, + cache_position=torch.arange( + part_a, + input_ids_shared_prefix.shape[-1], + device=torch_device, + ), + past_key_values=past_key_values, + ) + decoded_1b = [ + self.tokenizer.decode(t) + for t in outs_1b.logits.argmax(-1)[ + 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a + ] + ] + self.assertEqual(decoded, decoded_1b) diff --git a/tests/models/blt/test_tokenization_blt.py b/tests/models/blt/test_tokenization_blt.py new file mode 100644 index 000000000000..62af101b1d83 --- /dev/null +++ b/tests/models/blt/test_tokenization_blt.py @@ -0,0 +1,914 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +import shutil +import tempfile +import unittest + +from datasets import load_dataset +from huggingface_hub import hf_hub_download + +from transformers import ( + SPIECE_UNDERLINE, + AddedToken, + AutoTokenizer, + BLTTokenizer, + BLTTokenizerFast, + PreTrainedTokenizerFast, +) +from transformers.convert_slow_tokenizer import convert_slow_tokenizer +from transformers.testing_utils import ( + get_tests_dir, + nested_simplify, + require_jinja, + require_read_token, + require_sentencepiece, + require_tiktoken, + require_tokenizers, + require_torch, + slow, +) + +from ...test_tokenization_common import TokenizerTesterMixin + + +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + + +@require_sentencepiece +@require_tokenizers +class BLTTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + from_pretrained_id = ["hf-internal-testing/blt-tokenizer", "meta-blt/BLT-2-7b-hf"] + tokenizer_class = BLTTokenizer + rust_tokenizer_class = BLTTokenizerFast + + test_rust_tokenizer = False + test_sentencepiece = True + from_pretrained_kwargs = {} + + @classmethod + def setUpClass(cls): + super().setUpClass() + + # We have a SentencePiece fixture for testing + tokenizer = BLTTokenizer(SAMPLE_VOCAB, keep_accents=True) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.save_pretrained(cls.tmpdirname) + + def get_tokenizers(self, **kwargs): + kwargs.update({"pad_token": ""}) + return super().get_tokenizers(**kwargs) + + def test_full_tokenizer(self): + tokenizer = BLTTokenizer(SAMPLE_VOCAB, keep_accents=True) + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), + [285, 46, 10, 170, 382], + ) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual( + tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "9", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "é", + ".", + ], + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual( + ids, + [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4], + ) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "", + ".", + ], + ) + + @unittest.skip(reason="Let's wait for the fast tokenizer!") + def test_save_pretrained(self): + self.tokenizers_list += (self.rust_tokenizer_class, "hf-internal-testing/blt-tokenizer", {}) + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.get_rust_tokenizer(pretrained_name, **kwargs) + tokenizer_p = self.get_tokenizer(pretrained_name, **kwargs) + + tmpdirname2 = tempfile.mkdtemp() + + tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2) + tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2) + + # Checks it save with the same files + the tokenizer.json file for the fast one + self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files)) + tokenizer_r_files = tuple(f for f in tokenizer_r_files if "tokenizer.json" not in f) + self.assertSequenceEqual(tokenizer_r_files, tokenizer_p_files) + + # Checks everything loads correctly in the same way + tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2) + tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2) + + # Check special tokens are set accordingly on Rust and Python + for key in tokenizer_pp.special_tokens_map: + self.assertTrue(hasattr(tokenizer_rp, key)) + + shutil.rmtree(tmpdirname2) + + # Save tokenizer rust, legacy_format=True + tmpdirname2 = tempfile.mkdtemp() + + tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2, legacy_format=True) + tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2) + + # Checks it save with the same files + self.assertSequenceEqual(tokenizer_r_files, tokenizer_p_files) + + # Checks everything loads correctly in the same way + tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2) + tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2) + + # Check special tokens are set accordingly on Rust and Python + for key in tokenizer_pp.special_tokens_map: + self.assertTrue(hasattr(tokenizer_rp, key)) + + shutil.rmtree(tmpdirname2) + + # Save tokenizer rust, legacy_format=False + tmpdirname2 = tempfile.mkdtemp() + + tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2, legacy_format=False) + tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2) + + # Checks it saved the tokenizer.json file + self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files)) + + # Checks everything loads correctly in the same way + tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2) + tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2) + + # Check special tokens are set accordingly on Rust and Python + for key in tokenizer_pp.special_tokens_map: + self.assertTrue(hasattr(tokenizer_rp, key)) + + shutil.rmtree(tmpdirname2) + + @require_torch + def test_batch_tokenization(self): + if not self.test_seq2seq: + self.skipTest(reason="test_seq2seq is set to False") + + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + # Longer text that will definitely require truncation. + text = [ + " UN Chief Says There Is No Military Solution in Syria", + " Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for" + " Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons" + " will only worsen the violence and misery for millions of people.", + ] + try: + batch = tokenizer( + text=text, + max_length=3, + max_target_length=10, + return_tensors="pt", + ) + except NotImplementedError: + self.skipTest(reason="Encountered NotImplementedError when calling tokenizer") + self.assertEqual(batch.input_ids.shape[1], 3) + # max_target_length will default to max_length if not specified + batch = tokenizer(text, max_length=3, return_tensors="pt") + self.assertEqual(batch.input_ids.shape[1], 3) + + batch_encoder_only = tokenizer(text=text, max_length=3, max_target_length=10, return_tensors="pt") + self.assertEqual(batch_encoder_only.input_ids.shape[1], 3) + self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3) + self.assertNotIn("decoder_input_ids", batch_encoder_only) + + @unittest.skip(reason="Unfortunately way too slow to build a BPE with SentencePiece.") + def test_save_slow_from_fast_and_reload_fast(self): + pass + + def test_special_tokens_initialization(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + added_tokens = [AddedToken("", lstrip=True)] + + tokenizer_r = self.get_rust_tokenizer( + pretrained_name, additional_special_tokens=added_tokens, **kwargs + ) + r_output = tokenizer_r.encode("Hey this is a token") + + special_token_id = tokenizer_r.encode("", add_special_tokens=False)[0] + + self.assertTrue(special_token_id in r_output) + + if self.test_slow_tokenizer: + tokenizer_cr = self.get_rust_tokenizer( + pretrained_name, + additional_special_tokens=added_tokens, + **kwargs, # , from_slow=True <- unfortunately too slow to convert + ) + tokenizer_p = self.tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs + ) + + p_output = tokenizer_p.encode("Hey this is a token") + + cr_output = tokenizer_cr.encode("Hey this is a token") + + self.assertEqual(p_output, r_output) + self.assertEqual(cr_output, r_output) + self.assertTrue(special_token_id in p_output) + self.assertTrue(special_token_id in cr_output) + + @slow + def test_tokenizer_integration(self): + expected_encoding = {'input_ids': [[1, 4103, 689, 414, 313, 24784, 368, 2998, 408, 282, 3637, 25350, 29899, 9067, 414, 322, 282, 3637, 25350, 29899, 1457, 3018, 1312, 29899, 2151, 29897, 8128, 2498, 29899, 15503, 4220, 6956, 1973, 313, 13635, 29911, 29892, 402, 7982, 29899, 29906, 29892, 1528, 13635, 29911, 29874, 29892, 1060, 26369, 29892, 6652, 309, 29933, 814, 29892, 1060, 29931, 6779, 11410, 363, 18385, 17088, 7634, 11235, 313, 25103, 29965, 29897, 322, 18385, 17088, 28203, 313, 25103, 29954, 29897, 411, 975, 29871, 29941, 29906, 29974, 758, 3018, 1312, 4733, 297, 29871, 29896, 29900, 29900, 29974, 10276, 322, 6483, 1006, 3372, 3097, 1546, 435, 1165, 29892, 10772, 29911, 25350, 322, 323, 6073, 17907, 29889], [1, 350, 20161, 338, 8688, 304, 758, 29899, 14968, 6483, 21000, 8684, 284, 22540, 515, 443, 29880, 24025, 1426, 491, 14002, 368, 4195, 292, 373, 1716, 2175, 322, 1492, 3030, 297, 599, 15359, 29889], [1, 450, 4996, 17354, 1701, 29916, 432, 17204, 975, 278, 17366, 11203, 29889]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip + + self.tokenizer_integration_test_util( + expected_encoding=expected_encoding, + model_name="hf-internal-testing/blt-tokenizer", + revision="0984d03108b1a041ed679bd253b6519b7e1a4778", + padding=False, + ) + + def test_picklable(self): + with tempfile.NamedTemporaryFile() as f: + shutil.copyfile(SAMPLE_VOCAB, f.name) + tokenizer = BLTTokenizer(f.name, keep_accents=True) + pickled_tokenizer = pickle.dumps(tokenizer) + pickle.loads(pickled_tokenizer) + + @unittest.skip(reason="worker 'gw4' crashed on CI, passing locally.") + def test_pickle_subword_regularization_tokenizer(self): + pass + + @unittest.skip(reason="worker 'gw4' crashed on CI, passing locally.") + def test_subword_regularization_tokenizer(self): + pass + + def test_add_prefix_space(self): + pretrained_name = "hf-internal-testing/blt-tokenizer-non-normalized" + inputs = "Hey how are you doing" + EXPECTED_WITH_SPACE = [1, 18637, 920, 526, 366, 2599] + EXPECTED_WO_SPACE = [1, 29950, 1032, 920, 526, 366, 2599] + + slow_ = self.get_tokenizer(pretrained_name, add_prefix_space=False, legacy=False) + fast_ = self.get_rust_tokenizer(pretrained_name, add_prefix_space=False, legacy=False) + self.assertEqual(slow_.encode(inputs), EXPECTED_WO_SPACE) + self.assertEqual(slow_.encode(inputs), fast_.encode(inputs)) + self.assertEqual(slow_.tokenize(inputs), ["H", "ey", "▁how", "▁are", "▁you", "▁doing"]) + self.assertEqual(slow_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True), inputs) + self.assertEqual( + slow_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True), + fast_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True), + ) + + slow_ = self.get_tokenizer(pretrained_name, add_prefix_space=True, legacy=False) + fast_ = self.get_rust_tokenizer(pretrained_name, add_prefix_space=True, legacy=False) + self.assertEqual(slow_.encode(inputs), EXPECTED_WITH_SPACE) + self.assertEqual(slow_.encode(inputs), fast_.encode(inputs)) + self.assertEqual(slow_.tokenize(inputs), ["▁Hey", "▁how", "▁are", "▁you", "▁doing"]) + self.assertEqual(slow_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True), inputs) + self.assertEqual( + slow_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True), + fast_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True), + ) + + def test_load_tokenizer_with_model_file_only(self): + with tempfile.TemporaryDirectory() as tmp_dir: + hf_hub_download(repo_id="huggyblt/blt-7b", filename="tokenizer.model", local_dir=tmp_dir) + tokenizer_fast = self.rust_tokenizer_class.from_pretrained(tmp_dir) + self.assertEqual(tokenizer_fast.encode("This is a test"), [1, 910, 338, 263, 1243]) + + tokenizer_slow = self.tokenizer_class.from_pretrained(tmp_dir) + self.assertEqual(tokenizer_slow.encode("This is a test"), [1, 910, 338, 263, 1243]) + + +@require_torch +@require_sentencepiece +@require_tokenizers +class BLTIntegrationTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + checkpoint_name = "hf-internal-testing/blt-tokenizer-non-normalized" + cls.tokenizer: BLTTokenizer = BLTTokenizer.from_pretrained(checkpoint_name) + cls.rust_tokenizer = BLTTokenizerFast.from_pretrained(checkpoint_name) + return cls + + @require_torch + def integration_tests(self): + inputs = self.tokenizer( + ["The following string should be properly encoded: Hello.", "But ird and ปี ird ด"], + return_tensors="pt", + ) + + self.assertEqual( + nested_simplify(inputs), + { + "input_ids": [ + [1, 450, 1494, 1347, 881, 367, 6284, 18511, 29901, 15043, 29889], + [1, 1205, 29871, 1823, 322, 29871, 31010, 30691, 1678, 1823, 1678, 30718], + ], + "attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], + }, + ) + + def test_fast_special_tokens(self): + slow_tokenizer = self.tokenizer + fast_tokenizer = self.rust_tokenizer + slow = slow_tokenizer.encode("A sample test", add_special_tokens=True) + assert slow == [1, 319, 4559, 1243] + + fast_tokenizer.add_eos_token = False + fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) + assert fast == [1, 319, 4559, 1243] + + fast_tokenizer.add_eos_token = True + print(fast_tokenizer.add_eos_token) + fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) + assert fast == [1, 319, 4559, 1243, 2] + + slow_tokenizer.add_eos_token = True + slow = slow_tokenizer.encode("A sample test", add_special_tokens=True) + assert slow == [1, 319, 4559, 1243, 2] + + fast_tokenizer = BLTTokenizerFast.from_pretrained( + "hf-internal-testing/blt-tokenizer", add_eos_token=True, add_bos_token=False + ) + fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) + assert fast == [319, 4559, 1243, 2] + + slow_tokenizer = BLTTokenizer.from_pretrained( + "hf-internal-testing/blt-tokenizer", add_eos_token=True, add_bos_token=False + ) + slow = slow_tokenizer.encode("A sample test", add_special_tokens=True) + assert slow == [319, 4559, 1243, 2] + + self.tokenizer.add_eos_token = False + self.rust_tokenizer.add_eos_token = False + + @slow + def test_conversion(self): + # This is excruciatingly slow since it has to recreate the entire merge + # list from the original vocabulary in spm + self.rust_tokenizer.save_pretrained("./out") + with tempfile.TemporaryDirectory() as dirname: + self.rust_tokenizer.save_pretrained(dirname) + + with open(os.path.join(dirname, "tokenizer.json")) as f: + old_serialized = f.read() + + new_tokenizer = convert_slow_tokenizer(self.tokenizer) + with tempfile.NamedTemporaryFile() as f: + new_tokenizer.save(f.name) + # Re-opening since `f` is in bytes. + new_serialized = open(f.name).read() + with open("out_tokenizer.json", "w") as g: + g.write(new_serialized) + + self.assertEqual(old_serialized, new_serialized) + + def test_simple_encode_decode(self): + pyth_tokenizer = self.tokenizer + rust_tokenizer = self.rust_tokenizer + + self.assertEqual(pyth_tokenizer.encode("This is a test"), [1, 910, 338, 263, 1243]) + self.assertEqual(rust_tokenizer.encode("This is a test"), [1, 910, 338, 263, 1243]) + self.assertEqual(pyth_tokenizer.decode([1, 910, 338, 263, 1243], skip_special_tokens=True), "This is a test") + self.assertEqual(rust_tokenizer.decode([1, 910, 338, 263, 1243], skip_special_tokens=True), "This is a test") + + # bytefallback showcase + self.assertEqual(pyth_tokenizer.encode("生活的真谛是"), [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392]) # fmt: skip + self.assertEqual(rust_tokenizer.encode("生活的真谛是"), [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392]) # fmt: skip + self.assertEqual( + pyth_tokenizer.decode( + [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392], skip_special_tokens=True + ), + "生活的真谛是", + ) + self.assertEqual( + rust_tokenizer.decode( + [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392], skip_special_tokens=True + ), + "生活的真谛是", + ) + + # Inner spaces showcase + self.assertEqual(pyth_tokenizer.encode("Hi Hello"), [1, 6324, 29871, 15043]) + self.assertEqual(rust_tokenizer.encode("Hi Hello"), [1, 6324, 29871, 15043]) + self.assertEqual(pyth_tokenizer.decode([1, 6324, 29871, 15043], skip_special_tokens=True), "Hi Hello") + self.assertEqual(rust_tokenizer.decode([1, 6324, 29871, 15043], skip_special_tokens=True), "Hi Hello") + + self.assertEqual(pyth_tokenizer.encode("Hi Hello"), [1, 6324, 259, 15043]) + self.assertEqual(rust_tokenizer.encode("Hi Hello"), [1, 6324, 259, 15043]) + self.assertEqual(pyth_tokenizer.decode([1, 6324, 259, 15043], skip_special_tokens=True), "Hi Hello") + self.assertEqual(rust_tokenizer.decode([1, 6324, 259, 15043], skip_special_tokens=True), "Hi Hello") + + self.assertEqual(pyth_tokenizer.encode(""), [1]) + self.assertEqual(rust_tokenizer.encode(""), [1]) + + self.assertEqual(pyth_tokenizer.encode(" "), [1, 259]) + self.assertEqual(rust_tokenizer.encode(" "), [1, 259]) + + self.assertEqual(pyth_tokenizer.encode(" "), [1, 1678]) + self.assertEqual(rust_tokenizer.encode(" "), [1, 1678]) + + self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043]) + self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043]) + + def test_no_differences_showcase(self): + pyth_tokenizer = self.tokenizer + rust_tokenizer = self.rust_tokenizer + self.assertEqual(pyth_tokenizer.encode(""), [1]) + self.assertEqual(rust_tokenizer.encode(""), [1]) + + self.assertEqual(pyth_tokenizer.encode(" "), [1, 259]) + self.assertEqual(rust_tokenizer.encode(" "), [1, 259]) + + self.assertEqual(pyth_tokenizer.encode(" "), [1, 1678]) + self.assertEqual(rust_tokenizer.encode(" "), [1, 1678]) + + self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043]) + self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043]) + + self.assertEqual(pyth_tokenizer.encode(""), [1, 1]) + self.assertEqual(rust_tokenizer.encode(""), [1, 1]) + + def test_no_differences_decode(self): + pyth_tokenizer = self.tokenizer + rust_tokenizer = self.rust_tokenizer + + self.assertEqual(pyth_tokenizer.decode([869]), ".") + self.assertEqual(rust_tokenizer.decode([869]), ".") + + self.assertEqual(pyth_tokenizer.decode([30112, 869]), "ا .") + self.assertEqual(rust_tokenizer.decode([30112, 869]), "ا .") + + def test_no_differences_special_tokens(self): + pyth_tokenizer = self.tokenizer + rust_tokenizer = self.rust_tokenizer + self.assertEqual(pyth_tokenizer.encode(""), [1]) + self.assertEqual(rust_tokenizer.encode(""), [1]) + + self.assertEqual(pyth_tokenizer.encode(""), [1, 1]) + self.assertEqual(rust_tokenizer.encode(""), [1, 1]) + + @unittest.skipIf( + os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0", + "RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests", + ) + def test_integration_test_xnli(self): + import tqdm + + pyth_tokenizer = self.tokenizer + rust_tokenizer = self.rust_tokenizer + + dataset = load_dataset("google/code_x_glue_ct_code_to_text", "go") + for item in tqdm.tqdm(dataset["validation"]): + string = item["code"] + encoded1 = pyth_tokenizer.encode(string) + encoded2 = rust_tokenizer.encode(string) + + self.assertEqual(encoded1, encoded2) + + decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True) + decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True) + + self.assertEqual(decoded1, decoded2) + + dataset = load_dataset("facebook/xnli", "all_languages") + + for item in tqdm.tqdm(dataset["train"]): + for string in item["premise"].values(): + encoded1 = pyth_tokenizer.encode(string) + encoded2 = rust_tokenizer.encode(string) + + self.assertEqual(encoded1, encoded2) + + decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True) + decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True) + + self.assertEqual(decoded1, decoded2) + + def test_special_token_special_word(self): + # the word inform should be split as ['in', 'form'] + tokenizer = BLTTokenizerFast.from_pretrained("huggyblt/blt-7b", legacy=False, from_slow=True) + tokenizer.add_tokens([AddedToken("", rstrip=True, lstrip=True)], special_tokens=False) + + example_inputs = tokenizer.tokenize("inform. Hey. .") + self.assertEqual(example_inputs, ["", "in", "form", "", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."]) + + # Make sure dummy space is added if it is indeed the first word + example_inputs = tokenizer.tokenize("inform. Hey. .") + self.assertEqual(example_inputs, ["▁inform", "", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."]) + out1 = tokenizer.decode( + tokenizer.encode("inform", add_special_tokens=False), spaces_between_special_tokens=False + ) + self.assertEqual(out1, "inform") + out2 = tokenizer.decode( + tokenizer.encode("inform", add_special_tokens=False), spaces_between_special_tokens=True + ) + # decoding strips the added prefix space. + self.assertEqual(out2, "inform") + input_ids = tokenizer.encode("inform", add_special_tokens=False) + self.assertEqual(input_ids, [32000, 262, 689]) # 29871 is the spiece underline, '▁' added as it should + + out2 = tokenizer.decode( + tokenizer.encode(" inform", add_special_tokens=False), spaces_between_special_tokens=False + ) + # TODO @ArthurZ currently we strip left and right, so this will not keep the spaces + self.assertEqual(out2, "inform") + + ### Let's make sure decoding does not add extra spaces here and there + # TODO @ArthurZ this should be affected by the lstrip/rstrip/single word /normalize refactoring + # Since currently we always strip left and right of the token, results are as such + input_ids = tokenizer.encode(" Hellohow", add_special_tokens=False) + self.assertEqual(input_ids, [1, 15043, 1, 3525]) + tokens = tokenizer.tokenize(" Hellohow", add_special_tokens=False) + self.assertEqual(tokens, ["", "▁Hello", "", "how"]) + decoded_tokens = tokenizer.decode(input_ids) + self.assertEqual(decoded_tokens, " Hellohow") + + # Let's make sure that if there are any spaces, we don't remove them! + input_ids = tokenizer.encode(" Hello how", add_special_tokens=False) + self.assertEqual(input_ids, [29871, 1, 15043, 1, 920]) + tokens = tokenizer.tokenize(" Hello how", add_special_tokens=False) + self.assertEqual(tokens, ["▁", "", "▁Hello", "", "▁how"]) + decoded_tokens = tokenizer.decode(input_ids) + self.assertEqual(decoded_tokens, " Hello how") + + # Let's make sure the space is preserved + input_ids = tokenizer.encode("hello", add_special_tokens=True) + self.assertEqual(input_ids, [1, 22172]) + tokens = tokenizer.tokenize("hello") + self.assertEqual(tokens, ["▁hello"]) + decoded_tokens = tokenizer.decode(input_ids) + self.assertEqual(decoded_tokens, " hello") + + input_ids = tokenizer.encode("hello", add_special_tokens=False) + self.assertEqual(input_ids, [22172]) + decoded_tokens = tokenizer.decode(input_ids) + self.assertEqual(decoded_tokens, "hello") + + def test_no_prefix_space(self): + tokenizer_no_prefix_space = BLTTokenizerFast.from_pretrained("huggyblt/blt-7b", add_prefix_space=False) + no_prefix_space_tokens = tokenizer_no_prefix_space.tokenize("Hey") + self.assertEqual(no_prefix_space_tokens, ["H", "ey"]) + + tokenizer = BLTTokenizerFast.from_pretrained( + "huggyblt/blt-7b", legacy=False, from_slow=True, add_prefix_space=False + ) + tokenizer.add_tokens([AddedToken("", rstrip=True, lstrip=True)], special_tokens=False) + + example_inputs = tokenizer.tokenize("inform. Hey. .") + self.assertEqual(example_inputs, ["", "in", "form", "", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."]) + + # Make sure dummy space is added if it is indeed the first word + example_inputs = tokenizer.tokenize("inform. Hey. .") + self.assertEqual(example_inputs, ["in", "form", "", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."]) + out1 = tokenizer.decode( + tokenizer.encode("inform", add_special_tokens=False), spaces_between_special_tokens=False + ) + self.assertEqual(out1, "inform") + out2 = tokenizer.decode( + tokenizer.encode("inform", add_special_tokens=False), spaces_between_special_tokens=True + ) + # decoding strips the added prefix space. + self.assertEqual(out2, "inform") + input_ids = tokenizer.encode("inform", add_special_tokens=False) + self.assertEqual(input_ids, [32000, 262, 689]) # 29871 is the spiece underline, '▁' added as it should + + out2 = tokenizer.decode( + tokenizer.encode(" inform", add_special_tokens=False), spaces_between_special_tokens=False + ) + self.assertEqual(out2, "inform") + + input_ids = tokenizer.encode(" Hellohow", add_special_tokens=False) + self.assertEqual(input_ids, [1, 15043, 1, 3525]) + tokens = tokenizer.tokenize(" Hellohow", add_special_tokens=False) + self.assertEqual(tokens, ["", "▁Hello", "", "how"]) + decoded_tokens = tokenizer.decode(input_ids) + self.assertEqual(decoded_tokens, " Hellohow") + + # Let's make sure that if there are any spaces, we don't remove them! + input_ids = tokenizer.encode(" Hello how", add_special_tokens=False) + self.assertEqual(input_ids, [29871, 1, 15043, 1, 920]) + tokens = tokenizer.tokenize(" Hello how", add_special_tokens=False) + self.assertEqual(tokens, ["▁", "", "▁Hello", "", "▁how"]) + decoded_tokens = tokenizer.decode(input_ids) + self.assertEqual(decoded_tokens, " Hello how") + + # Let's make sure the space is preserved + input_ids = tokenizer.encode("hello", add_special_tokens=True) + self.assertEqual(input_ids, [1, 12199]) + tokens = tokenizer.tokenize("hello") + self.assertEqual(tokens, ["hello"]) + decoded_tokens = tokenizer.decode(input_ids) + self.assertEqual(decoded_tokens, "hello") + + input_ids = tokenizer.encode("hello", add_special_tokens=False) + self.assertEqual(input_ids, [12199]) + decoded_tokens = tokenizer.decode(input_ids) + self.assertEqual(decoded_tokens, "hello") + + def test_some_edge_cases(self): + tokenizer = BLTTokenizer.from_pretrained("huggyblt/blt-7b", legacy=False) + + sp_tokens = tokenizer.sp_model.encode(">", out_type=str) + self.assertEqual(sp_tokens, ["<", "s", ">>"]) + tokens = tokenizer.tokenize(">") + self.assertNotEqual(sp_tokens, tokens) + self.assertEqual(tokens, ["", ">"]) + + tokens = tokenizer.tokenize("") + self.assertEqual(tokens, []) + self.assertEqual(tokens, tokenizer.sp_model.encode("", out_type=str)) + + tokens = tokenizer.tokenize(" ") + self.assertEqual(tokens, ["▁▁"]) + # a dummy prefix space is not added by the sp_model as it was de-activated + self.assertEqual(tokens, tokenizer.sp_model.encode(" ", out_type=str)) + + tokens = tokenizer.tokenize("▁") + self.assertEqual(tokens, ["▁▁"]) + # a dummy prefix space is not added by the sp_model as it was de-activated + self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁", out_type=str)) + + tokens = tokenizer.tokenize(" ▁") + self.assertEqual(tokens, ["▁▁▁"]) + # a dummy prefix space is not added by the sp_model as it was de-activated + self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str)) + + def test_fast_post_processor(self): + tokenizer = BLTTokenizerFast( + SAMPLE_VOCAB, eos_token=None, bos_token=None, add_bos_token=False, add_eos_token=False + ) + tokenizer.encode(" Hey ") + + with self.assertRaises(ValueError): + tokenizer = BLTTokenizerFast( + SAMPLE_VOCAB, bos_token=None, eos_token="", add_bos_token=True, add_eos_token=False + ) + with self.assertRaises(ValueError): + tokenizer = BLTTokenizerFast(SAMPLE_VOCAB, eos_token=None, add_bos_token=True, add_eos_token=True) + + @require_jinja + def test_tokenization_for_chat(self): + tokenizer = BLTTokenizer.from_pretrained("huggyblt/blt-7b", legacy=False) + + test_chats = [ + [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}], + [ + {"role": "system", "content": "You are a helpful chatbot."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Nice to meet you."}, + ], + [{"role": "user", "content": "Hello!"}], + ] + # Matt: The third test case tests the default system message, but if this is ever changed in the + # class/repo code then that test will fail, and the case will need to be updated. + tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats] + # fmt: off + expected_tokens = [ + [1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962], + [1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962, 20103, 304, 5870, 366, 29889, 29871, 2], + [1, 29961, 25580, 29962, 15043, 29991, 518, 29914, 25580, 29962] + ] + # fmt: on + for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): + self.assertListEqual(tokenized_chat, expected_tokens) + + +@require_sentencepiece +@require_tokenizers +class CommonSpmIntegrationTests(unittest.TestCase): + """ + A class that regroups important test to make sure that we properly handle the special tokens. + """ + + @classmethod + def setUpClass(cls): + tokenizer = BLTTokenizer(SAMPLE_VOCAB, extra_ids=0, add_bos_token=False, legacy=False) + tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("", rstrip=False, lstrip=False)]}) + cls.tokenizer = tokenizer + return cls + + def test_add_dummy_prefix(self): + # make sure `'▁'` is prepended, and outputs match sp_model's + # `sentencepiece.NormalizerSpec.add_dummy_prefix` attribute + input_ids = self.tokenizer.encode(". Hello") + self.assertEqual(input_ids, [7, 4, 156, 86, 20]) + sp_encode = self.tokenizer.sp_model.encode(". Hello") + self.assertEqual(input_ids, [7] + sp_encode) + tokens = self.tokenizer.tokenize(". Hello") + self.assertEqual(tokens, ["▁", ".", "▁He", "ll", "o"]) + + tokens = self.tokenizer.tokenize("") + self.assertEqual(tokens, []) + self.assertEqual(tokens, self.tokenizer.sp_model.encode("", out_type=str)) + + tokens = self.tokenizer.tokenize(" ") + self.assertEqual(tokens, []) + self.assertEqual(tokens, self.tokenizer.sp_model.encode(" ", out_type=str)) + + tokens = self.tokenizer.tokenize("▁") + self.assertEqual(tokens, []) + self.assertEqual(tokens, self.tokenizer.sp_model.encode("▁", out_type=str)) + + def test_remove_extra_whitespaces(self): + # make sure the extra spaces are eaten. Since the sample vocab does not have + # `______`. sentencepiece.NormalizerSpec.remove_extra_whitespaces attribute is set to False + + input_ids = self.tokenizer.encode(" . Hello") + self.assertEqual(input_ids, [7, 4, 156, 86, 20]) + sp_encode = self.tokenizer.sp_model.encode(" . Hello") + self.assertEqual(input_ids, [7] + sp_encode) + tokens = self.tokenizer.tokenize(" . Hello") + self.assertEqual(tokens, ["▁", ".", "▁He", "ll", "o"]) + + # `'▁'` is also a whitespace + input_ids = self.tokenizer.encode("▁He is not") + self.assertEqual(input_ids, [156, 46, 44]) + tokens = self.tokenizer.tokenize("▁He is not") + sp_encode = [ + self.tokenizer.sp_model.piece_to_id("▁He"), + self.tokenizer.sp_model.piece_to_id("▁is"), + self.tokenizer.sp_model.piece_to_id("▁not"), + ] + self.assertEqual(input_ids, sp_encode) + self.assertEqual(tokens, ["▁He", "▁is", "▁not"]) # no extra space added + + input_ids = self.tokenizer.encode("▁He is not ▁He") + self.assertEqual(input_ids, [156, 46, 44, 1, 156]) + tokens = self.tokenizer.tokenize("▁He is not ▁He") + self.assertEqual(tokens, ["▁He", "▁is", "▁not", "", "▁He"]) # spaces are eaten by spm + our strip + # make sure that the output after the extra id is the same as if + # extra_id was not there + input_ids = self.tokenizer.encode("▁He is not ▁He") + self.assertEqual(input_ids, [156, 46, 44, 156]) + tokens = self.tokenizer.tokenize("▁He is not ▁He") + self.assertEqual(tokens, ["▁He", "▁is", "▁not", "▁He"]) # spaces are eaten by spm even if not start + + def test_character_after_special_token(self): + # Make sure that `tokenizer.tokenize` is similar to + # adding the equivalent special token to the vocab + input_ids = self.tokenizer.encode("Hey I") + self.assertEqual(input_ids, [156, 30, 1, 100]) + sp_encode = self.tokenizer.sp_model.encode("Hey .I") + # the last token should be 100 + self.assertEqual(input_ids[-1], sp_encode[-1]) + tokens = self.tokenizer.tokenize("I") + self.assertEqual(tokens, ["", "I"]) + + input_ids = self.tokenizer.encode("Hello, ,") + self.assertEqual(input_ids, [156, 86, 20, 3, 1, 3]) + tokens = self.tokenizer.tokenize("Hello, ,") + self.assertEqual(tokens, ["▁He", "ll", "o", ",", "", ","]) + + def test_special_tokens_strip(self): + input_ids = self.tokenizer.encode(" ,") + self.assertEqual(input_ids, [1, 7, 3]) + tokens = self.tokenizer.tokenize(" ,") + # spaces are eaten by rstrip / lstrip + spm sp_model.encode(" ") = [] + self.assertEqual(tokens, ["", "▁", ","]) + + input_ids = self.tokenizer.encode("No ▁He") + self.assertEqual(input_ids, [284, 1, 156]) + tokens = self.tokenizer.tokenize("No ▁He") + self.assertEqual(tokens, ["▁No", "", "▁He"]) # spaces are eaten by rstrip / lstrip + + +@require_tiktoken +@require_read_token +class TikTokenIntegrationTests(unittest.TestCase): + """ + A class that regroups important test to make sure that we properly handle the special tokens. + """ + + def test_tiktoken_blt(self): + model_path = "hf-internal-testing/blt-3-8b-internal" + subfolder = "original" + test_text = "This is a test sentence." + test_tokens = [128000, 2028, 374, 264, 1296, 11914, 13, 128001] + num_reserved_special_tokens = 256 + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", + "<|python_tag|>", # end of turn + ] + [f"<|reserved_special_token_{i}|>" for i in range(5, num_reserved_special_tokens - 5)] + + tiktoken_tokenizer = PreTrainedTokenizerFast.from_pretrained( + model_path, + subfolder=subfolder, + additional_special_tokens=special_tokens, + bos_token="<|begin_of_text|>", + eos_token="<|end_of_text|>", + ) + tokens = tiktoken_tokenizer.tokenize("<|begin_of_text|> " + test_text) + self.assertEqual(tokens[0], "<|begin_of_text|>") + + tiktoken_tokenizer = AutoTokenizer.from_pretrained( + model_path, + subfolder=subfolder, + legacy=False, + additional_special_tokens=special_tokens, + bos_token="<|begin_of_text|>", + eos_token="<|end_of_text|>", + add_bos_token=True, + add_eos_token=True, + ) + self.assertTrue(isinstance(tiktoken_tokenizer, PreTrainedTokenizerFast)) + + tokens = tiktoken_tokenizer.encode(test_text, add_special_tokens=True) + self.assertEqual(tokens, test_tokens) + + tmpdirname = tempfile.mkdtemp() + tiktoken_tokenizer.save_pretrained(tmpdirname) + tokenizer_reload = AutoTokenizer.from_pretrained(tmpdirname) + + self.assertTrue(isinstance(tokenizer_reload, PreTrainedTokenizerFast)) + tokens = tokenizer_reload.encode(test_text, add_special_tokens=True) + self.assertEqual(tokens, test_tokens) + shutil.rmtree(tmpdirname) + + tiktoken_tokenizer = AutoTokenizer.from_pretrained( + model_path, + subfolder=subfolder, + additional_special_tokens=special_tokens, + bos_token="<|begin_of_text|>", + eos_token="<|end_of_text|>", + from_slow=True, + add_bos_token=True, + add_eos_token=True, + ) + tokens = tiktoken_tokenizer.encode(test_text, add_special_tokens=True) + self.assertEqual(tokens, test_tokens) From 3e8dc1e52519a992d9898bf8361ae420d687f1d6 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 3 Jul 2025 14:05:45 +0000 Subject: [PATCH 057/139] update demo --- src/demo_hf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index 7c64e47b6723..1cf587c427e5 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -5,10 +5,9 @@ from huggingface_hub import hf_hub_download from safetensors.torch import load_file -from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM -from transformers.models.blt_wip.modeling_blt import BLTModel +from transformers.models.blt.modeling_blt import BLTForCausalLM -from transformers.models.blt_wip.tokenization_blt import BLTTokenizer +from transformers.models.blt.tokenization_blt import BLTTokenizer logger = logging.getLogger() From e6bc6398336df2e9b463798e1b3283fd6f8d8b8c Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 3 Jul 2025 16:01:50 +0000 Subject: [PATCH 058/139] working on tests --- src/transformers/models/blt/__init__.py | 28 +++++++++ tests/models/blt/test_modeling_blt.py | 82 +------------------------ 2 files changed, 30 insertions(+), 80 deletions(-) diff --git a/src/transformers/models/blt/__init__.py b/src/transformers/models/blt/__init__.py index e69de29bb2d1..703b81ecdd09 100644 --- a/src/transformers/models/blt/__init__.py +++ b/src/transformers/models/blt/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_blt import * + from .modeling_blt import * + from .tokenization_blt import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 47ca2d000f11..b91ccf8e489a 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -41,9 +41,6 @@ from transformers import ( BLTForCausalLM, - BLTForQuestionAnswering, - BLTForSequenceClassification, - BLTForTokenClassification, BLTModel, BLTTokenizer, ) @@ -172,9 +169,6 @@ class BLTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ( BLTModel, BLTForCausalLM, - BLTForSequenceClassification, - BLTForQuestionAnswering, - BLTForTokenClassification, ) if is_torch_available() else () @@ -443,9 +437,9 @@ def test_blt_3_1_hard(self): "the beginning of the end of the absolute monarchy and the rise of the middle class.\n" ) - tokenizer = AutoTokenizer.from_pretrained("meta-blt/Meta-BLT-3.1-8B-Instruct") + tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b") model = BLTForCausalLM.from_pretrained( - "meta-blt/Meta-BLT-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16 + "itazap/blt-1b", device_map="auto", torch_dtype=torch.bfloat16 ) input_text = ["Tell me about the french revolution."] model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device) @@ -503,78 +497,6 @@ def test_model_7b_logits_bf16(self): ) ) - @slow - @require_read_token - def test_model_7b_logits(self): - input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] - - model = BLTForCausalLM.from_pretrained( - "meta-blt/BLT-2-7b-hf", device_map="auto", torch_dtype=torch.float16 - ) - - with torch.no_grad(): - out = model(torch.tensor([input_ids]).to(torch_device)) - - # fmt: off - # Expected mean on dim = -1 - expected_means = Expectations( - { - ("xpu", 3): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), - ("cuda", 7): torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]), - ("cuda", 8): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]), - }) - - expected_mean = expected_means.get_expectation() - self.assertTrue( - torch.allclose( - expected_mean.to(torch_device), - out.logits.float().mean(-1), - atol=1e-2, - rtol=1e-2 - ) - ) - - # slicing logits[0, 0, 0:15] - expected_slices = Expectations( - { - ("xpu", 3): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]), - ("cuda", 7): torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]), - ("cuda", 8): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]) - }) - # fmt: on - - expected_slice = expected_slices.get_expectation() - self.assertTrue( - torch.allclose( - expected_slice.to(torch_device), - out.logits[0, 0, :15].float(), - atol=1e-2, - rtol=1e-2, - ) - ) - - @slow - def test_model_7b_dola_generation(self): - # ground truth text generated with dola_layers="low", repetition_penalty=1.2 - EXPECTED_TEXT_COMPLETION = ( - "Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of " - "physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of " - "relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our " - "understanding of space and time." - ) - prompt = "Simply put, the theory of relativity states that " - tokenizer = BLTTokenizer.from_pretrained("meta-blt/BLT-2-7b-chat-hf") - model = BLTForCausalLM.from_pretrained( - "meta-blt/BLT-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16 - ) - model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - - # greedy generation outputs - generated_ids = model.generate( - **model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low" - ) - text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, text) @slow @require_torch_accelerator From 7c352ae79e5725bde7a92ff57bdb221f566eee0c Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 4 Jul 2025 17:14:22 +0000 Subject: [PATCH 059/139] first running integration tests --- src/demo_hf.py | 13 +- src/transformers/models/__init__.py | 12 +- .../models/auto/tokenization_auto.py | 41 +- src/transformers/models/blt/__init__.py | 1 - .../models/blt/tokenization_blt.py | 12 +- tests/models/blt/test_modeling_blt.py | 820 ++---------------- 6 files changed, 77 insertions(+), 822 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index 1cf587c427e5..97ad40a296bf 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -2,14 +2,11 @@ import os import torch -from huggingface_hub import hf_hub_download -from safetensors.torch import load_file from transformers.models.blt.modeling_blt import BLTForCausalLM from transformers.models.blt.tokenization_blt import BLTTokenizer - logger = logging.getLogger() os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" @@ -22,8 +19,10 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): device = "cuda" - model = BLTForCausalLM.from_pretrained("itazap/blt-1b").to(device) - + model = BLTForCausalLM.from_pretrained( + "itazap/blt-1b" + ).to(device) + tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) input_ids = torch.tensor([tokenizer.encode(prompt, add_eos=False)]).to(device) @@ -33,9 +32,7 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): input_ids, max_new_tokens=200, do_sample=False, - temperature=1.0, - pad_token_id=tokenizer.pad_id, - eos_token_id=tokenizer.eos_id, + temperature=1.0 ) generated_ids = output_ids[0][len(input_ids[0]):] diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index c32c8a795488..5fbeacd21d78 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -22,7 +22,6 @@ from .albert import * from .align import * from .altclip import * - from .arcee import * from .aria import * from .audio_spectrogram_transformer import * from .auto import * @@ -47,6 +46,7 @@ from .blenderbot_small import * from .blip import * from .blip_2 import * + from .blt import * from .bloom import * from .bridgetower import * from .bros import * @@ -65,7 +65,6 @@ from .cohere2 import * from .cohere2_vision import * from .colpali import * - from .colqwen2 import * from .conditional_detr import * from .convbert import * from .convnext import * @@ -93,7 +92,6 @@ from .depth_anything import * from .depth_pro import * from .detr import * - from .dia import * from .dialogpt import * from .diffllama import * from .dinat import * @@ -104,7 +102,6 @@ from .distilbert import * from .dit import * from .donut import * - from .dots1 import * from .dpr import * from .dpt import * from .efficientloftr import * @@ -175,7 +172,6 @@ from .janus import * from .jetmoe import * from .kosmos2 import * - from .kyutai_speech_to_text import * from .layoutlm import * from .layoutlmv2 import * from .layoutlmv3 import * @@ -210,7 +206,6 @@ from .megatron_gpt2 import * from .mgp_str import * from .mimi import * - from .minimax import * from .ministral import * from .mistral import * from .mistral3 import * @@ -269,7 +264,6 @@ from .plbart import * from .poolformer import * from .pop2piano import * - from .prompt_depth_anything import * from .prophetnet import * from .pvt import * from .pvt_v2 import * @@ -279,8 +273,6 @@ from .qwen2_audio import * from .qwen2_moe import * from .qwen2_vl import * - from .qwen3 import * - from .qwen3_moe import * from .qwen3_next import * from .qwen3_vl import * from .qwen3_vl_moe import * @@ -308,7 +300,6 @@ from .seggpt import * from .sew import * from .sew_d import * - from .shieldgemma2 import * from .siglip import * from .siglip2 import * from .smolvlm import * @@ -327,7 +318,6 @@ from .swinv2 import * from .switch_transformers import * from .t5 import * - from .t5gemma import * from .table_transformer import * from .tapas import * from .textnet import * diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 7858ae587946..52e32ddd6b77 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -73,7 +73,6 @@ ), ), ("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), - ("arcee", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("aya_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)), ("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), @@ -105,6 +104,7 @@ ("blip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("blip-2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)), + ("blt", ("BLTTokenizer", None)), ("bridgetower", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), ("bros", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("byt5", ("ByT5Tokenizer", None)), @@ -157,7 +157,6 @@ ("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)), ("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)), ("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), - ("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)), ( "cpm", @@ -208,7 +207,6 @@ "LlamaTokenizerFast" if is_tokenizers_available() else None, ), ), - ("dia", ("DiaTokenizer", None)), ( "diffllama", ( @@ -277,26 +275,11 @@ "GemmaTokenizerFast" if is_tokenizers_available() else None, ), ), - ( - "gemma3n", - ( - "GemmaTokenizer" if is_sentencepiece_available() else None, - "GemmaTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ( - "gemma3n_text", - ( - "GemmaTokenizer" if is_sentencepiece_available() else None, - "GemmaTokenizerFast" if is_tokenizers_available() else None, - ), - ), ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), - ("glm4_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), - ("glm4v_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("glm4_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)), ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), @@ -306,10 +289,6 @@ ("gpt_oss", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)), - ("granite", ("GPT2Tokenizer", None)), - ("granitemoe", ("GPT2Tokenizer", None)), - ("granitemoehybrid", ("GPT2Tokenizer", None)), - ("granitemoeshared", ("GPT2Tokenizer", None)), ("grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("helium", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), @@ -415,13 +394,6 @@ ), ), ("mgp-str", ("MgpstrTokenizer", None)), - ( - "minimax", - ( - "GPT2Tokenizer" if is_sentencepiece_available() else None, - "GPT2TokenizerFast" if is_tokenizers_available() else None, - ), - ), ( "mistral", ( @@ -673,13 +645,6 @@ "T5TokenizerFast" if is_tokenizers_available() else None, ), ), - ( - "t5gemma", - ( - "GemmaTokenizer" if is_sentencepiece_available() else None, - "GemmaTokenizerFast" if is_tokenizers_available() else None, - ), - ), ("tapas", ("TapasTokenizer", None)), ("tapex", ("TapexTokenizer", None)), ("transfo-xl", ("TransfoXLTokenizer", None)), @@ -935,7 +900,7 @@ class AutoTokenizer: """ def __init__(self): - raise OSError( + raise EnvironmentError( "AutoTokenizer is designed to be instantiated " "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method." ) diff --git a/src/transformers/models/blt/__init__.py b/src/transformers/models/blt/__init__.py index 703b81ecdd09..c29d2aa5a8f0 100644 --- a/src/transformers/models/blt/__init__.py +++ b/src/transformers/models/blt/__init__.py @@ -16,7 +16,6 @@ from ...utils import _LazyModule from ...utils.import_utils import define_import_structure - if TYPE_CHECKING: from .configuration_blt import * from .modeling_blt import * diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py index ff4004f6261b..6973a4f52631 100644 --- a/src/transformers/models/blt/tokenization_blt.py +++ b/src/transformers/models/blt/tokenization_blt.py @@ -14,6 +14,7 @@ # limitations under the License. """Tokenization classes for BLT.""" +import os from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from ...tokenization_utils import AddedToken, PreTrainedTokenizer @@ -229,14 +230,13 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str: except (UnicodeDecodeError, ValueError): return "" - def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None): + def encode(self, text: str, add_special_tokens: bool = True, **kwargs): """ Encode text exactly like the original BLT tokenizer. """ - if add_bos is None: - add_bos = self.add_bos_token - if add_eos is None: - add_eos = self.add_eos_token + # Handle the standard PreTrainedTokenizer interface + add_bos = kwargs.get('add_bos', self.add_bos_token if add_special_tokens else False) + add_eos = kwargs.get('add_eos', self.add_eos_token if add_special_tokens else False) # Since bpe_delim=False, we use the simple byte encoding tokens = bytes(text, encoding="utf-8", errors="ignore") @@ -268,4 +268,4 @@ def get_vocab_size(self) -> int: """Get vocab size like the original tokenizer.""" return self.vocab_size_unit_1 + self.offsetting_special_char -__all__ = ["BLTTokenizer"] \ No newline at end of file +__all__ = ["BLTTokenizer"] \ No newline at end of file diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index b91ccf8e489a..deb14b569322 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -16,9 +16,8 @@ import unittest from packaging import version -from parameterized import parameterized -from transformers import AutoTokenizer, BLTConfig, StaticCache, is_torch_available, set_seed +from transformers import AutoTokenizer, StaticCache, is_torch_available from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( Expectations, @@ -30,141 +29,39 @@ torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): import torch from transformers import ( + BLTConfig, BLTForCausalLM, BLTModel, - BLTTokenizer, + BLTTokenizer ) from transformers.models.blt.modeling_blt import BLTRotaryEmbedding +# import os +# import gc +# gc.collect() +# torch.cuda.empty_cache() -class BLTModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) +# os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" +# os.environ["TORCH_USE_CUDA_DSA"] = "1" - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return BLTConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - ) - - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = BLTModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class BLTModelTester(CausalLMModelTester): + if is_torch_available(): + config_class = BLTConfig + base_model_class = BLTModel + causal_lm_class = BLTForCausalLM @require_torch -class BLTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class BLTModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( ( BLTModel, @@ -173,9 +70,19 @@ class BLTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, if is_torch_available() else () ) + pipeline_model_mapping = ( + { + "feature-extraction": BLTModel, + "text-generation": BLTForCausalLM, + } + if is_torch_available() + else {} + ) test_headmasking = False test_pruning = False - fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + fx_compatible = False + model_tester_class = BLTModelTester + rotary_embedding_layer = BLTRotaryEmbedding # Enables RoPE tests if set # Need to use `0.8` instead of `0.9` for `test_cpu_offload` # This is because we are hitting edge cases with the causal_mask buffer @@ -184,232 +91,8 @@ class BLTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # used in `test_torch_compile_for_training` _torch_compile_train_cls = BLTForCausalLM if is_torch_available() else None - def setUp(self): - self.model_tester = BLTModelTester(self) - self.config_tester = ConfigTester(self, config_class=BLTConfig, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_blt_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = BLTForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_blt_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = BLTForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_blt_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = BLTForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_blt_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = BLTForTokenClassification(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - - @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) - def test_model_rope_scaling_from_config(self, scaling_type): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - short_input = ids_tensor([1, 10], config.vocab_size) - long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - original_model = BLTModel(config) - original_model.to(torch_device) - original_model.eval() - original_short_output = original_model(short_input).last_hidden_state - original_long_output = original_model(long_input).last_hidden_state - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - config.rope_scaling = {"type": scaling_type, "factor": 10.0} - scaled_model = BLTModel(config) - scaled_model.to(torch_device) - scaled_model.eval() - scaled_short_output = scaled_model(short_input).last_hidden_state - scaled_long_output = scaled_model(long_input).last_hidden_state - - # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original - # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": - torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) - else: - self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - - # The output should be different for long inputs - self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - - def test_model_rope_scaling(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - scaling_factor = 10 - short_input_length = 10 - long_input_length = int(config.max_position_embeddings * 1.5) - - # Inputs - x = torch.randn( - 1, dtype=torch.float32, device=torch_device - ) # used exclusively to get the dtype and the device - position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - position_ids_short = position_ids_short.unsqueeze(0) - position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - position_ids_long = position_ids_long.unsqueeze(0) - - # Sanity check original RoPE - original_rope = BLTRotaryEmbedding(config=config).to(torch_device) - original_cos_short, original_sin_short = original_rope(x, position_ids_short) - original_cos_long, original_sin_long = original_rope(x, position_ids_long) - torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) - - # Sanity check linear RoPE scaling - # New position "x" should match original position with index "x/scaling_factor" - config.rope_scaling = {"type": "linear", "factor": scaling_factor} - linear_scaling_rope = BLTRotaryEmbedding(config=config).to(torch_device) - linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) - linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) - torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) - for new_position in range(0, long_input_length, scaling_factor): - original_position = int(new_position // scaling_factor) - torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) - torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) - - # Sanity check Dynamic NTK RoPE scaling - # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase - # with scaling_factor (or that `inv_freq` decreases) - config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} - ntk_scaling_rope = BLTRotaryEmbedding(config=config).to(torch_device) - ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) - ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) - torch.testing.assert_close(ntk_cos_short, original_cos_short) - torch.testing.assert_close(ntk_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_sin_long, original_sin_long) - self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - - # Sanity check Yarn RoPE scaling - # Scaling should be over the entire input - config.rope_scaling = {"type": "yarn", "factor": scaling_factor} - yarn_scaling_rope = BLTRotaryEmbedding(config=config).to(torch_device) - yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) - yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) - torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_short, original_cos_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_long, original_sin_long) - def test_model_loading_old_rope_configs(self): - def _reinitialize_config(base_config, new_kwargs): - # Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation - # steps. - base_config_dict = base_config.to_dict() - new_config = BLTConfig.from_dict(config_dict={**base_config_dict, **new_kwargs}) - return new_config - - # from untouched config -> ✅ - base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common() - original_model = BLTForCausalLM(base_config).to(torch_device) - original_model(**model_inputs) - - # from a config with the expected rope configuration -> ✅ - config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}}) - original_model = BLTForCausalLM(config).to(torch_device) - original_model(**model_inputs) - - # from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC - config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}}) - original_model = BLTForCausalLM(config).to(torch_device) - original_model(**model_inputs) - - # from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config) - config = _reinitialize_config( - base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}} - ) - self.assertTrue(config.rope_scaling["type"] == "linear") - self.assertTrue(config.rope_scaling["rope_type"] == "linear") - original_model = BLTForCausalLM(config).to(torch_device) - original_model(**model_inputs) - - # from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning - with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: - config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}}) - original_model = BLTForCausalLM(config).to(torch_device) - original_model(**model_inputs) - self.assertEqual(len(logs.output), 1) - self.assertIn("factor field", logs.output[0]) - - # from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning - with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: - config = _reinitialize_config( - base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}} - ) - original_model = BLTForCausalLM(config).to(torch_device) - original_model(**model_inputs) - self.assertEqual(len(logs.output), 1) - self.assertIn("Unrecognized keys", logs.output[0]) - - # from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception - with self.assertRaises(KeyError): - config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" - - -@require_torch_accelerator +# @require_torch_accelerator class BLTIntegrationTest(unittest.TestCase): def tearDown(self): # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves @@ -420,433 +103,54 @@ def tearDown(self): @slow @require_read_token - def test_blt_3_1_hard(self): - """ - An integration test for blt 3.1. It tests against a long output to ensure the subtle numerical differences - from blt 3.1.'s RoPE can be detected - """ - # diff on `EXPECTED_TEXT`: - # 2024-08-26: updating from torch 2.3.1 to 2.4.0 slightly changes the results. - EXPECTED_TEXT = ( - "Tell me about the french revolution. The french revolution was a period of radical political and social " - "upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked " - "by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the " - "First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative " - "assembly that had not met since 1614. The Third Estate, which represented the common people, " - "demanded greater representation and eventually broke away to form the National Assembly. This marked " - "the beginning of the end of the absolute monarchy and the rise of the middle class.\n" - ) - - tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b") - model = BLTForCausalLM.from_pretrained( - "itazap/blt-1b", device_map="auto", torch_dtype=torch.bfloat16 - ) - input_text = ["Tell me about the french revolution."] - model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device) + def test_blt(self): + prompt = "my name is" - generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False) - generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(generated_text, EXPECTED_TEXT) - - @slow - @require_read_token - def test_model_7b_logits_bf16(self): - input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + EXPECTED_TEXT = " alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s" model = BLTForCausalLM.from_pretrained( - "meta-blt/BLT-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" + "itazap/blt-1b", device_map="auto" #, torch_dtype=torch.bfloat16 ) - with torch.no_grad(): - out = model(torch.tensor([input_ids]).to(torch_device)) - # Expected mean on dim = -1 + tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b") - # fmt: off - expected_means = Expectations( - { - ("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]), - ("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]), - ("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]) - }) + input_ids = torch.tensor([tokenizer.encode(prompt, add_eos=False)]).to(torch_device) - expected_mean = expected_means.get_expectation() - self.assertTrue( - torch.allclose( - expected_mean.to(torch_device), - out.logits.float().mean(-1), - atol=1e-2, - rtol=1e-2 - ) + output_ids = model.generate( + input_ids, + max_new_tokens=200 ) - # slicing logits[0, 0, 0:15] - expected_slices = Expectations( - { - ("xpu", 3): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]), - ("cuda", 7): torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]), - ("cuda", 8): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]) - }) - # fmt: on - expected_slice = expected_slices.get_expectation() - self.assertTrue( - torch.allclose( - expected_slice.to(torch_device), - out.logits[0, 0, :15].float(), - atol=1e-2, - rtol=1e-2, - ) - ) + generated_ids = output_ids[0][len(input_ids[0]):] + output_text = tokenizer.decode(generated_ids.tolist()) + + print(f'Prompt: "{prompt}"') + print(f'Completion: "{output_text}"') + print('here') + self.assertEqual(output_text, EXPECTED_TEXT) @slow - @require_torch_accelerator @require_read_token - def test_compile_static_cache(self): - # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 - # work as intended. See https://github.com/pytorch/pytorch/issues/121943 - if version.parse(torch.__version__) < version.parse("2.3.0"): - self.skipTest(reason="This test requires torch >= 2.3 to run.") - - NUM_TOKENS_TO_GENERATE = 40 - # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test - # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. - EXPECTED_TEXT_COMPLETION = [ - "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " - "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " - "theory of relativ", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " - "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ] + def test_model_logits(self): + input_ids = [1, 42, 21, 12, 43, 23, 1, 4] - prompts = [ - "Simply put, the theory of relativity states that ", - "My favorite all time favorite condiment is ketchup.", - ] - tokenizer = BLTTokenizer.from_pretrained("meta-blt/BLT-2-7b-hf", pad_token="", padding_side="right") model = BLTForCausalLM.from_pretrained( - "meta-blt/BLT-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16 - ) - inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - - # Dynamic Cache - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) - dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) - - # Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used) - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" - ) - static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) - - @slow - @require_read_token - def test_export_static_cache(self): - if version.parse(torch.__version__) < version.parse("2.4.0"): - self.skipTest(reason="This test requires torch >= 2.4 to run.") - - from transformers.integrations.executorch import ( - TorchExportableModuleWithStaticCache, - convert_and_export_with_cache, - ) - - blt_models = { - "meta-blt/BLT-3.2-1B": [ - "Simply put, the theory of relativity states that 1) the speed of light is the same for all " - "observers, regardless of their location, and 2) the laws of physics are the same for all observers" - ], - } - - for blt_model_ckp, EXPECTED_TEXT_COMPLETION in blt_models.items(): - # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(blt_model_ckp, pad_token="", padding_side="right") - max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ - "input_ids" - ].shape[-1] - - # Load model - device = "cpu" - dtype = torch.bfloat16 - cache_implementation = "static" - attn_implementation = "sdpa" - batch_size = 1 - model = BLTForCausalLM.from_pretrained( - blt_model_ckp, - device_map=device, - torch_dtype=dtype, - attn_implementation=attn_implementation, - generation_config=GenerationConfig( - use_cache=True, - cache_implementation=cache_implementation, - max_length=max_generation_length, - cache_config={ - "batch_size": batch_size, - "max_cache_len": max_generation_length, - "device": device, - }, - ), - ) - - prompts = ["Simply put, the theory of relativity states that "] - prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - prompt_token_ids = prompt_tokens["input_ids"] - max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] - - # Static Cache + export - exported_program = convert_and_export_with_cache(model) - ep_generated_ids = TorchExportableModuleWithStaticCache.generate( - exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens - ) - ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) - - -@slow -@require_torch_accelerator -class Mask4DTestHard(unittest.TestCase): - def tearDown(self): - cleanup(torch_device, gc_collect=True) - - def setUp(self): - cleanup(torch_device, gc_collect=True) - model_name = "TinyBLT/TinyBLT-1.1B-Chat-v1.0" - self.model_dtype = torch.float32 - self.tokenizer = BLTTokenizer.from_pretrained(model_name) - self.model = BLTForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device) - - def get_test_data(self): - template = "my favorite {}" - items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item - - batch_separate = [template.format(x) for x in items] # 3 separate lines - batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated - - input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device) - input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device) - - mask_shared_prefix = torch.tensor( - [ - [ - [ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1], - ] - ] - ], - device=torch_device, + "itazap/blt-1b", device_map="auto" ) - position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device) - - # building custom positions ids based on custom mask - position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1) - # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device) - - # inverting the mask - min_dtype = torch.finfo(self.model_dtype).min - mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype - - return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix - - def test_stacked_causal_mask(self): - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # single forward run with 4D custom mask - logits_shared_prefix = self.model.forward( - input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix - ).logits - logits_shared_prefix_last = logits_shared_prefix[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : - ] # last three tokens - decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] - - self.assertEqual(decoded, decoded_shared_prefix) - - def test_partial_stacked_causal_mask(self): - # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks - - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # 2 forward runs with custom 4D masks - part_a = 3 # split point - - input_1a = input_ids_shared_prefix[:, :part_a] - position_ids_1a = position_ids_shared_prefix[:, :part_a] - mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] - - outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a) - past_key_values_a = outs_1a["past_key_values"] - - # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len]) - input_1b = input_ids_shared_prefix[:, part_a:] - position_ids_1b = position_ids_shared_prefix[:, part_a:] - mask_1b = mask_shared_prefix[:, :, part_a:, :] - outs_1b = self.model.forward( - input_1b, - attention_mask=mask_1b, - position_ids=position_ids_1b, - past_key_values=past_key_values_a, - ) - decoded_1b = [ - self.tokenizer.decode(t) - for t in outs_1b.logits.argmax(-1)[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a - ] - ] - self.assertEqual(decoded, decoded_1b) - - def test_stacked_causal_mask_static_cache(self): - """same as above but with StaticCache""" - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # upgrade the model with StaticCache - max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache( - config=self.model.config, - max_batch_size=1, - max_cache_len=max_cache_len, - device=torch_device, - dtype=self.model.dtype, - ) - - padded_attention_mask = torch.nn.functional.pad( - input=mask_shared_prefix, - pad=(0, max_cache_len - mask_shared_prefix.shape[-1]), - mode="constant", - value=torch.finfo(self.model_dtype).min, - ) - - # single forward run with 4D custom mask - logits_shared_prefix = self.model.forward( - input_ids_shared_prefix, - attention_mask=padded_attention_mask, - position_ids=position_ids_shared_prefix, - cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device), - past_key_values=past_key_values, - ).logits - logits_shared_prefix_last = logits_shared_prefix[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], : - ] # last three tokens - decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)] - - self.assertEqual(decoded, decoded_shared_prefix) - - def test_partial_stacked_causal_mask_static_cache(self): - # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks - # we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len]) - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self.get_test_data() - - # regular batch - logits = self.model.forward(input_ids, position_ids=position_ids).logits - logits_last = logits[:, -1, :] # last tokens in each batch line - decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)] - - # upgrade the model with StaticCache - max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache( - config=self.model.config, - max_batch_size=1, - max_cache_len=max_cache_len, - device=torch_device, - dtype=self.model.dtype, - ) - - # forward run for the first part of input - part_a = 3 # split point - - input_1a = input_ids_shared_prefix[:, :part_a] - position_ids_1a = position_ids_shared_prefix[:, :part_a] - mask_1a = mask_shared_prefix[:, :, :part_a, :part_a] - - padded_mask_1a = torch.nn.functional.pad( - input=mask_1a, - pad=(0, max_cache_len - mask_1a.shape[-1]), - mode="constant", - value=torch.finfo(self.model_dtype).min, - ) - - _ = self.model.forward( - input_1a, - attention_mask=padded_mask_1a, - position_ids=position_ids_1a, - cache_position=torch.arange(part_a, device=torch_device), - past_key_values=past_key_values, - ) - - # forward run for the second part of input - input_1b = input_ids_shared_prefix[:, part_a:] - position_ids_1b = position_ids_shared_prefix[:, part_a:] - mask_1b = mask_shared_prefix[:, :, part_a:, :] - - padded_mask_1b = torch.nn.functional.pad( - input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0 - ) - - outs_1b = self.model.forward( - input_1b, - attention_mask=padded_mask_1b, - position_ids=position_ids_1b, - cache_position=torch.arange( - part_a, - input_ids_shared_prefix.shape[-1], - device=torch_device, - ), - past_key_values=past_key_values, - ) - decoded_1b = [ - self.tokenizer.decode(t) - for t in outs_1b.logits.argmax(-1)[ - 0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a - ] - ] - self.assertEqual(decoded, decoded_1b) + with torch.no_grad(): + output = model(torch.tensor([input_ids]).to(torch_device))[0] + + EXPECTED_OUTPUT = torch.tensor([[-10.4948, -10.7065, -6.1813, -10.5545, -10.3428, -9.1493, -8.4937, + -8.6382, -9.2159, -9.5907, -9.3679, -8.4184, -9.0655, -3.4436, + 2.9616, -10.3157, -6.3723, -6.0133, -9.7100, -9.2128, -8.8064, + -9.8179, -9.7516, -9.4681, -9.7715, -9.4897, -9.0491, -9.8098, + -9.4648, -9.3294], + [-13.3010, -13.1910, -5.7230, -13.2895, -13.4864, -8.7140, -7.0275, + -7.0182, -10.1362, -10.3762, -9.9086, -7.8049, -8.8660, -5.2711, + -3.5778, -12.5346, -9.1609, -6.7925, -10.3717, -9.2650, -10.6393, + -11.4807, -11.2128, -10.9615, -10.5806, -10.8873, -11.0651, -11.3471, + -10.5437, -9.9688]]).to(torch_device) + + torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-4, atol=1e-4) From c32f692aa15fc40b6a61fe0ea3f7129ee76ae770 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 7 Jul 2025 13:08:50 +0000 Subject: [PATCH 060/139] added integration tests --- src/transformers/models/auto/modeling_auto.py | 1 + .../models/blt/configuration_blt.py | 31 +++- src/transformers/models/blt/modeling_blt.py | 14 +- .../models/blt/tokenization_blt.py | 49 ++++-- tests/models/blt/test_modeling_blt.py | 162 ++++++++++++++---- 5 files changed, 186 insertions(+), 71 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index abc4557ecfe4..961ea5ddd91c 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -72,6 +72,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("blip-2", "Blip2Model"), ("blip_2_qformer", "Blip2QFormerModel"), ("bloom", "BloomModel"), + ("blt", "BLTModel"), ("bridgetower", "BridgeTowerModel"), ("bros", "BrosModel"), ("camembert", "CamembertModel"), diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index a9a245afebc3..50788081813f 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -66,9 +66,10 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling or {"rope_type": "default"} self.hidden_act = hidden_act - self._attn_implementation = _attn_implementation super().__init__(**kwargs) + + self._attn_implementation = _attn_implementation class BLTLocalDecoderConfig(PretrainedConfig): """ @@ -117,6 +118,9 @@ def __init__( super().__init__(**kwargs) + self._attn_implementation = _attn_implementation + + class BLTGlobalTransformerConfig(PretrainedConfig): """ @@ -153,10 +157,12 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling or {"rope_type": "default"} self.hidden_act = hidden_act - self._attn_implementation = _attn_implementation super().__init__(**kwargs) + self._attn_implementation = _attn_implementation + + class BLTPatcherConfig(PretrainedConfig): r""" @@ -187,7 +193,7 @@ class BLTPatcherConfig(PretrainedConfig): Make feedforward dimension multiple of this for the entropy model. rope_theta (`float`, *optional*, defaults to 10000.0): RoPE theta parameter for the entropy model. - attn_impl (`str`, *optional*, defaults to "sdpa"): + _attn_implementation (`str`, *optional*, defaults to "sdpa"): Attention implementation for the entropy model. attn_bias_type (`str`, *optional*, defaults to "causal"): Attention bias type for the entropy model. @@ -206,7 +212,7 @@ def __init__( norm_eps=1e-5, dropout=0.0, rope_theta=10000.0, - attn_impl="sdpa", + _attn_implementation="sdpa", attn_bias_type="causal", intermediate_size=None, **kwargs, @@ -221,13 +227,15 @@ def __init__( self.norm_eps = norm_eps self.dropout = dropout self.rope_theta = rope_theta - self.attn_impl = attn_impl self.attn_bias_type = attn_bias_type self.hidden_act = "silu" # BLT uses silu activation self.intermediate_size = intermediate_size or int(8 * self.hidden_size / 3) self.rope_scaling = {"rope_type": "default"} super().__init__(**kwargs) + self._attn_implementation = _attn_implementation + + class BLTConfig(PretrainedConfig): r""" @@ -242,6 +250,9 @@ class BLTConfig(PretrainedConfig): Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented. max_position_embeddings (`int`, *optional*, defaults to 1024): The maximum sequence length that this model can handle. + _attn_implementation (`str`, *optional*, defaults to "sdpa"): + The attention implementation to use. Can be "eager", "sdpa", etc. This setting is propagated to all + sub-components (encoder, decoder, global transformer, patcher). # Patching configuration patch_in_forward (`bool`, *optional*, defaults to False): @@ -322,6 +333,7 @@ def __init__( decoder_config=None, global_config=None, tie_word_embeddings=False, + _attn_implementation="sdpa", **kwargs, ): @@ -329,6 +341,7 @@ def __init__( self.tie_word_embeddings = tie_word_embeddings self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings + self._attn_implementation = _attn_implementation # Patching configuration self.patch_in_forward = patch_in_forward @@ -348,7 +361,7 @@ def __init__( # Initialize component configurations if patcher_config is None: - self.patcher_config = BLTPatcherConfig() + self.patcher_config = BLTPatcherConfig(_attn_implementation=_attn_implementation) logger.info("patcher_config is None, using default BLT patcher config") elif isinstance(patcher_config, dict): self.patcher_config = BLTPatcherConfig(**patcher_config) @@ -356,7 +369,7 @@ def __init__( self.patcher_config = patcher_config if encoder_config is None: - self.encoder_config = BLTLocalEncoderConfig() + self.encoder_config = BLTLocalEncoderConfig(_attn_implementation=_attn_implementation) logger.info("encoder_config is None, using default BLT encoder config") elif isinstance(encoder_config, dict): self.encoder_config = BLTLocalEncoderConfig(**encoder_config) @@ -364,7 +377,7 @@ def __init__( self.encoder_config = encoder_config if decoder_config is None: - self.decoder_config = BLTLocalDecoderConfig() + self.decoder_config = BLTLocalDecoderConfig(_attn_implementation=_attn_implementation) logger.info("decoder_config is None, using default BLT decoder config") elif isinstance(decoder_config, dict): self.decoder_config = BLTLocalDecoderConfig(**decoder_config) @@ -372,7 +385,7 @@ def __init__( self.decoder_config = decoder_config if global_config is None: - self.global_config = BLTGlobalTransformerConfig() + self.global_config = BLTGlobalTransformerConfig(_attn_implementation=_attn_implementation) logger.info("global_config is None, using default BLT global config") elif isinstance(global_config, dict): self.global_config = BLTGlobalTransformerConfig(**global_config) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 84d874daa76e..d1170a99b8fb 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -250,7 +250,7 @@ def __init__(self, config, layer_idx: int): self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = None + self.scaling = self.head_dim ** -0.5 self.rope_theta = config.rope_theta self.layer_idx = layer_idx @@ -291,7 +291,6 @@ def forward( attention_interface: Callable = eager_attention_forward output_attentions = False - self.config._attn_implementation = "sdpa" if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -732,7 +731,7 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = None #self.head_dim ** -0.5 + self.scaling = self.head_dim ** -0.5 self.dropout = config.dropout self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) @@ -784,7 +783,6 @@ def forward( attention_interface: Callable = eager_attention_forward - self.config._attn_implementation = "sdpa" if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -978,15 +976,15 @@ def forward( self.config.max_patch_length ) patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) - cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( - patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32 - ) encoder_embeds = compute_hash_embeddings( tokens, self.local_encoder, self.encoder_hash_tok_embedding, self.config.encoder_hash_byte_group_nb_functions, self.config.encoder_hash_byte_group_size, self.config.encoder_hash_byte_group_vocab, ) + cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( + patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype + ) encoder_hidden_states, encoder_cross_states = self.local_encoder( input_ids=tokens, input_embeds=encoder_embeds, @@ -1002,7 +1000,7 @@ def forward( ) decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( - decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32 + decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, encoder_embeds.dtype ) output, _ = self.local_decoder( tokens=tokens, diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py index 6973a4f52631..b699bf50d4c3 100644 --- a/src/transformers/models/blt/tokenization_blt.py +++ b/src/transformers/models/blt/tokenization_blt.py @@ -79,18 +79,16 @@ def __init__( unk_token="", boe_token="", add_bos_token=True, - add_eos_token=True, + add_eos_token=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False, **kwargs, ): - # Store BLT-specific parameters first self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token self.vocab_size_unit_1 = BYTE_UNITS self.offsetting_special_char = OFFSET - # BLT token IDs (exactly like original) self.boe_id = BOE_ID self.bos_id = BOS_ID self.eos_id = EOS_ID @@ -98,7 +96,6 @@ def __init__( self.bpe_id = BPE_ID self.n_words = self.vocab_size_unit_1 + self.offsetting_special_char - # Convert string tokens to AddedToken objects bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token @@ -175,14 +172,13 @@ def _convert_token_to_id(self, token: str) -> int: return self.boe_id else: try: - # Convert byte value string to int and add offset (like original) + # Convert byte value string to int and add offset byte_val = int(token) if 0 <= byte_val <= 255: return byte_val + self.offsetting_special_char except ValueError: pass - # Check if it's in added tokens return self.added_tokens_encoder.get(token, self.unk_token_id) def _convert_id_to_token(self, index: int) -> str: @@ -197,7 +193,7 @@ def _convert_id_to_token(self, index: int) -> str: elif index == self.boe_id: return str(self.boe_token) elif index >= self.offsetting_special_char and index < self.vocab_size: - # Convert back to byte value (like original) + # Convert back to byte value byte_val = index - self.offsetting_special_char return str(byte_val) else: @@ -217,31 +213,27 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str: continue try: - # Convert token back to byte value (like original decode method) + # Convert token back to byte value byte_val = int(token) if 0 <= byte_val <= 255: byte_values.append(byte_val) except ValueError: continue - # Convert byte values back to string (exactly like original) + # Convert byte values back to string try: return bytes(byte_values).decode("utf-8", errors="ignore") except (UnicodeDecodeError, ValueError): return "" def encode(self, text: str, add_special_tokens: bool = True, **kwargs): - """ - Encode text exactly like the original BLT tokenizer. - """ - # Handle the standard PreTrainedTokenizer interface add_bos = kwargs.get('add_bos', self.add_bos_token if add_special_tokens else False) add_eos = kwargs.get('add_eos', self.add_eos_token if add_special_tokens else False) # Since bpe_delim=False, we use the simple byte encoding tokens = bytes(text, encoding="utf-8", errors="ignore") - # Offsetting (exactly like original) + # Offsetting tokens = [int(unit) + self.offsetting_special_char for unit in tokens] if add_bos: @@ -251,10 +243,7 @@ def encode(self, text: str, add_special_tokens: bool = True, **kwargs): return tokens - def decode(self, tokens: list[int], cut_at_eos: bool = False): - """ - Decode tokens exactly like the original BLT tokenizer. - """ + def decode(self, tokens, cut_at_eos: bool = False): if cut_at_eos: for k, t in enumerate(tokens): if t == self.eos_id: @@ -264,6 +253,30 @@ def decode(self, tokens: list[int], cut_at_eos: bool = False): [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0] ).decode("utf-8", errors="ignore") + def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating and + adding special tokens. A BLT sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + bos = [self.bos_id] if self.add_bos_token else [] + eos = [self.eos_id] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos + token_ids_0 + eos + return bos + token_ids_0 + eos + token_ids_1 + eos + def get_vocab_size(self) -> int: """Get vocab size like the original tokenizer.""" return self.vocab_size_unit_1 + self.offsetting_special_char diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index deb14b569322..f6e240f8d79d 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -18,13 +18,12 @@ from packaging import version from transformers import AutoTokenizer, StaticCache, is_torch_available -from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( - Expectations, cleanup, require_read_token, require_torch, require_torch_accelerator, + require_torch_bf16, slow, torch_device, ) @@ -43,16 +42,6 @@ ) from transformers.models.blt.modeling_blt import BLTRotaryEmbedding -# import os -# import gc -# gc.collect() -# torch.cuda.empty_cache() - - -# os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" -# os.environ["TORCH_USE_CUDA_DSA"] = "1" - - class BLTModelTester(CausalLMModelTester): if is_torch_available(): config_class = BLTConfig @@ -92,7 +81,7 @@ class BLTModelTest(CausalLMModelTest, unittest.TestCase): _torch_compile_train_cls = BLTForCausalLM if is_torch_available() else None -# @require_torch_accelerator +@require_torch_accelerator class BLTIntegrationTest(unittest.TestCase): def tearDown(self): # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves @@ -104,43 +93,30 @@ def tearDown(self): @slow @require_read_token def test_blt(self): - prompt = "my name is" + NUM_TOKENS_TO_GENERATE = 200 + EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s" - EXPECTED_TEXT = " alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s" + prompt = "my name is" model = BLTForCausalLM.from_pretrained( - "itazap/blt-1b", device_map="auto" #, torch_dtype=torch.bfloat16 + "itazap/blt-1b", + device_map="auto", ) tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b") - input_ids = torch.tensor([tokenizer.encode(prompt, add_eos=False)]).to(torch_device) + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - output_ids = model.generate( - input_ids, - max_new_tokens=200 - ) - - generated_ids = output_ids[0][len(input_ids[0]):] - output_text = tokenizer.decode(generated_ids.tolist()) - - print(f'Prompt: "{prompt}"') - print(f'Completion: "{output_text}"') - print('here') + old_input_ids = torch.tensor([tokenizer.encode(prompt, add_eos=False)]).to(torch_device) + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + output_text = tokenizer.decode(generated_ids[0]) self.assertEqual(output_text, EXPECTED_TEXT) + @slow @require_read_token def test_model_logits(self): - input_ids = [1, 42, 21, 12, 43, 23, 1, 4] - - model = BLTForCausalLM.from_pretrained( - "itazap/blt-1b", device_map="auto" - ) - - with torch.no_grad(): - output = model(torch.tensor([input_ids]).to(torch_device))[0] EXPECTED_OUTPUT = torch.tensor([[-10.4948, -10.7065, -6.1813, -10.5545, -10.3428, -9.1493, -8.4937, -8.6382, -9.2159, -9.5907, -9.3679, -8.4184, -9.0655, -3.4436, @@ -153,4 +129,118 @@ def test_model_logits(self): -11.4807, -11.2128, -10.9615, -10.5806, -10.8873, -11.0651, -11.3471, -10.5437, -9.9688]]).to(torch_device) + input_ids = [1, 42, 21, 12, 43, 23, 1, 4] + + model = BLTForCausalLM.from_pretrained( + "itazap/blt-1b", device_map="auto" + ) + + with torch.no_grad(): + output = model(torch.tensor([input_ids]).to(torch_device))[0] + torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-4, atol=1e-4) + + @slow + @require_read_token + @require_torch_bf16 + def test_model_bf16(self): + """Test BLT model with bfloat16 precision.""" + NUM_TOKENS_TO_GENERATE = 200 + EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m" + + prompt = "my name is" + + model = BLTForCausalLM.from_pretrained( + "itazap/blt-1b", + device_map="auto", + torch_dtype=torch.bfloat16 + ) + + tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b") + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + + output_text = tokenizer.decode(generated_ids[0]) + self.assertEqual(output_text, EXPECTED_TEXT) + + @slow + @require_read_token + @require_torch_bf16 + def test_model_logits_bf16(self): + """Test BLT model logits with bfloat16 precision.""" + EXPECTED_OUTPUT = torch.tensor([[-10.5000, -10.7500, -6.2188, -10.5625, -10.3750, -9.1875, -8.5000, + -8.6250, -9.1875, -9.6250, -9.3750, -8.5000, -9.0625, -3.4219, + 2.9688, -10.3125, -6.4062, -6.0000, -9.6875, -9.2500, -8.8125, + -9.8750, -9.7500, -9.5000, -9.8125, -9.5000, -9.0625, -9.8750, + -9.5000, -9.3750], + [-13.3750, -13.2500, -5.5938, -13.3750, -13.5000, -8.7500, -7.0312, + -7.0000, -10.1875, -10.3750, -9.8750, -7.8125, -8.8750, -5.3125, + -3.5469, -12.5625, -9.1875, -6.7812, -10.3750, -9.2500, -10.6250, + -11.5000, -11.2500, -11.0000, -10.6250, -10.9375, -11.1250, -11.3750, + -10.5625, -10.0000]], dtype=torch.bfloat16).to(torch_device) + + input_ids = [1, 42, 21, 12, 43, 23, 1, 4] + + model = BLTForCausalLM.from_pretrained( + "itazap/blt-1b", + device_map="auto", + torch_dtype=torch.bfloat16 + ) + + with torch.no_grad(): + output = model(torch.tensor([input_ids]).to(torch_device))[0] + + torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-3, atol=1e-3) + + @slow + @require_read_token + def test_model_eager(self): + """Test BLT model with bfloat16 precision using eager attention implementation.""" + NUM_TOKENS_TO_GENERATE = 200 + EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s" + + prompt = "my name is" + + model = BLTForCausalLM.from_pretrained( + "itazap/blt-1b", + device_map="auto", + attn_implementation="eager" + ) + + tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b") + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + + output_text = tokenizer.decode(generated_ids[0]) + self.assertEqual(output_text, EXPECTED_TEXT) + + @slow + @require_read_token + @require_torch_bf16 + def test_model_bf16_static_cache(self): + """Test BLT model with bfloat16 precision and static cache.""" + NUM_TOKENS_TO_GENERATE = 200 + EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m" + + prompt = "my name is" + + model = BLTForCausalLM.from_pretrained( + "itazap/blt-1b", + device_map="auto", + torch_dtype=torch.bfloat16 + ) + + model.generation_config.cache_implementation = "static" + + tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b") + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + + output_text = tokenizer.decode(generated_ids[0]) + self.assertEqual(output_text, EXPECTED_TEXT) From e72383031bb639f13078b63647277c20212d00b4 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 8 Jul 2025 09:51:24 +0000 Subject: [PATCH 061/139] adding tokenization tests, integration tests, and cleaned up tokenization file, + ruff --- .../models/blt/configuration_blt.py | 53 +- .../models/blt/convert_blt_weights_to_hf.py | 50 +- src/transformers/models/blt/modeling_blt.py | 259 ++-- .../models/blt/tokenization_blt.py | 227 ++-- tests/models/blt/test_modeling_blt.py | 205 +++- tests/models/blt/test_tokenization_blt.py | 1080 ++++------------- 6 files changed, 656 insertions(+), 1218 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 50788081813f..15613e307dc5 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -23,15 +23,16 @@ logger = logging.get_logger(__name__) + class BLTLocalEncoderConfig(PretrainedConfig): """ Configuration class for the BLT Local Encoder component. """ - + model_type = "blt_local_encoder" - + def __init__( - self, + self, vocab_size=256, cross_attn_all_layers=True, cross_attn_k=2, @@ -66,18 +67,19 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling or {"rope_type": "default"} self.hidden_act = hidden_act - + super().__init__(**kwargs) self._attn_implementation = _attn_implementation - + + class BLTLocalDecoderConfig(PretrainedConfig): """ Configuration class for the BLT Local Decoder component. """ - + model_type = "blt_local_decoder" - + def __init__( self, vocab_size=256, @@ -121,14 +123,13 @@ def __init__( self._attn_implementation = _attn_implementation - class BLTGlobalTransformerConfig(PretrainedConfig): """ Configuration class for the BLT Global Transformer component. """ - + model_type = "blt_global_transformer" - + def __init__( self, hidden_size=512, @@ -157,17 +158,16 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling or {"rope_type": "default"} self.hidden_act = hidden_act - + super().__init__(**kwargs) self._attn_implementation = _attn_implementation - class BLTPatcherConfig(PretrainedConfig): r""" Configuration class for the BLT Patcher/Entropy model component. - + Args: vocab_size (`int`, *optional*, defaults to 256): Vocabulary size for the entropy model used in patching. @@ -198,9 +198,9 @@ class BLTPatcherConfig(PretrainedConfig): attn_bias_type (`str`, *optional*, defaults to "causal"): Attention bias type for the entropy model. """ - + model_type = "blt_patcher" - + def __init__( self, vocab_size=256, @@ -236,7 +236,6 @@ def __init__( self._attn_implementation = _attn_implementation - class BLTConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`BLTModel`]. It is used to instantiate a @@ -251,7 +250,7 @@ class BLTConfig(PretrainedConfig): max_position_embeddings (`int`, *optional*, defaults to 1024): The maximum sequence length that this model can handle. _attn_implementation (`str`, *optional*, defaults to "sdpa"): - The attention implementation to use. Can be "eager", "sdpa", etc. This setting is propagated to all + The attention implementation to use. Can be "eager", "sdpa", etc. This setting is propagated to all sub-components (encoder, decoder, global transformer, patcher). # Patching configuration @@ -308,10 +307,10 @@ class BLTConfig(PretrainedConfig): model_type = "blt" keys_to_ignore_at_inference = ["past_key_values"] sub_configs = { - "patcher_config": BLTPatcherConfig, - "encoder_config": BLTLocalEncoderConfig, - "decoder_config": BLTLocalDecoderConfig, - "global_config": BLTGlobalTransformerConfig + "patcher_config": BLTPatcherConfig, + "encoder_config": BLTLocalEncoderConfig, + "decoder_config": BLTLocalDecoderConfig, + "global_config": BLTGlobalTransformerConfig, } def __init__( @@ -336,7 +335,6 @@ def __init__( _attn_implementation="sdpa", **kwargs, ): - # Basic model configuration self.tie_word_embeddings = tie_word_embeddings self.vocab_size = vocab_size @@ -394,10 +392,11 @@ def __init__( super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + __all__ = [ - "BLTConfig", - "BLTPatcherConfig", - "BLTLocalEncoderConfig", - "BLTLocalDecoderConfig", - "BLTGlobalTransformerConfig", + "BLTConfig", + "BLTPatcherConfig", + "BLTLocalEncoderConfig", + "BLTLocalDecoderConfig", + "BLTGlobalTransformerConfig", ] diff --git a/src/transformers/models/blt/convert_blt_weights_to_hf.py b/src/transformers/models/blt/convert_blt_weights_to_hf.py index 26c05477a169..d025e09cbc31 100644 --- a/src/transformers/models/blt/convert_blt_weights_to_hf.py +++ b/src/transformers/models/blt/convert_blt_weights_to_hf.py @@ -43,8 +43,10 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str # Create patcher config patcher_hidden_size = int(entropy_model_params.get("dim", 512)) patcher_multiple_of = int(entropy_model_params.get("multiple_of", 256)) - patcher_intermediate_size = patcher_multiple_of * ((int(8 * patcher_hidden_size / 3) + patcher_multiple_of - 1) // patcher_multiple_of) - + patcher_intermediate_size = patcher_multiple_of * ( + (int(8 * patcher_hidden_size / 3) + patcher_multiple_of - 1) // patcher_multiple_of + ) + patcher_config = { "vocab_size": int(entropy_model_params.get("vocab_size", 256)), "hidden_size": patcher_hidden_size, @@ -65,8 +67,10 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str # Create encoder config encoder_hidden_size = unified_config.get("dim_local_encoder", 1024) encoder_multiple_of = unified_config.get("multiple_of", 256) - encoder_intermediate_size = encoder_multiple_of * ((int(8 * encoder_hidden_size / 3) + encoder_multiple_of - 1) // encoder_multiple_of) - + encoder_intermediate_size = encoder_multiple_of * ( + (int(8 * encoder_hidden_size / 3) + encoder_multiple_of - 1) // encoder_multiple_of + ) + encoder_config = { "vocab_size": unified_config.get("vocab_size", 256), "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_encoder", False), @@ -79,7 +83,8 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str "num_hidden_layers": unified_config.get("n_layers_local_encoder", 1), "norm_eps": unified_config.get("norm_eps", 1e-5), "dropout": unified_config.get("dropout", 0.0), - "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024), + "max_position_embeddings": unified_config.get("max_encoder_seq_length") + or unified_config.get("max_seqlen", 1024), "rope_theta": unified_config.get("rope_theta", 10000.0), "rope_scaling": {"rope_type": "default"}, "hidden_act": unified_config.get("hidden_act", "silu"), @@ -90,8 +95,10 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str # Create decoder config decoder_hidden_size = unified_config.get("dim_local_decoder", 1024) decoder_multiple_of = unified_config.get("multiple_of", 256) - decoder_intermediate_size = decoder_multiple_of * ((int(8 * decoder_hidden_size / 3) + decoder_multiple_of - 1) // decoder_multiple_of) - + decoder_intermediate_size = decoder_multiple_of * ( + (int(8 * decoder_hidden_size / 3) + decoder_multiple_of - 1) // decoder_multiple_of + ) + decoder_config = { "vocab_size": unified_config.get("vocab_size", 256), "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_decoder", False), @@ -103,7 +110,8 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str "num_hidden_layers": unified_config.get("n_layers_local_decoder", 9), "norm_eps": unified_config.get("norm_eps", 1e-5), "dropout": unified_config.get("dropout", 0.0), - "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024), + "max_position_embeddings": unified_config.get("max_encoder_seq_length") + or unified_config.get("max_seqlen", 1024), "rope_theta": unified_config.get("rope_theta", 10000.0), "rope_scaling": {"rope_type": "default"}, "hidden_act": unified_config.get("hidden_act", "silu"), @@ -114,8 +122,10 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str # Create global transformer config global_hidden_size = unified_config.get("dim_global", 2048) global_multiple_of = unified_config.get("multiple_of", 256) - global_intermediate_size = global_multiple_of * ((int(8 * global_hidden_size / 3) + global_multiple_of - 1) // global_multiple_of) - + global_intermediate_size = global_multiple_of * ( + (int(8 * global_hidden_size / 3) + global_multiple_of - 1) // global_multiple_of + ) + global_config = { "hidden_size": global_hidden_size, "num_attention_heads": unified_config.get("n_heads_global", 16), @@ -163,7 +173,7 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str return main_config_dict -def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: +def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: component_mappings = { ".attention.": ".self_attn.", ".feed_forward.": ".mlp.", @@ -181,18 +191,18 @@ def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch ".wo.": ".o_proj.", ".output.": ".lm_head.", } - + new_state_dict = {} - + for old_key, tensor in state_dict.items(): new_key = old_key - + for old_pattern, new_pattern in component_mappings.items(): if old_pattern in new_key: new_key = new_key.replace(old_pattern, new_pattern) - + new_state_dict[new_key] = tensor - + return new_state_dict @@ -211,9 +221,9 @@ def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, tor for key, tensor in entropy_weights.items(): patcher_key = f"patcher.{key}" unified_weights[patcher_key] = tensor - + unified_weights = apply_weight_mapping(unified_weights) - + decoder_lm_head_key = "local_decoder.lm_head.weight" top_lm_head_key = "lm_head.weight" unified_weights[top_lm_head_key] = unified_weights[decoder_lm_head_key] @@ -227,9 +237,9 @@ def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, tor prefixed_weights[f"model.{key}"] = tensor else: prefixed_weights[key] = tensor - + unified_weights = prefixed_weights - + return unified_weights diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index d1170a99b8fb..e089d2cf97a8 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -49,6 +49,7 @@ logger = logging.get_logger(__name__) + class PatchingModeEnum(str, Enum): entropy = "entropy" bpe = "bpe" @@ -71,14 +72,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -101,31 +102,31 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # Extract first head_dim//2 elements which correspond to the unique frequencies # This matches the original BLT approach which uses head_dim//2 frequency pairs head_dim = q.shape[-1] - cos_freqs = cos[..., :head_dim//2] # [B, S, D/2] - sin_freqs = sin[..., :head_dim//2] # [B, S, D/2] - + cos_freqs = cos[..., : head_dim // 2] # [B, S, D/2] + sin_freqs = sin[..., : head_dim // 2] # [B, S, D/2] + # Expand cos/sin to match query/key tensor format [B, H, S, D/2] cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] - + # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... - q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] - k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] - + q_pairs = q.view(*q.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] + k_pairs = k.view(*k.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] + # Extract real and i parts q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] - + # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] q_real_rot = cos_freqs * q_real - sin_freqs * q_imag q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag k_real_rot = cos_freqs * k_real - sin_freqs * k_imag k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag - + # Recombine pairs and reshape back to original format q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] - + return q_rot.type_as(q), k_rot.type_as(k) @@ -142,7 +143,7 @@ def __init__(self, config): def forward(self, x: torch.Tensor) -> torch.Tensor: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj - + class BLTRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -250,7 +251,7 @@ def __init__(self, config, layer_idx: int): self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.rope_theta = config.rope_theta self.layer_idx = layer_idx @@ -260,30 +261,30 @@ def __init__(self, config, layer_idx: int): self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, - output_attentions: bool = False, - use_cache: bool = False, - past_key_value=None, - cache_position=None, - **kwargs, + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, ): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} @@ -318,32 +319,43 @@ def forward( attn_weights = None return attn_output, attn_weights, past_key_value - + def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): primes = [ - 1000000007, 5915587277, 1500450271, 3267000013, 5754853343, - 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313, + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, ] prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) - prime_powers = prime ** powers + prime_powers = prime**powers return torch.sum(token_tensor * prime_powers, dim=-1) -def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): +def byte_group_hash_function( + token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 +): """Hash token groups and map to range [0, max_hash].""" with torch.no_grad(): batch_size, seq_len = token_ids.shape # Add padding for sliding window padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) padded_tokens = torch.cat([padding, token_ids], dim=1) - + # Create sliding windows and compute hashes windows = padded_tokens.unfold(1, group_size, 1) hashes = rolling_polynomial_hash(windows, hash_func_nb) hash_values = hashes % max_hash - + hash_values.requires_grad = False return hash_values @@ -351,10 +363,7 @@ def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_ def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): """Initialize hash-based token embeddings for the BLT encoder.""" num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) - embeddings = [ - nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) - for _ in range(num_embeddings) - ] + embeddings = [nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) for _ in range(num_embeddings)] return nn.ModuleList(embeddings) @@ -390,10 +399,10 @@ def _prepare_patch_cross_attention_mask( ) -> Tuple[torch.Tensor, torch.Tensor]: """ Prepare cross-attention mask for patch-based attention, following mllama's robust approach. - + This function creates masks that control which patches can attend to which other patches, with support for query/key role swapping and cross-attention multipliers. - + Args: patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. num_patches (int): Total number of patches. @@ -401,22 +410,25 @@ def _prepare_patch_cross_attention_mask( patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. cross_attn_k (int): Cross-attention multiplier for repeating patches. dtype (torch.dtype): Data type for the output mask. - + Returns: - Tuple[torch.Tensor, torch.Tensor]: - - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] + Tuple[torch.Tensor, torch.Tensor]: + - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows """ batch_size, seq_len = patch_ids.shape device = patch_ids.device - + # Determine query and key lengths based on configuration if patches_as_queries: q_len = num_patches * cross_attn_k kv_len = sequence_length # Create patch-to-sequence mapping - q_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(-1).expand( - batch_size, num_patches, seq_len + q_patch_ids = ( + torch.arange(num_patches, device=device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(batch_size, num_patches, seq_len) ) kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) else: @@ -424,33 +436,35 @@ def _prepare_patch_cross_attention_mask( kv_len = num_patches * cross_attn_k # Create sequence-to-patch mapping q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) - kv_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand( - batch_size, seq_len, num_patches + kv_patch_ids = ( + torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, num_patches) ) - + # Create base attention mask - boolean mask where True means "should attend" # Exact patch matching cross_attention_mask = q_patch_ids == kv_patch_ids - + # Handle cross_attn_k multiplier by repeating along appropriate dimension repeat_dim = 1 if patches_as_queries else -1 cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) - + # Validate dimensions expected_shape = (batch_size, q_len, kv_len) if cross_attention_mask.shape != expected_shape: - raise ValueError(f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}") - + raise ValueError( + f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}" + ) + # Reshape so it can be used by attn module - add head dimension cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len] - + # Invert the mask (following mllama pattern exactly) # True -> 0.0 (attend), False -> 1.0 (will become -inf) - inverted_cross_attn_mask = (1.0 - cross_attention_mask.to(dtype)) + inverted_cross_attn_mask = 1.0 - cross_attention_mask.to(dtype) cross_attention_mask = inverted_cross_attn_mask.masked_fill( inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min ) - + # Apply full-row bias (following mllama pattern exactly) # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's # last dimension contains negative infinity values, otherwise it's 1 @@ -459,7 +473,7 @@ def _prepare_patch_cross_attention_mask( (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] ) cross_attention_mask *= full_text_row_masked_out_mask - + return cross_attention_mask, full_text_row_masked_out_mask @@ -497,7 +511,7 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optiona for i, splits in enumerate(processed): if splits: - padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) + padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) # Trim zero columns if (padded != 0).any(dim=0).sum() < padded.shape[1]: @@ -539,10 +553,12 @@ def forward(self, x, position_ids): class BLTLocalEncoder(nn.Module): def __init__(self, config: BLTLocalEncoderConfig): super().__init__() - + self.config = config - - self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + + self.layers = nn.ModuleList( + [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.rotary_emb = BLTRotaryEmbedding(config=config) @@ -579,10 +595,10 @@ def forward( batch_size, _, _ = input_embeds.shape - hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) + hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) @@ -593,7 +609,9 @@ def forward( if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) + patch_embeds = patch_embeds.reshape( + batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size + ) layer_idx = idx if self.config.cross_attn_all_layers else 0 cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( @@ -609,7 +627,7 @@ def forward( encoder_cross_states = patch_embeds return hidden_states, encoder_cross_states - + def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): """ Reduce variable length patches to single embedding per patch @@ -625,7 +643,9 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) - reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device) + reduced_embeddings = torch.zeros( + (batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) reduced_embeddings = reduced_embeddings.scatter_reduce( src=hidden_states, dim=1, @@ -644,9 +664,11 @@ def __init__(self, config: BLTLocalDecoderConfig): # Extract config values to instance attributes self.config = config - self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove + self.cross_attn_decoder = True # config.cross_attn_decoder #TODO: maybe remove - self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.rotary_emb = BLTRotaryEmbedding(config=config) @@ -671,7 +693,6 @@ def __init__(self, config: BLTLocalDecoderConfig): # bias=False, # ) - def forward( self, tokens: torch.Tensor, @@ -687,13 +708,15 @@ def forward( hidden_states = embeds patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) + patch_embeds = patch_embeds.reshape( + batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size + ) if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for i, layer in enumerate(self.layers): @@ -714,7 +737,7 @@ def forward( hidden_states = layer_outputs[0] logits = self.norm(hidden_states) - # logits = self.lm_head(logits) + # logits = self.lm_head(logits) return logits, cache @@ -731,7 +754,7 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.dropout = config.dropout self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) @@ -756,8 +779,8 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() - - query_states = self.q_norm(hidden_states) # BLT normalizes first + + query_states = self.q_norm(hidden_states) # BLT normalizes first query_states = self.q_proj(query_states) if cross_attention_states is not None: @@ -833,7 +856,6 @@ def __init__(self, config: BLTGlobalTransformerConfig): self.rotary_emb = BLTRotaryEmbedding(config=config) - def forward( self, input_embeds: torch.Tensor, @@ -847,7 +869,7 @@ def forward( hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, layer in enumerate(self.layers): layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) @@ -856,8 +878,6 @@ def forward( return hidden_states, cache - - class BLTPreTrainedModel(PreTrainedModel): config_class = BLTConfig base_model_prefix = "model" @@ -870,7 +890,7 @@ class BLTPreTrainedModel(PreTrainedModel): def _init_weights(self, module): if isinstance(module, nn.Linear): - std = getattr(module, '_custom_std', module.in_features ** (-0.5)) + std = getattr(module, "_custom_std", module.in_features ** (-0.5)) nn.init.trunc_normal_( module.weight, mean=0.0, @@ -880,9 +900,9 @@ def _init_weights(self, module): ) if module.bias is not None: nn.init.zeros_(module.bias) - + elif isinstance(module, nn.Embedding): - std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) + std = getattr(module, "_custom_std", module.embedding_dim ** (-0.5)) nn.init.trunc_normal_( module.weight, mean=0.0, @@ -890,26 +910,26 @@ def _init_weights(self, module): a=-3 * std, b=3 * std, ) - + elif isinstance(module, BLTModel): if module.encoder_hash_tok_embedding is not None: emb_std = module.config.encoder_config.hidden_size ** (-0.5) for emb in module.encoder_hash_tok_embedding: emb._custom_std = emb_std - + elif isinstance(module, BLTLocalEncoder): if module.patch_embedding_projection is not None: module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) - + elif isinstance(module, BLTLocalDecoder): if module.patch_embedding_projection is not None: module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) - + elif isinstance(module, BLTPatcher): emb_std = module.config.hidden_size ** (-0.5) module.embed_tokens._custom_std = emb_std module.lm_head._custom_std = emb_std - + elif isinstance(module, BLTForCausalLM): if module.lm_head is not None: module.lm_head._custom_std = module.config.decoder_config.hidden_size ** (-0.5) @@ -973,11 +993,13 @@ def forward( else: patch_lengths = process_patch_lengths( torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device), - self.config.max_patch_length + self.config.max_patch_length, ) patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) encoder_embeds = compute_hash_embeddings( - tokens, self.local_encoder, self.encoder_hash_tok_embedding, + tokens, + self.local_encoder, + self.encoder_hash_tok_embedding, self.config.encoder_hash_byte_group_nb_functions, self.config.encoder_hash_byte_group_size, self.config.encoder_hash_byte_group_vocab, @@ -1000,7 +1022,12 @@ def forward( ) decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( - decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, encoder_embeds.dtype + decoder_patch_ids, + patch_lengths.shape[1], + sequence_length, + False, + self.config.cross_attn_k, + encoder_embeds.dtype, ) output, _ = self.local_decoder( tokens=tokens, @@ -1016,18 +1043,21 @@ def forward( else: return (output, None, None) return output - + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: """Convert patch lengths to patch IDs for each token position.""" batch_size = patch_lengths.shape[0] - patch_starts = torch.cat([ - torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), - patch_lengths.cumsum(dim=-1)[:, :-1] - ], dim=-1) - + patch_starts = torch.cat( + [ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1], + ], + dim=-1, + ) + token_positions = torch.arange(seq_len, device=patch_lengths.device) return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 - + class BLTPatcher(BLTPreTrainedModel): def __init__(self, config: BLTPatcherConfig): @@ -1036,11 +1066,10 @@ def __init__(self, config: BLTPatcherConfig): self.rotary_emb = BLTRotaryEmbedding(config=self.config) self.layers = nn.ModuleList() - + for layer_idx in range(self.config.num_hidden_layers): self.layers.append(BLTTransformerLayer(self.config, layer_idx)) - self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps) @@ -1060,7 +1089,6 @@ def forward( patching_batch_size: int = 1, device: Optional[str] = None, ): - # Handle chunked processing for entropy calculation entropies = [] predictions = [] @@ -1085,15 +1113,15 @@ def forward( batch_size, _, _ = input_embeds.shape position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - - position_embeddings = self.rotary_emb(hidden_states, position_ids) - + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for i, layer in enumerate(self.layers): layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) hidden_states = layer_outputs[0] logits = self.lm_head(self.norm(hidden_states)) - logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] + logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] predictions.append(logits) prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() entropies.append(prediction_entropies) @@ -1114,7 +1142,9 @@ def forward( ) else: # Default to byte-level patching - patch_lengths = torch.ones((batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device) + patch_lengths = torch.ones( + (batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device + ) patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) return concat_entropies, patch_lengths, concat_predictions @@ -1136,7 +1166,9 @@ def patch_lengths_from_entropies( batch_size = entropies.shape[0] # Always include token 0 and 1 as starting tokens - init_tokens = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + init_tokens = ( + torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + ) offset = init_tokens.shape[1] # Ignore first token entropy (BOS) @@ -1273,13 +1305,14 @@ def forward( attentions=None, ) + __all__ = [ "BLTPreTrainedModel", "BLTModel", "BLTPatcher", "BLTLocalEncoder", - "BLTLocalDecoder", + "BLTLocalDecoder", "BLTGlobalTransformer", "BLTTransformerLayer", "BLTForCausalLM", -] \ No newline at end of file +] diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py index b699bf50d4c3..646b78a58120 100644 --- a/src/transformers/models/blt/tokenization_blt.py +++ b/src/transformers/models/blt/tokenization_blt.py @@ -14,10 +14,9 @@ # limitations under the License. """Tokenization classes for BLT.""" -import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple -from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging @@ -30,7 +29,7 @@ SEP = " " BOS_ID: int = 1 EOS_ID: int = 2 -PAD_ID: int = -1 +PAD_ID: int = 260 # Use valid ID after byte tokens (4-259) BOE_ID: int = 0 BPE_ID: int = 3 OFFSET: int = 4 @@ -44,7 +43,7 @@ class BLTTokenizer(PreTrainedTokenizer): Construct a BLT tokenizer. Based on byte-level tokenization where each byte is treated as a token. This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset. - It supports special tokens for beginning of sequence (BOS), end of sequence (EOS), + It supports special tokens for beginning of sequence (BOS), end of sequence (EOS), beginning of example (BOE), and padding (PAD). Args: @@ -74,7 +73,7 @@ class BLTTokenizer(PreTrainedTokenizer): def __init__( self, bos_token="", - eos_token="", + eos_token="", pad_token="", unk_token="", boe_token="", @@ -84,29 +83,43 @@ def __init__( spaces_between_special_tokens=False, **kwargs, ): + # Store BLT-specific parameters first self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token self.vocab_size_unit_1 = BYTE_UNITS self.offsetting_special_char = OFFSET - + + # BLT token IDs (exactly like original) self.boe_id = BOE_ID - self.bos_id = BOS_ID + self.bos_id = BOS_ID self.eos_id = EOS_ID self.pad_id = PAD_ID self.bpe_id = BPE_ID self.n_words = self.vocab_size_unit_1 + self.offsetting_special_char + self.boe_token = boe_token + + # Build encoder (token -> id) and decoder (id -> token) mappings + self.encoder = {} + + # Add special tokens to encoder + self.encoder[str(bos_token)] = self.bos_id + self.encoder[str(eos_token)] = self.eos_id + self.encoder[str(pad_token)] = self.pad_id + self.encoder[str(boe_token)] = self.boe_id - bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token - pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token - unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token - self.boe_token = AddedToken(boe_token, normalized=False, special=True) if isinstance(boe_token, str) else boe_token + # Add byte tokens (0-255) to encoder + for i in range(self.vocab_size_unit_1): + self.encoder[str(i)] = i + self.offsetting_special_char + + # Create decoder as reverse of encoder + self.decoder = {v: k for k, v in self.encoder.items()} super().__init__( bos_token=bos_token, eos_token=eos_token, pad_token=pad_token, unk_token=unk_token, + boe_token=boe_token, add_bos_token=add_bos_token, add_eos_token=add_eos_token, clean_up_tokenization_spaces=clean_up_tokenization_spaces, @@ -117,143 +130,60 @@ def __init__( @property def vocab_size(self): """Returns vocab size""" - return self.vocab_size_unit_1 + self.offsetting_special_char + # Account for byte tokens (4-259) plus special tokens (0,1,2,3,260) + return max(self.vocab_size_unit_1 + self.offsetting_special_char, PAD_ID + 1) def get_vocab(self): """Returns vocab as a dict""" - # Create a mapping for byte values + offset - vocab = {} - - # Add special tokens (with defensive checks) - if hasattr(self, 'bos_token'): - vocab[str(self.bos_token)] = self.bos_id - if hasattr(self, 'eos_token'): - vocab[str(self.eos_token)] = self.eos_id - if hasattr(self, 'pad_token'): - vocab[str(self.pad_token)] = self.pad_id - if hasattr(self, 'boe_token'): - vocab[str(self.boe_token)] = self.boe_id - - # Add byte tokens as string representations of byte values - vocab_size_unit_1 = getattr(self, 'vocab_size_unit_1', BYTE_UNITS) - offsetting_special_char = getattr(self, 'offsetting_special_char', OFFSET) - for i in range(vocab_size_unit_1): - vocab[str(i)] = i + offsetting_special_char - - # Add any additional tokens if available - if hasattr(self, 'added_tokens_encoder'): - vocab.update(self.added_tokens_encoder) - return vocab - - def _tokenize(self, text: str, **kwargs) -> List[str]: - """ - Converts a string to a list of tokens. For BLT, we work directly with byte values. - Returns a list of strings that represent the byte values. - """ - # Convert text to UTF-8 bytes, just like the original - try: - bytes_data = text.encode("utf-8", errors="ignore") - except UnicodeEncodeError: - bytes_data = text.encode("utf-8", errors="ignore") - - # Return string representations of byte values for the tokenizer framework - return [str(byte_val) for byte_val in bytes_data] + return dict(self.encoder, **self.added_tokens_encoder) def _convert_token_to_id(self, token: str) -> int: """Converts a token (str) to an id using the vocab.""" - # Handle special tokens - if token == str(self.bos_token): - return self.bos_id - elif token == str(self.eos_token): - return self.eos_id - elif token == str(self.pad_token): - return self.pad_id - elif token == str(self.boe_token): - return self.boe_id - else: - try: - # Convert byte value string to int and add offset - byte_val = int(token) - if 0 <= byte_val <= 255: - return byte_val + self.offsetting_special_char - except ValueError: - pass - - return self.added_tokens_encoder.get(token, self.unk_token_id) + return self.encoder.get(token, self.added_tokens_encoder.get(token, self.unk_token_id)) def _convert_id_to_token(self, index: int) -> str: """Converts an index (integer) to a token (str) using the vocab.""" - # Handle special tokens - if index == self.bos_id: - return str(self.bos_token) - elif index == self.eos_id: - return str(self.eos_token) - elif index == self.pad_id: - return str(self.pad_token) - elif index == self.boe_id: - return str(self.boe_token) - elif index >= self.offsetting_special_char and index < self.vocab_size: - # Convert back to byte value - byte_val = index - self.offsetting_special_char - return str(byte_val) - else: - # Check added tokens - for token, token_id in self.added_tokens_encoder.items(): - if token_id == index: - return token - return str(self.unk_token) + # Check added tokens first (they might override special token IDs) + for token, token_id in self.added_tokens_encoder.items(): + if token_id == index: + return token + + return self.decoder.get(index, str(self.unk_token)) def convert_tokens_to_string(self, tokens: List[str]) -> str: """Converts a sequence of tokens to a single string.""" byte_values = [] - + for token in tokens: - # Skip special tokens - if token in [str(self.bos_token), str(self.eos_token), str(self.pad_token), str(self.boe_token)]: + # Skip special tokens by checking if they're in encoder but not byte tokens + if token in self.encoder and token in { + str(self.bos_token), + str(self.eos_token), + str(self.pad_token), + str(self.boe_token), + }: continue - + try: - # Convert token back to byte value byte_val = int(token) if 0 <= byte_val <= 255: byte_values.append(byte_val) except ValueError: continue - - # Convert byte values back to string - try: - return bytes(byte_values).decode("utf-8", errors="ignore") - except (UnicodeDecodeError, ValueError): - return "" - - def encode(self, text: str, add_special_tokens: bool = True, **kwargs): - add_bos = kwargs.get('add_bos', self.add_bos_token if add_special_tokens else False) - add_eos = kwargs.get('add_eos', self.add_eos_token if add_special_tokens else False) - - # Since bpe_delim=False, we use the simple byte encoding - tokens = bytes(text, encoding="utf-8", errors="ignore") - - # Offsetting - tokens = [int(unit) + self.offsetting_special_char for unit in tokens] - - if add_bos: - tokens.insert(0, self.bos_id) - if add_eos: - tokens.append(self.eos_id) - - return tokens - - def decode(self, tokens, cut_at_eos: bool = False): - if cut_at_eos: - for k, t in enumerate(tokens): - if t == self.eos_id: - tokens = tokens[: k + 1] - break - return bytes( - [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0] - ).decode("utf-8", errors="ignore") - - def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: + + return bytes(byte_values).decode("utf-8", errors="ignore") + + def _tokenize(self, text: str, **kwargs) -> List[str]: + """Converts a string to a list of tokens. For BLT, we work directly with byte values.""" + return [str(byte_val) for byte_val in text.encode("utf-8", errors="ignore")] + + # def decode(self, token_ids, skip_special_tokens: bool = False, clean_up_tokenization_spaces: Optional[bool] = None, **kwargs): + # """Converts a sequence of ids in a string, using the tokenizer and vocabulary.""" + # return super().decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: """ Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating and adding special tokens. A BLT sequence has the following format: @@ -277,8 +207,43 @@ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: return bos + token_ids_0 + eos return bos + token_ids_0 + eos + token_ids_1 + eos + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + ([0] * len(token_ids_1)) + eos_token_id + def get_vocab_size(self) -> int: """Get vocab size like the original tokenizer.""" - return self.vocab_size_unit_1 + self.offsetting_special_char + return self.vocab_size + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + # BLT doesn't require external vocabulary files since it uses byte-level tokenization + return () + -__all__ = ["BLTTokenizer"] \ No newline at end of file +__all__ = ["BLTTokenizer"] diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index f6e240f8d79d..5910d46eabbe 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -34,14 +34,10 @@ if is_torch_available(): import torch - from transformers import ( - BLTConfig, - BLTForCausalLM, - BLTModel, - BLTTokenizer - ) + from transformers import BLTConfig, BLTForCausalLM, BLTModel, BLTTokenizer from transformers.models.blt.modeling_blt import BLTRotaryEmbedding + class BLTModelTester(CausalLMModelTester): if is_torch_available(): config_class = BLTConfig @@ -99,7 +95,7 @@ def test_blt(self): prompt = "my name is" model = BLTForCausalLM.from_pretrained( - "itazap/blt-1b", + "itazap/blt-1b", device_map="auto", ) @@ -113,31 +109,85 @@ def test_blt(self): output_text = tokenizer.decode(generated_ids[0]) self.assertEqual(output_text, EXPECTED_TEXT) - @slow @require_read_token def test_model_logits(self): + EXPECTED_OUTPUT = torch.tensor( + [ + [ + -10.4948, + -10.7065, + -6.1813, + -10.5545, + -10.3428, + -9.1493, + -8.4937, + -8.6382, + -9.2159, + -9.5907, + -9.3679, + -8.4184, + -9.0655, + -3.4436, + 2.9616, + -10.3157, + -6.3723, + -6.0133, + -9.7100, + -9.2128, + -8.8064, + -9.8179, + -9.7516, + -9.4681, + -9.7715, + -9.4897, + -9.0491, + -9.8098, + -9.4648, + -9.3294, + ], + [ + -13.3010, + -13.1910, + -5.7230, + -13.2895, + -13.4864, + -8.7140, + -7.0275, + -7.0182, + -10.1362, + -10.3762, + -9.9086, + -7.8049, + -8.8660, + -5.2711, + -3.5778, + -12.5346, + -9.1609, + -6.7925, + -10.3717, + -9.2650, + -10.6393, + -11.4807, + -11.2128, + -10.9615, + -10.5806, + -10.8873, + -11.0651, + -11.3471, + -10.5437, + -9.9688, + ], + ] + ).to(torch_device) - EXPECTED_OUTPUT = torch.tensor([[-10.4948, -10.7065, -6.1813, -10.5545, -10.3428, -9.1493, -8.4937, - -8.6382, -9.2159, -9.5907, -9.3679, -8.4184, -9.0655, -3.4436, - 2.9616, -10.3157, -6.3723, -6.0133, -9.7100, -9.2128, -8.8064, - -9.8179, -9.7516, -9.4681, -9.7715, -9.4897, -9.0491, -9.8098, - -9.4648, -9.3294], - [-13.3010, -13.1910, -5.7230, -13.2895, -13.4864, -8.7140, -7.0275, - -7.0182, -10.1362, -10.3762, -9.9086, -7.8049, -8.8660, -5.2711, - -3.5778, -12.5346, -9.1609, -6.7925, -10.3717, -9.2650, -10.6393, - -11.4807, -11.2128, -10.9615, -10.5806, -10.8873, -11.0651, -11.3471, - -10.5437, -9.9688]]).to(torch_device) - input_ids = [1, 42, 21, 12, 43, 23, 1, 4] - model = BLTForCausalLM.from_pretrained( - "itazap/blt-1b", device_map="auto" - ) + model = BLTForCausalLM.from_pretrained("itazap/blt-1b", device_map="auto") with torch.no_grad(): - output = model(torch.tensor([input_ids]).to(torch_device))[0] - + output = model(torch.tensor([input_ids]).to(torch_device))[0] + torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-4, atol=1e-4) @slow @@ -147,14 +197,10 @@ def test_model_bf16(self): """Test BLT model with bfloat16 precision.""" NUM_TOKENS_TO_GENERATE = 200 EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m" - + prompt = "my name is" - model = BLTForCausalLM.from_pretrained( - "itazap/blt-1b", - device_map="auto", - torch_dtype=torch.bfloat16 - ) + model = BLTForCausalLM.from_pretrained("itazap/blt-1b", device_map="auto", torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b") @@ -170,24 +216,79 @@ def test_model_bf16(self): @require_torch_bf16 def test_model_logits_bf16(self): """Test BLT model logits with bfloat16 precision.""" - EXPECTED_OUTPUT = torch.tensor([[-10.5000, -10.7500, -6.2188, -10.5625, -10.3750, -9.1875, -8.5000, - -8.6250, -9.1875, -9.6250, -9.3750, -8.5000, -9.0625, -3.4219, - 2.9688, -10.3125, -6.4062, -6.0000, -9.6875, -9.2500, -8.8125, - -9.8750, -9.7500, -9.5000, -9.8125, -9.5000, -9.0625, -9.8750, - -9.5000, -9.3750], - [-13.3750, -13.2500, -5.5938, -13.3750, -13.5000, -8.7500, -7.0312, - -7.0000, -10.1875, -10.3750, -9.8750, -7.8125, -8.8750, -5.3125, - -3.5469, -12.5625, -9.1875, -6.7812, -10.3750, -9.2500, -10.6250, - -11.5000, -11.2500, -11.0000, -10.6250, -10.9375, -11.1250, -11.3750, - -10.5625, -10.0000]], dtype=torch.bfloat16).to(torch_device) - + EXPECTED_OUTPUT = torch.tensor( + [ + [ + -10.5000, + -10.7500, + -6.2188, + -10.5625, + -10.3750, + -9.1875, + -8.5000, + -8.6250, + -9.1875, + -9.6250, + -9.3750, + -8.5000, + -9.0625, + -3.4219, + 2.9688, + -10.3125, + -6.4062, + -6.0000, + -9.6875, + -9.2500, + -8.8125, + -9.8750, + -9.7500, + -9.5000, + -9.8125, + -9.5000, + -9.0625, + -9.8750, + -9.5000, + -9.3750, + ], + [ + -13.3750, + -13.2500, + -5.5938, + -13.3750, + -13.5000, + -8.7500, + -7.0312, + -7.0000, + -10.1875, + -10.3750, + -9.8750, + -7.8125, + -8.8750, + -5.3125, + -3.5469, + -12.5625, + -9.1875, + -6.7812, + -10.3750, + -9.2500, + -10.6250, + -11.5000, + -11.2500, + -11.0000, + -10.6250, + -10.9375, + -11.1250, + -11.3750, + -10.5625, + -10.0000, + ], + ], + dtype=torch.bfloat16, + ).to(torch_device) + input_ids = [1, 42, 21, 12, 43, 23, 1, 4] - model = BLTForCausalLM.from_pretrained( - "itazap/blt-1b", - device_map="auto", - torch_dtype=torch.bfloat16 - ) + model = BLTForCausalLM.from_pretrained("itazap/blt-1b", device_map="auto", torch_dtype=torch.bfloat16) with torch.no_grad(): output = model(torch.tensor([input_ids]).to(torch_device))[0] @@ -203,11 +304,7 @@ def test_model_eager(self): prompt = "my name is" - model = BLTForCausalLM.from_pretrained( - "itazap/blt-1b", - device_map="auto", - attn_implementation="eager" - ) + model = BLTForCausalLM.from_pretrained("itazap/blt-1b", device_map="auto", attn_implementation="eager") tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b") @@ -228,11 +325,7 @@ def test_model_bf16_static_cache(self): prompt = "my name is" - model = BLTForCausalLM.from_pretrained( - "itazap/blt-1b", - device_map="auto", - torch_dtype=torch.bfloat16 - ) + model = BLTForCausalLM.from_pretrained("itazap/blt-1b", device_map="auto", torch_dtype=torch.bfloat16) model.generation_config.cache_implementation = "static" diff --git a/tests/models/blt/test_tokenization_blt.py b/tests/models/blt/test_tokenization_blt.py index 62af101b1d83..503bf504c3ad 100644 --- a/tests/models/blt/test_tokenization_blt.py +++ b/tests/models/blt/test_tokenization_blt.py @@ -12,903 +12,241 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import pickle -import shutil -import tempfile import unittest -from datasets import load_dataset -from huggingface_hub import hf_hub_download - -from transformers import ( - SPIECE_UNDERLINE, - AddedToken, - AutoTokenizer, - BLTTokenizer, - BLTTokenizerFast, - PreTrainedTokenizerFast, -) -from transformers.convert_slow_tokenizer import convert_slow_tokenizer -from transformers.testing_utils import ( - get_tests_dir, - nested_simplify, - require_jinja, - require_read_token, - require_sentencepiece, - require_tiktoken, - require_tokenizers, - require_torch, - slow, -) +from transformers import BLTTokenizer +from transformers.testing_utils import require_tokenizers +from transformers.tokenization_utils import AddedToken from ...test_tokenization_common import TokenizerTesterMixin -SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") - - -@require_sentencepiece @require_tokenizers class BLTTokenizationTest(TokenizerTesterMixin, unittest.TestCase): - from_pretrained_id = ["hf-internal-testing/blt-tokenizer", "meta-blt/BLT-2-7b-hf"] + from_pretrained_id = [] tokenizer_class = BLTTokenizer - rust_tokenizer_class = BLTTokenizerFast + rust_tokenizer_class = None test_rust_tokenizer = False - test_sentencepiece = True + test_sentencepiece = False + test_slow_tokenizer = True from_pretrained_kwargs = {} @classmethod def setUpClass(cls): super().setUpClass() - - # We have a SentencePiece fixture for testing - tokenizer = BLTTokenizer(SAMPLE_VOCAB, keep_accents=True) - tokenizer.pad_token = tokenizer.eos_token + # Create a simple BLT tokenizer for testing + tokenizer = BLTTokenizer() tokenizer.save_pretrained(cls.tmpdirname) def get_tokenizers(self, **kwargs): - kwargs.update({"pad_token": ""}) + kwargs.update({"add_bos_token": True, "add_eos_token": False}) return super().get_tokenizers(**kwargs) - def test_full_tokenizer(self): - tokenizer = BLTTokenizer(SAMPLE_VOCAB, keep_accents=True) - - tokens = tokenizer.tokenize("This is a test") - self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) - - self.assertListEqual( - tokenizer.convert_tokens_to_ids(tokens), - [285, 46, 10, 170, 382], - ) - - tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") - self.assertListEqual( - tokens, - [ - SPIECE_UNDERLINE + "I", - SPIECE_UNDERLINE + "was", - SPIECE_UNDERLINE + "b", - "or", - "n", - SPIECE_UNDERLINE + "in", - SPIECE_UNDERLINE + "", - "9", - "2", - "0", - "0", - "0", - ",", - SPIECE_UNDERLINE + "and", - SPIECE_UNDERLINE + "this", - SPIECE_UNDERLINE + "is", - SPIECE_UNDERLINE + "f", - "al", - "s", - "é", - ".", - ], - ) - ids = tokenizer.convert_tokens_to_ids(tokens) - self.assertListEqual( - ids, - [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4], - ) - - back_tokens = tokenizer.convert_ids_to_tokens(ids) - self.assertListEqual( - back_tokens, - [ - SPIECE_UNDERLINE + "I", - SPIECE_UNDERLINE + "was", - SPIECE_UNDERLINE + "b", - "or", - "n", - SPIECE_UNDERLINE + "in", - SPIECE_UNDERLINE + "", - "", - "2", - "0", - "0", - "0", - ",", - SPIECE_UNDERLINE + "and", - SPIECE_UNDERLINE + "this", - SPIECE_UNDERLINE + "is", - SPIECE_UNDERLINE + "f", - "al", - "s", - "", - ".", - ], - ) - - @unittest.skip(reason="Let's wait for the fast tokenizer!") - def test_save_pretrained(self): - self.tokenizers_list += (self.rust_tokenizer_class, "hf-internal-testing/blt-tokenizer", {}) - for tokenizer, pretrained_name, kwargs in self.tokenizers_list: - with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): - tokenizer_r = self.get_rust_tokenizer(pretrained_name, **kwargs) - tokenizer_p = self.get_tokenizer(pretrained_name, **kwargs) - - tmpdirname2 = tempfile.mkdtemp() - - tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2) - tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2) - - # Checks it save with the same files + the tokenizer.json file for the fast one - self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files)) - tokenizer_r_files = tuple(f for f in tokenizer_r_files if "tokenizer.json" not in f) - self.assertSequenceEqual(tokenizer_r_files, tokenizer_p_files) - - # Checks everything loads correctly in the same way - tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2) - tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2) - - # Check special tokens are set accordingly on Rust and Python - for key in tokenizer_pp.special_tokens_map: - self.assertTrue(hasattr(tokenizer_rp, key)) - - shutil.rmtree(tmpdirname2) - - # Save tokenizer rust, legacy_format=True - tmpdirname2 = tempfile.mkdtemp() - - tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2, legacy_format=True) - tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2) - - # Checks it save with the same files - self.assertSequenceEqual(tokenizer_r_files, tokenizer_p_files) - - # Checks everything loads correctly in the same way - tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2) - tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2) - - # Check special tokens are set accordingly on Rust and Python - for key in tokenizer_pp.special_tokens_map: - self.assertTrue(hasattr(tokenizer_rp, key)) - - shutil.rmtree(tmpdirname2) - - # Save tokenizer rust, legacy_format=False - tmpdirname2 = tempfile.mkdtemp() - - tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2, legacy_format=False) - tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2) - - # Checks it saved the tokenizer.json file - self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files)) - - # Checks everything loads correctly in the same way - tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2) - tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2) - - # Check special tokens are set accordingly on Rust and Python - for key in tokenizer_pp.special_tokens_map: - self.assertTrue(hasattr(tokenizer_rp, key)) - - shutil.rmtree(tmpdirname2) - - @require_torch - def test_batch_tokenization(self): - if not self.test_seq2seq: - self.skipTest(reason="test_seq2seq is set to False") - - tokenizers = self.get_tokenizers() - for tokenizer in tokenizers: - with self.subTest(f"{tokenizer.__class__.__name__}"): - # Longer text that will definitely require truncation. - text = [ - " UN Chief Says There Is No Military Solution in Syria", - " Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for" - " Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons" - " will only worsen the violence and misery for millions of people.", - ] - try: - batch = tokenizer( - text=text, - max_length=3, - max_target_length=10, - return_tensors="pt", - ) - except NotImplementedError: - self.skipTest(reason="Encountered NotImplementedError when calling tokenizer") - self.assertEqual(batch.input_ids.shape[1], 3) - # max_target_length will default to max_length if not specified - batch = tokenizer(text, max_length=3, return_tensors="pt") - self.assertEqual(batch.input_ids.shape[1], 3) - - batch_encoder_only = tokenizer(text=text, max_length=3, max_target_length=10, return_tensors="pt") - self.assertEqual(batch_encoder_only.input_ids.shape[1], 3) - self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3) - self.assertNotIn("decoder_input_ids", batch_encoder_only) - - @unittest.skip(reason="Unfortunately way too slow to build a BPE with SentencePiece.") - def test_save_slow_from_fast_and_reload_fast(self): - pass + def test_blt_tokenizer_basics(self): + """Test basic BLT tokenizer functionality""" + tokenizer = BLTTokenizer() + + # Test vocab size (256 bytes + 4 offset + special tokens) + self.assertEqual(tokenizer.vocab_size, 261) + + # Test special token IDs + self.assertEqual(tokenizer.bos_id, 1) + self.assertEqual(tokenizer.eos_id, 2) + self.assertEqual(tokenizer.boe_id, 0) + self.assertEqual(tokenizer.pad_id, 260) + + # Test special tokens + self.assertEqual(str(tokenizer.bos_token), "") + self.assertEqual(str(tokenizer.eos_token), "") + self.assertEqual(str(tokenizer.boe_token), "") + self.assertEqual(str(tokenizer.pad_token), "") + + def test_simple_encode_decode(self): + tokenizer = BLTTokenizer(add_bos_token=False, add_eos_token=False) + + text = "Hello" + encoded = tokenizer.encode(text, add_special_tokens=False) + + # "Hello" in UTF-8 bytes: [72, 101, 108, 108, 111] + # With offset +4: [76, 105, 112, 112, 115] + expected = [76, 105, 112, 112, 115] + self.assertEqual(encoded, expected) + + decoded = tokenizer.decode(encoded) + self.assertEqual(decoded, text) - def test_special_tokens_initialization(self): - for tokenizer, pretrained_name, kwargs in self.tokenizers_list: - with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): - added_tokens = [AddedToken("", lstrip=True)] - - tokenizer_r = self.get_rust_tokenizer( - pretrained_name, additional_special_tokens=added_tokens, **kwargs - ) - r_output = tokenizer_r.encode("Hey this is a token") - - special_token_id = tokenizer_r.encode("", add_special_tokens=False)[0] - - self.assertTrue(special_token_id in r_output) - - if self.test_slow_tokenizer: - tokenizer_cr = self.get_rust_tokenizer( - pretrained_name, - additional_special_tokens=added_tokens, - **kwargs, # , from_slow=True <- unfortunately too slow to convert - ) - tokenizer_p = self.tokenizer_class.from_pretrained( - pretrained_name, additional_special_tokens=added_tokens, **kwargs - ) - - p_output = tokenizer_p.encode("Hey this is a token") - - cr_output = tokenizer_cr.encode("Hey this is a token") - - self.assertEqual(p_output, r_output) - self.assertEqual(cr_output, r_output) - self.assertTrue(special_token_id in p_output) - self.assertTrue(special_token_id in cr_output) - - @slow - def test_tokenizer_integration(self): - expected_encoding = {'input_ids': [[1, 4103, 689, 414, 313, 24784, 368, 2998, 408, 282, 3637, 25350, 29899, 9067, 414, 322, 282, 3637, 25350, 29899, 1457, 3018, 1312, 29899, 2151, 29897, 8128, 2498, 29899, 15503, 4220, 6956, 1973, 313, 13635, 29911, 29892, 402, 7982, 29899, 29906, 29892, 1528, 13635, 29911, 29874, 29892, 1060, 26369, 29892, 6652, 309, 29933, 814, 29892, 1060, 29931, 6779, 11410, 363, 18385, 17088, 7634, 11235, 313, 25103, 29965, 29897, 322, 18385, 17088, 28203, 313, 25103, 29954, 29897, 411, 975, 29871, 29941, 29906, 29974, 758, 3018, 1312, 4733, 297, 29871, 29896, 29900, 29900, 29974, 10276, 322, 6483, 1006, 3372, 3097, 1546, 435, 1165, 29892, 10772, 29911, 25350, 322, 323, 6073, 17907, 29889], [1, 350, 20161, 338, 8688, 304, 758, 29899, 14968, 6483, 21000, 8684, 284, 22540, 515, 443, 29880, 24025, 1426, 491, 14002, 368, 4195, 292, 373, 1716, 2175, 322, 1492, 3030, 297, 599, 15359, 29889], [1, 450, 4996, 17354, 1701, 29916, 432, 17204, 975, 278, 17366, 11203, 29889]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip - - self.tokenizer_integration_test_util( - expected_encoding=expected_encoding, - model_name="hf-internal-testing/blt-tokenizer", - revision="0984d03108b1a041ed679bd253b6519b7e1a4778", - padding=False, - ) - - def test_picklable(self): - with tempfile.NamedTemporaryFile() as f: - shutil.copyfile(SAMPLE_VOCAB, f.name) - tokenizer = BLTTokenizer(f.name, keep_accents=True) - pickled_tokenizer = pickle.dumps(tokenizer) - pickle.loads(pickled_tokenizer) - - @unittest.skip(reason="worker 'gw4' crashed on CI, passing locally.") - def test_pickle_subword_regularization_tokenizer(self): + def test_special_tokens_encoding(self): + tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) + + text = "Hi" + encoded = tokenizer.encode(text, add_special_tokens=True) + + # "Hi" in UTF-8 bytes: [72, 105] -> with offset: [76, 109] + # With BOS (1) and EOS (2): [1, 76, 109, 2] + expected = [1, 76, 109, 2] + self.assertEqual(encoded, expected) + + def test_tokenize_method(self): + tokenizer = BLTTokenizer() + + text = "ABC" + tokens = tokenizer._tokenize(text) + + # "ABC" in UTF-8 bytes: [65, 66, 67] + expected = ["65", "66", "67"] + self.assertEqual(tokens, expected) + + def test_token_conversion(self): + """Test token to ID and ID to token conversion""" + tokenizer = BLTTokenizer() + + # Test byte token conversion + token = "65" # Byte value for 'A' + token_id = tokenizer._convert_token_to_id(token) + self.assertEqual(token_id, 69) # 65 + 4 offset + + converted_token = tokenizer._convert_id_to_token(token_id) + self.assertEqual(converted_token, token) + + bos_id = tokenizer._convert_token_to_id(str(tokenizer.bos_token)) + self.assertEqual(bos_id, 1) + + bos_token = tokenizer._convert_id_to_token(1) + self.assertEqual(bos_token, str(tokenizer.bos_token)) + + def test_convert_tokens_to_string(self): + tokenizer = BLTTokenizer() + + tokens = ["72", "101", "108", "108", "111"] # "Hello" in bytes + result = tokenizer.convert_tokens_to_string(tokens) + self.assertEqual(result, "Hello") + + # Test with special tokens mixed in (should be ignored) + tokens_with_special = [str(tokenizer.bos_token), "72", "105", str(tokenizer.eos_token)] + result = tokenizer.convert_tokens_to_string(tokens_with_special) + self.assertEqual(result, "Hi") + + def test_unicode_handling(self): + tokenizer = BLTTokenizer(add_bos_token=False, add_eos_token=False) + + # Test Unicode character (é) + text = "café" + encoded = tokenizer.encode(text, add_special_tokens=False) + decoded = tokenizer.decode(encoded) + self.assertEqual(decoded, text) + + # Test emoji + text = "Hello 👋" + encoded = tokenizer.encode(text, add_special_tokens=False) + decoded = tokenizer.decode(encoded) + self.assertEqual(decoded, text) + + def test_empty_and_whitespace(self): + tokenizer = BLTTokenizer(add_bos_token=False, add_eos_token=False) + + # Test empty string + encoded = tokenizer.encode("", add_special_tokens=False) + self.assertEqual(encoded, []) + decoded = tokenizer.decode(encoded) + self.assertEqual(decoded, "") + + # Test single space + encoded = tokenizer.encode(" ", add_special_tokens=False) + self.assertEqual(encoded, [36]) # 32 (space) + 4 offset + decoded = tokenizer.decode(encoded) + self.assertEqual(decoded, " ") + + def test_get_vocab(self): + tokenizer = BLTTokenizer() + vocab = tokenizer.get_vocab() + + # Should contain special tokens + self.assertIn(str(tokenizer.bos_token), vocab) + self.assertIn(str(tokenizer.eos_token), vocab) + self.assertIn(str(tokenizer.boe_token), vocab) + self.assertIn(str(tokenizer.pad_token), vocab) + + # Should contain byte representations + self.assertIn("0", vocab) # First byte + self.assertIn("255", vocab) # Last byte + + self.assertEqual(vocab[str(tokenizer.bos_token)], 1) + self.assertEqual(vocab[str(tokenizer.eos_token)], 2) + self.assertEqual(vocab["0"], 4) # 0 + 4 offset + self.assertEqual(vocab["255"], 259) # 255 + 4 offset + + def test_build_inputs_with_special_tokens(self): + tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) + + # Single sequence + token_ids = [76, 109] # "Hi" encoded (H=72+4=76, i=105+4=109) + result = tokenizer.build_inputs_with_special_tokens(token_ids) + expected = [1, 76, 109, 2] # BOS + tokens + EOS + self.assertEqual(result, expected) + + # Pair of sequences + token_ids_1 = [76, 109] # "Hi" + token_ids_2 = [66, 121, 101] # "Bye" + result = tokenizer.build_inputs_with_special_tokens(token_ids_1, token_ids_2) + expected = [1, 76, 109, 2, 66, 121, 101, 2] # BOS + seq1 + EOS + seq2 + EOS + self.assertEqual(result, expected) + + def test_special_tokens_mask(self): + tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) + + token_ids = [76, 109] # "Hi" encoded (H=72+4=76, i=105+4=109) + mask = tokenizer.get_special_tokens_mask(token_ids) + expected = [1, 0, 0, 1] # BOS=1, content=0, content=0, EOS=1 + self.assertEqual(mask, expected) + + def test_add_special_tokens_flags(self): + tokenizer1 = BLTTokenizer(add_bos_token=True, add_eos_token=True) + encoded1 = tokenizer1.encode("Hi", add_special_tokens=True) + self.assertEqual(encoded1[0], 1) # BOS + self.assertEqual(encoded1[-1], 2) # EOS + + tokenizer2 = BLTTokenizer(add_bos_token=False, add_eos_token=False) + encoded2 = tokenizer2.encode("Hi", add_special_tokens=True) + self.assertNotEqual(encoded2[0], 1) # No BOS + self.assertNotEqual(encoded2[-1], 2) # No EOS + + # Test with only BOS + tokenizer3 = BLTTokenizer(add_bos_token=True, add_eos_token=False) + encoded3 = tokenizer3.encode("Hi", add_special_tokens=True) + self.assertEqual(encoded3[0], 1) # BOS + self.assertNotEqual(encoded3[-1], 2) # No EOS + + def test_added_tokens(self): + tokenizer = BLTTokenizer() + + custom_token = AddedToken("", normalized=False, special=True) + tokenizer.add_tokens([custom_token]) + + self.assertIn("", tokenizer.get_vocab()) + + token_id = tokenizer._convert_token_to_id("") + self.assertIsInstance(token_id, int) + + back_token = tokenizer._convert_id_to_token(token_id) + self.assertEqual(back_token, "") + + @unittest.skip("BLT is byte-level, special tokens are encoded as bytes") + def test_add_special_tokens(self): pass - @unittest.skip(reason="worker 'gw4' crashed on CI, passing locally.") - def test_subword_regularization_tokenizer(self): + @unittest.skip("BLT byte-level tokenization doesn't handle pretokenized inputs the same way") + def test_pretokenized_inputs(self): pass - def test_add_prefix_space(self): - pretrained_name = "hf-internal-testing/blt-tokenizer-non-normalized" - inputs = "Hey how are you doing" - EXPECTED_WITH_SPACE = [1, 18637, 920, 526, 366, 2599] - EXPECTED_WO_SPACE = [1, 29950, 1032, 920, 526, 366, 2599] - - slow_ = self.get_tokenizer(pretrained_name, add_prefix_space=False, legacy=False) - fast_ = self.get_rust_tokenizer(pretrained_name, add_prefix_space=False, legacy=False) - self.assertEqual(slow_.encode(inputs), EXPECTED_WO_SPACE) - self.assertEqual(slow_.encode(inputs), fast_.encode(inputs)) - self.assertEqual(slow_.tokenize(inputs), ["H", "ey", "▁how", "▁are", "▁you", "▁doing"]) - self.assertEqual(slow_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True), inputs) - self.assertEqual( - slow_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True), - fast_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True), - ) - - slow_ = self.get_tokenizer(pretrained_name, add_prefix_space=True, legacy=False) - fast_ = self.get_rust_tokenizer(pretrained_name, add_prefix_space=True, legacy=False) - self.assertEqual(slow_.encode(inputs), EXPECTED_WITH_SPACE) - self.assertEqual(slow_.encode(inputs), fast_.encode(inputs)) - self.assertEqual(slow_.tokenize(inputs), ["▁Hey", "▁how", "▁are", "▁you", "▁doing"]) - self.assertEqual(slow_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True), inputs) - self.assertEqual( - slow_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True), - fast_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True), - ) - - def test_load_tokenizer_with_model_file_only(self): - with tempfile.TemporaryDirectory() as tmp_dir: - hf_hub_download(repo_id="huggyblt/blt-7b", filename="tokenizer.model", local_dir=tmp_dir) - tokenizer_fast = self.rust_tokenizer_class.from_pretrained(tmp_dir) - self.assertEqual(tokenizer_fast.encode("This is a test"), [1, 910, 338, 263, 1243]) - - tokenizer_slow = self.tokenizer_class.from_pretrained(tmp_dir) - self.assertEqual(tokenizer_slow.encode("This is a test"), [1, 910, 338, 263, 1243]) - - -@require_torch -@require_sentencepiece -@require_tokenizers -class BLTIntegrationTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - checkpoint_name = "hf-internal-testing/blt-tokenizer-non-normalized" - cls.tokenizer: BLTTokenizer = BLTTokenizer.from_pretrained(checkpoint_name) - cls.rust_tokenizer = BLTTokenizerFast.from_pretrained(checkpoint_name) - return cls - - @require_torch - def integration_tests(self): - inputs = self.tokenizer( - ["The following string should be properly encoded: Hello.", "But ird and ปี ird ด"], - return_tensors="pt", - ) - - self.assertEqual( - nested_simplify(inputs), - { - "input_ids": [ - [1, 450, 1494, 1347, 881, 367, 6284, 18511, 29901, 15043, 29889], - [1, 1205, 29871, 1823, 322, 29871, 31010, 30691, 1678, 1823, 1678, 30718], - ], - "attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], - }, - ) - - def test_fast_special_tokens(self): - slow_tokenizer = self.tokenizer - fast_tokenizer = self.rust_tokenizer - slow = slow_tokenizer.encode("A sample test", add_special_tokens=True) - assert slow == [1, 319, 4559, 1243] - - fast_tokenizer.add_eos_token = False - fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) - assert fast == [1, 319, 4559, 1243] - - fast_tokenizer.add_eos_token = True - print(fast_tokenizer.add_eos_token) - fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) - assert fast == [1, 319, 4559, 1243, 2] - - slow_tokenizer.add_eos_token = True - slow = slow_tokenizer.encode("A sample test", add_special_tokens=True) - assert slow == [1, 319, 4559, 1243, 2] - - fast_tokenizer = BLTTokenizerFast.from_pretrained( - "hf-internal-testing/blt-tokenizer", add_eos_token=True, add_bos_token=False - ) - fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) - assert fast == [319, 4559, 1243, 2] - - slow_tokenizer = BLTTokenizer.from_pretrained( - "hf-internal-testing/blt-tokenizer", add_eos_token=True, add_bos_token=False - ) - slow = slow_tokenizer.encode("A sample test", add_special_tokens=True) - assert slow == [319, 4559, 1243, 2] - - self.tokenizer.add_eos_token = False - self.rust_tokenizer.add_eos_token = False - - @slow - def test_conversion(self): - # This is excruciatingly slow since it has to recreate the entire merge - # list from the original vocabulary in spm - self.rust_tokenizer.save_pretrained("./out") - with tempfile.TemporaryDirectory() as dirname: - self.rust_tokenizer.save_pretrained(dirname) - - with open(os.path.join(dirname, "tokenizer.json")) as f: - old_serialized = f.read() - - new_tokenizer = convert_slow_tokenizer(self.tokenizer) - with tempfile.NamedTemporaryFile() as f: - new_tokenizer.save(f.name) - # Re-opening since `f` is in bytes. - new_serialized = open(f.name).read() - with open("out_tokenizer.json", "w") as g: - g.write(new_serialized) - - self.assertEqual(old_serialized, new_serialized) + @unittest.skip("BLT encodes added tokens as bytes, not single tokens") + def test_add_tokens_tokenizer(self): + pass - def test_simple_encode_decode(self): - pyth_tokenizer = self.tokenizer - rust_tokenizer = self.rust_tokenizer - - self.assertEqual(pyth_tokenizer.encode("This is a test"), [1, 910, 338, 263, 1243]) - self.assertEqual(rust_tokenizer.encode("This is a test"), [1, 910, 338, 263, 1243]) - self.assertEqual(pyth_tokenizer.decode([1, 910, 338, 263, 1243], skip_special_tokens=True), "This is a test") - self.assertEqual(rust_tokenizer.decode([1, 910, 338, 263, 1243], skip_special_tokens=True), "This is a test") - - # bytefallback showcase - self.assertEqual(pyth_tokenizer.encode("生活的真谛是"), [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392]) # fmt: skip - self.assertEqual(rust_tokenizer.encode("生活的真谛是"), [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392]) # fmt: skip - self.assertEqual( - pyth_tokenizer.decode( - [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392], skip_special_tokens=True - ), - "生活的真谛是", - ) - self.assertEqual( - rust_tokenizer.decode( - [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392], skip_special_tokens=True - ), - "生活的真谛是", - ) - - # Inner spaces showcase - self.assertEqual(pyth_tokenizer.encode("Hi Hello"), [1, 6324, 29871, 15043]) - self.assertEqual(rust_tokenizer.encode("Hi Hello"), [1, 6324, 29871, 15043]) - self.assertEqual(pyth_tokenizer.decode([1, 6324, 29871, 15043], skip_special_tokens=True), "Hi Hello") - self.assertEqual(rust_tokenizer.decode([1, 6324, 29871, 15043], skip_special_tokens=True), "Hi Hello") + @unittest.skip("BLT tokenizer serialization needs additional work for added tokens") + def test_save_and_load_tokenizer(self): + pass - self.assertEqual(pyth_tokenizer.encode("Hi Hello"), [1, 6324, 259, 15043]) - self.assertEqual(rust_tokenizer.encode("Hi Hello"), [1, 6324, 259, 15043]) - self.assertEqual(pyth_tokenizer.decode([1, 6324, 259, 15043], skip_special_tokens=True), "Hi Hello") - self.assertEqual(rust_tokenizer.decode([1, 6324, 259, 15043], skip_special_tokens=True), "Hi Hello") - - self.assertEqual(pyth_tokenizer.encode(""), [1]) - self.assertEqual(rust_tokenizer.encode(""), [1]) - - self.assertEqual(pyth_tokenizer.encode(" "), [1, 259]) - self.assertEqual(rust_tokenizer.encode(" "), [1, 259]) - - self.assertEqual(pyth_tokenizer.encode(" "), [1, 1678]) - self.assertEqual(rust_tokenizer.encode(" "), [1, 1678]) - - self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043]) - self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043]) - - def test_no_differences_showcase(self): - pyth_tokenizer = self.tokenizer - rust_tokenizer = self.rust_tokenizer - self.assertEqual(pyth_tokenizer.encode(""), [1]) - self.assertEqual(rust_tokenizer.encode(""), [1]) - - self.assertEqual(pyth_tokenizer.encode(" "), [1, 259]) - self.assertEqual(rust_tokenizer.encode(" "), [1, 259]) - - self.assertEqual(pyth_tokenizer.encode(" "), [1, 1678]) - self.assertEqual(rust_tokenizer.encode(" "), [1, 1678]) - - self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043]) - self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043]) - - self.assertEqual(pyth_tokenizer.encode(""), [1, 1]) - self.assertEqual(rust_tokenizer.encode(""), [1, 1]) - - def test_no_differences_decode(self): - pyth_tokenizer = self.tokenizer - rust_tokenizer = self.rust_tokenizer - - self.assertEqual(pyth_tokenizer.decode([869]), ".") - self.assertEqual(rust_tokenizer.decode([869]), ".") - - self.assertEqual(pyth_tokenizer.decode([30112, 869]), "ا .") - self.assertEqual(rust_tokenizer.decode([30112, 869]), "ا .") - - def test_no_differences_special_tokens(self): - pyth_tokenizer = self.tokenizer - rust_tokenizer = self.rust_tokenizer - self.assertEqual(pyth_tokenizer.encode(""), [1]) - self.assertEqual(rust_tokenizer.encode(""), [1]) - - self.assertEqual(pyth_tokenizer.encode(""), [1, 1]) - self.assertEqual(rust_tokenizer.encode(""), [1, 1]) - - @unittest.skipIf( - os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0", - "RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests", - ) - def test_integration_test_xnli(self): - import tqdm - - pyth_tokenizer = self.tokenizer - rust_tokenizer = self.rust_tokenizer - - dataset = load_dataset("google/code_x_glue_ct_code_to_text", "go") - for item in tqdm.tqdm(dataset["validation"]): - string = item["code"] - encoded1 = pyth_tokenizer.encode(string) - encoded2 = rust_tokenizer.encode(string) - - self.assertEqual(encoded1, encoded2) - - decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True) - decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True) - - self.assertEqual(decoded1, decoded2) - - dataset = load_dataset("facebook/xnli", "all_languages") - - for item in tqdm.tqdm(dataset["train"]): - for string in item["premise"].values(): - encoded1 = pyth_tokenizer.encode(string) - encoded2 = rust_tokenizer.encode(string) - - self.assertEqual(encoded1, encoded2) - - decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True) - decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True) - - self.assertEqual(decoded1, decoded2) - - def test_special_token_special_word(self): - # the word inform should be split as ['in', 'form'] - tokenizer = BLTTokenizerFast.from_pretrained("huggyblt/blt-7b", legacy=False, from_slow=True) - tokenizer.add_tokens([AddedToken("", rstrip=True, lstrip=True)], special_tokens=False) - - example_inputs = tokenizer.tokenize("inform. Hey. .") - self.assertEqual(example_inputs, ["", "in", "form", "", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."]) - - # Make sure dummy space is added if it is indeed the first word - example_inputs = tokenizer.tokenize("inform. Hey. .") - self.assertEqual(example_inputs, ["▁inform", "", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."]) - out1 = tokenizer.decode( - tokenizer.encode("inform", add_special_tokens=False), spaces_between_special_tokens=False - ) - self.assertEqual(out1, "inform") - out2 = tokenizer.decode( - tokenizer.encode("inform", add_special_tokens=False), spaces_between_special_tokens=True - ) - # decoding strips the added prefix space. - self.assertEqual(out2, "inform") - input_ids = tokenizer.encode("inform", add_special_tokens=False) - self.assertEqual(input_ids, [32000, 262, 689]) # 29871 is the spiece underline, '▁' added as it should - - out2 = tokenizer.decode( - tokenizer.encode(" inform", add_special_tokens=False), spaces_between_special_tokens=False - ) - # TODO @ArthurZ currently we strip left and right, so this will not keep the spaces - self.assertEqual(out2, "inform") - - ### Let's make sure decoding does not add extra spaces here and there - # TODO @ArthurZ this should be affected by the lstrip/rstrip/single word /normalize refactoring - # Since currently we always strip left and right of the token, results are as such - input_ids = tokenizer.encode(" Hellohow", add_special_tokens=False) - self.assertEqual(input_ids, [1, 15043, 1, 3525]) - tokens = tokenizer.tokenize(" Hellohow", add_special_tokens=False) - self.assertEqual(tokens, ["", "▁Hello", "", "how"]) - decoded_tokens = tokenizer.decode(input_ids) - self.assertEqual(decoded_tokens, " Hellohow") - - # Let's make sure that if there are any spaces, we don't remove them! - input_ids = tokenizer.encode(" Hello how", add_special_tokens=False) - self.assertEqual(input_ids, [29871, 1, 15043, 1, 920]) - tokens = tokenizer.tokenize(" Hello how", add_special_tokens=False) - self.assertEqual(tokens, ["▁", "", "▁Hello", "", "▁how"]) - decoded_tokens = tokenizer.decode(input_ids) - self.assertEqual(decoded_tokens, " Hello how") - - # Let's make sure the space is preserved - input_ids = tokenizer.encode("hello", add_special_tokens=True) - self.assertEqual(input_ids, [1, 22172]) - tokens = tokenizer.tokenize("hello") - self.assertEqual(tokens, ["▁hello"]) - decoded_tokens = tokenizer.decode(input_ids) - self.assertEqual(decoded_tokens, " hello") - - input_ids = tokenizer.encode("hello", add_special_tokens=False) - self.assertEqual(input_ids, [22172]) - decoded_tokens = tokenizer.decode(input_ids) - self.assertEqual(decoded_tokens, "hello") - - def test_no_prefix_space(self): - tokenizer_no_prefix_space = BLTTokenizerFast.from_pretrained("huggyblt/blt-7b", add_prefix_space=False) - no_prefix_space_tokens = tokenizer_no_prefix_space.tokenize("Hey") - self.assertEqual(no_prefix_space_tokens, ["H", "ey"]) - - tokenizer = BLTTokenizerFast.from_pretrained( - "huggyblt/blt-7b", legacy=False, from_slow=True, add_prefix_space=False - ) - tokenizer.add_tokens([AddedToken("", rstrip=True, lstrip=True)], special_tokens=False) - - example_inputs = tokenizer.tokenize("inform. Hey. .") - self.assertEqual(example_inputs, ["", "in", "form", "", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."]) - - # Make sure dummy space is added if it is indeed the first word - example_inputs = tokenizer.tokenize("inform. Hey. .") - self.assertEqual(example_inputs, ["in", "form", "", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."]) - out1 = tokenizer.decode( - tokenizer.encode("inform", add_special_tokens=False), spaces_between_special_tokens=False - ) - self.assertEqual(out1, "inform") - out2 = tokenizer.decode( - tokenizer.encode("inform", add_special_tokens=False), spaces_between_special_tokens=True - ) - # decoding strips the added prefix space. - self.assertEqual(out2, "inform") - input_ids = tokenizer.encode("inform", add_special_tokens=False) - self.assertEqual(input_ids, [32000, 262, 689]) # 29871 is the spiece underline, '▁' added as it should - - out2 = tokenizer.decode( - tokenizer.encode(" inform", add_special_tokens=False), spaces_between_special_tokens=False - ) - self.assertEqual(out2, "inform") - - input_ids = tokenizer.encode(" Hellohow", add_special_tokens=False) - self.assertEqual(input_ids, [1, 15043, 1, 3525]) - tokens = tokenizer.tokenize(" Hellohow", add_special_tokens=False) - self.assertEqual(tokens, ["", "▁Hello", "", "how"]) - decoded_tokens = tokenizer.decode(input_ids) - self.assertEqual(decoded_tokens, " Hellohow") - - # Let's make sure that if there are any spaces, we don't remove them! - input_ids = tokenizer.encode(" Hello how", add_special_tokens=False) - self.assertEqual(input_ids, [29871, 1, 15043, 1, 920]) - tokens = tokenizer.tokenize(" Hello how", add_special_tokens=False) - self.assertEqual(tokens, ["▁", "", "▁Hello", "", "▁how"]) - decoded_tokens = tokenizer.decode(input_ids) - self.assertEqual(decoded_tokens, " Hello how") - - # Let's make sure the space is preserved - input_ids = tokenizer.encode("hello", add_special_tokens=True) - self.assertEqual(input_ids, [1, 12199]) - tokens = tokenizer.tokenize("hello") - self.assertEqual(tokens, ["hello"]) - decoded_tokens = tokenizer.decode(input_ids) - self.assertEqual(decoded_tokens, "hello") - - input_ids = tokenizer.encode("hello", add_special_tokens=False) - self.assertEqual(input_ids, [12199]) - decoded_tokens = tokenizer.decode(input_ids) - self.assertEqual(decoded_tokens, "hello") - - def test_some_edge_cases(self): - tokenizer = BLTTokenizer.from_pretrained("huggyblt/blt-7b", legacy=False) - - sp_tokens = tokenizer.sp_model.encode(">", out_type=str) - self.assertEqual(sp_tokens, ["<", "s", ">>"]) - tokens = tokenizer.tokenize(">") - self.assertNotEqual(sp_tokens, tokens) - self.assertEqual(tokens, ["", ">"]) - - tokens = tokenizer.tokenize("") - self.assertEqual(tokens, []) - self.assertEqual(tokens, tokenizer.sp_model.encode("", out_type=str)) - - tokens = tokenizer.tokenize(" ") - self.assertEqual(tokens, ["▁▁"]) - # a dummy prefix space is not added by the sp_model as it was de-activated - self.assertEqual(tokens, tokenizer.sp_model.encode(" ", out_type=str)) - - tokens = tokenizer.tokenize("▁") - self.assertEqual(tokens, ["▁▁"]) - # a dummy prefix space is not added by the sp_model as it was de-activated - self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁", out_type=str)) - - tokens = tokenizer.tokenize(" ▁") - self.assertEqual(tokens, ["▁▁▁"]) - # a dummy prefix space is not added by the sp_model as it was de-activated - self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str)) - - def test_fast_post_processor(self): - tokenizer = BLTTokenizerFast( - SAMPLE_VOCAB, eos_token=None, bos_token=None, add_bos_token=False, add_eos_token=False - ) - tokenizer.encode(" Hey ") - - with self.assertRaises(ValueError): - tokenizer = BLTTokenizerFast( - SAMPLE_VOCAB, bos_token=None, eos_token="", add_bos_token=True, add_eos_token=False - ) - with self.assertRaises(ValueError): - tokenizer = BLTTokenizerFast(SAMPLE_VOCAB, eos_token=None, add_bos_token=True, add_eos_token=True) - - @require_jinja - def test_tokenization_for_chat(self): - tokenizer = BLTTokenizer.from_pretrained("huggyblt/blt-7b", legacy=False) - - test_chats = [ - [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}], - [ - {"role": "system", "content": "You are a helpful chatbot."}, - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Nice to meet you."}, - ], - [{"role": "user", "content": "Hello!"}], - ] - # Matt: The third test case tests the default system message, but if this is ever changed in the - # class/repo code then that test will fail, and the case will need to be updated. - tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats] - # fmt: off - expected_tokens = [ - [1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962], - [1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962, 20103, 304, 5870, 366, 29889, 29871, 2], - [1, 29961, 25580, 29962, 15043, 29991, 518, 29914, 25580, 29962] - ] - # fmt: on - for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): - self.assertListEqual(tokenized_chat, expected_tokens) - - -@require_sentencepiece -@require_tokenizers -class CommonSpmIntegrationTests(unittest.TestCase): - """ - A class that regroups important test to make sure that we properly handle the special tokens. - """ - @classmethod - def setUpClass(cls): - tokenizer = BLTTokenizer(SAMPLE_VOCAB, extra_ids=0, add_bos_token=False, legacy=False) - tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("", rstrip=False, lstrip=False)]}) - cls.tokenizer = tokenizer - return cls - - def test_add_dummy_prefix(self): - # make sure `'▁'` is prepended, and outputs match sp_model's - # `sentencepiece.NormalizerSpec.add_dummy_prefix` attribute - input_ids = self.tokenizer.encode(". Hello") - self.assertEqual(input_ids, [7, 4, 156, 86, 20]) - sp_encode = self.tokenizer.sp_model.encode(". Hello") - self.assertEqual(input_ids, [7] + sp_encode) - tokens = self.tokenizer.tokenize(". Hello") - self.assertEqual(tokens, ["▁", ".", "▁He", "ll", "o"]) - - tokens = self.tokenizer.tokenize("") - self.assertEqual(tokens, []) - self.assertEqual(tokens, self.tokenizer.sp_model.encode("", out_type=str)) - - tokens = self.tokenizer.tokenize(" ") - self.assertEqual(tokens, []) - self.assertEqual(tokens, self.tokenizer.sp_model.encode(" ", out_type=str)) - - tokens = self.tokenizer.tokenize("▁") - self.assertEqual(tokens, []) - self.assertEqual(tokens, self.tokenizer.sp_model.encode("▁", out_type=str)) - - def test_remove_extra_whitespaces(self): - # make sure the extra spaces are eaten. Since the sample vocab does not have - # `______`. sentencepiece.NormalizerSpec.remove_extra_whitespaces attribute is set to False - - input_ids = self.tokenizer.encode(" . Hello") - self.assertEqual(input_ids, [7, 4, 156, 86, 20]) - sp_encode = self.tokenizer.sp_model.encode(" . Hello") - self.assertEqual(input_ids, [7] + sp_encode) - tokens = self.tokenizer.tokenize(" . Hello") - self.assertEqual(tokens, ["▁", ".", "▁He", "ll", "o"]) - - # `'▁'` is also a whitespace - input_ids = self.tokenizer.encode("▁He is not") - self.assertEqual(input_ids, [156, 46, 44]) - tokens = self.tokenizer.tokenize("▁He is not") - sp_encode = [ - self.tokenizer.sp_model.piece_to_id("▁He"), - self.tokenizer.sp_model.piece_to_id("▁is"), - self.tokenizer.sp_model.piece_to_id("▁not"), - ] - self.assertEqual(input_ids, sp_encode) - self.assertEqual(tokens, ["▁He", "▁is", "▁not"]) # no extra space added - - input_ids = self.tokenizer.encode("▁He is not ▁He") - self.assertEqual(input_ids, [156, 46, 44, 1, 156]) - tokens = self.tokenizer.tokenize("▁He is not ▁He") - self.assertEqual(tokens, ["▁He", "▁is", "▁not", "", "▁He"]) # spaces are eaten by spm + our strip - # make sure that the output after the extra id is the same as if - # extra_id was not there - input_ids = self.tokenizer.encode("▁He is not ▁He") - self.assertEqual(input_ids, [156, 46, 44, 156]) - tokens = self.tokenizer.tokenize("▁He is not ▁He") - self.assertEqual(tokens, ["▁He", "▁is", "▁not", "▁He"]) # spaces are eaten by spm even if not start - - def test_character_after_special_token(self): - # Make sure that `tokenizer.tokenize` is similar to - # adding the equivalent special token to the vocab - input_ids = self.tokenizer.encode("Hey I") - self.assertEqual(input_ids, [156, 30, 1, 100]) - sp_encode = self.tokenizer.sp_model.encode("Hey .I") - # the last token should be 100 - self.assertEqual(input_ids[-1], sp_encode[-1]) - tokens = self.tokenizer.tokenize("I") - self.assertEqual(tokens, ["", "I"]) - - input_ids = self.tokenizer.encode("Hello, ,") - self.assertEqual(input_ids, [156, 86, 20, 3, 1, 3]) - tokens = self.tokenizer.tokenize("Hello, ,") - self.assertEqual(tokens, ["▁He", "ll", "o", ",", "", ","]) - - def test_special_tokens_strip(self): - input_ids = self.tokenizer.encode(" ,") - self.assertEqual(input_ids, [1, 7, 3]) - tokens = self.tokenizer.tokenize(" ,") - # spaces are eaten by rstrip / lstrip + spm sp_model.encode(" ") = [] - self.assertEqual(tokens, ["", "▁", ","]) - - input_ids = self.tokenizer.encode("No ▁He") - self.assertEqual(input_ids, [284, 1, 156]) - tokens = self.tokenizer.tokenize("No ▁He") - self.assertEqual(tokens, ["▁No", "", "▁He"]) # spaces are eaten by rstrip / lstrip - - -@require_tiktoken -@require_read_token -class TikTokenIntegrationTests(unittest.TestCase): - """ - A class that regroups important test to make sure that we properly handle the special tokens. - """ - - def test_tiktoken_blt(self): - model_path = "hf-internal-testing/blt-3-8b-internal" - subfolder = "original" - test_text = "This is a test sentence." - test_tokens = [128000, 2028, 374, 264, 1296, 11914, 13, 128001] - num_reserved_special_tokens = 256 - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|reserved_special_token_2|>", - "<|reserved_special_token_3|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|reserved_special_token_4|>", - "<|eot_id|>", - "<|python_tag|>", # end of turn - ] + [f"<|reserved_special_token_{i}|>" for i in range(5, num_reserved_special_tokens - 5)] - - tiktoken_tokenizer = PreTrainedTokenizerFast.from_pretrained( - model_path, - subfolder=subfolder, - additional_special_tokens=special_tokens, - bos_token="<|begin_of_text|>", - eos_token="<|end_of_text|>", - ) - tokens = tiktoken_tokenizer.tokenize("<|begin_of_text|> " + test_text) - self.assertEqual(tokens[0], "<|begin_of_text|>") - - tiktoken_tokenizer = AutoTokenizer.from_pretrained( - model_path, - subfolder=subfolder, - legacy=False, - additional_special_tokens=special_tokens, - bos_token="<|begin_of_text|>", - eos_token="<|end_of_text|>", - add_bos_token=True, - add_eos_token=True, - ) - self.assertTrue(isinstance(tiktoken_tokenizer, PreTrainedTokenizerFast)) - - tokens = tiktoken_tokenizer.encode(test_text, add_special_tokens=True) - self.assertEqual(tokens, test_tokens) - - tmpdirname = tempfile.mkdtemp() - tiktoken_tokenizer.save_pretrained(tmpdirname) - tokenizer_reload = AutoTokenizer.from_pretrained(tmpdirname) - - self.assertTrue(isinstance(tokenizer_reload, PreTrainedTokenizerFast)) - tokens = tokenizer_reload.encode(test_text, add_special_tokens=True) - self.assertEqual(tokens, test_tokens) - shutil.rmtree(tmpdirname) - - tiktoken_tokenizer = AutoTokenizer.from_pretrained( - model_path, - subfolder=subfolder, - additional_special_tokens=special_tokens, - bos_token="<|begin_of_text|>", - eos_token="<|end_of_text|>", - from_slow=True, - add_bos_token=True, - add_eos_token=True, - ) - tokens = tiktoken_tokenizer.encode(test_text, add_special_tokens=True) - self.assertEqual(tokens, test_tokens) +if __name__ == "__main__": + unittest.main() From 3a22e052293b4c4b3bc25c62b9cc8d284699f3bb Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 8 Jul 2025 12:55:45 +0000 Subject: [PATCH 062/139] tokenizer clean up --- src/transformers/models/blt/tokenization_blt.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py index 646b78a58120..0669b0299d9e 100644 --- a/src/transformers/models/blt/tokenization_blt.py +++ b/src/transformers/models/blt/tokenization_blt.py @@ -177,10 +177,6 @@ def _tokenize(self, text: str, **kwargs) -> List[str]: """Converts a string to a list of tokens. For BLT, we work directly with byte values.""" return [str(byte_val) for byte_val in text.encode("utf-8", errors="ignore")] - # def decode(self, token_ids, skip_special_tokens: bool = False, clean_up_tokenization_spaces: Optional[bool] = None, **kwargs): - # """Converts a sequence of ids in a string, using the tokenizer and vocabulary.""" - # return super().decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs) - def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: From d7b572149ef1c46547f5f874488d9611a953ac19 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 8 Jul 2025 12:56:00 +0000 Subject: [PATCH 063/139] modular file --- src/transformers/models/blt/modular_blt.py | 817 ++++++++------------- 1 file changed, 286 insertions(+), 531 deletions(-) diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index f433e2c8b799..19b3b2a3ae21 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -12,228 +12,146 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""BLT model.""" +"""BLT modular model, inheriting from Mllama where appropriate.""" -from ...utils import is_torch_flex_attn_available, logging from typing import Callable, List, Optional, Tuple, Union - -from ...cache_utils import Cache -from ...activations import ACT2FN +from enum import Enum import torch -import torch.distributions -import torch.nn import torch.nn as nn +import torch.nn.functional as F from torch.nn import functional as F +from ...cache_utils import Cache +from ...activations import ACT2FN +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_outputs import CausalLMOutputWithPast +from ...generation.utils import GenerationMixin +from ...utils import logging, is_torch_flex_attn_available from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +# Import configuration classes from .configuration_blt import ( BLTConfig, BLTLocalEncoderConfig, BLTLocalDecoderConfig, BLTGlobalTransformerConfig, BLTPatcherConfig, - PatchingModeEnum, ) if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask from ...integrations.flex_attention import make_flex_block_causal_mask -from ..mllama.modeling_mllama import repeat_kv, eager_attention_forward, MllamaRotaryEmbedding, MllamaTextRMSNorm, MllamaCrossAttentionDecoderLayer, MllamaTextCrossAttention, MllamaTextSelfAttention +# Import from mllama for inheritance +from ..mllama.modeling_mllama import ( + MllamaTextMLP, + MllamaTextRMSNorm, + MllamaRotaryEmbedding, + MllamaTextCrossAttention, + MllamaSelfAttentionDecoderLayer, + MllamaPreTrainedModel, + eager_attention_forward, + repeat_kv, + apply_rotary_pos_emb as mllama_apply_rotary_pos_emb, +) + +# Import other utility functions and classes from original BLT +from .modeling_blt import ( + PatchingModeEnum, + byte_group_hash_function, + rolling_polynomial_hash, + init_hash_embeddings, + compute_hash_embeddings, + _prepare_patch_cross_attention_mask, + process_patch_lengths, + apply_rotary_pos_emb, +) logger = logging.get_logger(__name__) -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - # TODO: not exactly equivalent to other transformers implementations,, need feedback - # Extract first head_dim//2 elements which correspond to the unique frequencies - # This matches the original BLT approach which uses head_dim//2 frequency pairs - head_dim = q.shape[-1] - cos_freqs = cos[..., :head_dim//2] # [B, S, D/2] - sin_freqs = sin[..., :head_dim//2] # [B, S, D/2] - - # Expand cos/sin to match query/key tensor format [B, H, S, D/2] - cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] - sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] - - # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... - q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] - k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] - - # Extract real and i parts - q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] - k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] - - # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] - q_real_rot = cos_freqs * q_real - sin_freqs * q_imag - q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag - k_real_rot = cos_freqs * k_real - sin_freqs * k_imag - k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag - - # Recombine pairs and reshape back to original format - q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] - k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] - - return q_rot.type_as(q), k_rot.type_as(k) +# ============================================================================== +# INHERITED COMPONENTS (minimal changes from Mllama) +# ============================================================================== +class BLTMLP(MllamaTextMLP): + pass -class BLTMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class BLTRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - BLTRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) +class BLTRMSNorm(MllamaTextRMSNorm): + pass - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +class BLTRotaryEmbedding(MllamaRotaryEmbedding): + pass -class BLTTransformerLayer(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - self.layer_idx = layer_idx - self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) - self.mlp = BLTMLP(config) - self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) - self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) +# ============================================================================== +# INHERITED BUT CUSTOMIZED COMPONENTS +# ============================================================================== - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - position_ids (`torch.LongTensor`, *optional*): - Position indices of tokens in the sequence for RoPE computation. - past_key_value (`Cache`, *optional*): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) +class BLTPreTrainedModel(MllamaPreTrainedModel): + """BLT PreTrainedModel inheriting from Mllama but with BLT-specific init.""" + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False # BLT uses its own attention implementation + _supports_sdpa = True + _supports_cache_class = False - return outputs + def _init_weights(self, module): + if isinstance(module, nn.Linear): + std = getattr(module, '_custom_std', module.in_features ** (-0.5)) + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.Embedding): + std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + elif isinstance(module, BLTModel): + if module.encoder_hash_tok_embedding is not None: + emb_std = module.config.encoder_config.hidden_size ** (-0.5) + for emb in module.encoder_hash_tok_embedding: + emb._custom_std = emb_std + + elif isinstance(module, BLTLocalEncoder): + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) + + elif isinstance(module, BLTLocalDecoder): + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) + + elif isinstance(module, BLTPatcher): + emb_std = module.config.hidden_size ** (-0.5) + module.embed_tokens._custom_std = emb_std + module.lm_head._custom_std = emb_std + + elif isinstance(module, BLTForCausalLM): + if module.lm_head is not None: + module.lm_head._custom_std = module.config.decoder_config.hidden_size ** (-0.5) class BLTSelfAttention(nn.Module): + """BLT Self Attention that could inherit from Mllama but has some BLT-specific patterns.""" + def __init__(self, config, layer_idx: int): super().__init__() self.config = config @@ -243,7 +161,7 @@ def __init__(self, config, layer_idx: int): self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = None + self.scaling = self.head_dim ** -0.5 self.rope_theta = config.rope_theta self.layer_idx = layer_idx @@ -284,8 +202,6 @@ def forward( attention_interface: Callable = eager_attention_forward output_attentions = False - # self.config._attn_implementation = "sdpa" - # self.scaling = None if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -313,252 +229,39 @@ def forward( attn_weights = None return attn_output, attn_weights, past_key_value - - -def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): - primes = [ - 1000000007, 5915587277, 1500450271, 3267000013, 5754853343, - 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313, - ] - prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) - powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) - prime_powers = prime ** powers - return torch.sum(token_tensor * prime_powers, dim=-1) - - -def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): - """Hash token groups and map to range [0, max_hash].""" - with torch.no_grad(): - batch_size, seq_len = token_ids.shape - # Add padding for sliding window - padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) - padded_tokens = torch.cat([padding, token_ids], dim=1) - - # Create sliding windows and compute hashes - windows = padded_tokens.unfold(1, group_size, 1) - hashes = rolling_polynomial_hash(windows, hash_func_nb) - hash_values = hashes % max_hash - - hash_values.requires_grad = False - return hash_values - - -def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): - """Initialize hash-based token embeddings for the BLT encoder.""" - num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) - embeddings = [ - nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) - for _ in range(num_embeddings) - ] - return nn.ModuleList(embeddings) - - -def compute_hash_embeddings( - local_encoder_tokens: torch.Tensor, - local_encoder, - encoder_hash_tok_embedding: nn.ModuleList, - encoder_hash_byte_group_nb_functions: int, - encoder_hash_byte_group_size: list, - encoder_hash_byte_group_vocab: int, -) -> torch.Tensor: - """Compute token embeddings enhanced with hash-based embeddings.""" - embeddings = local_encoder.embed_tokens(local_encoder_tokens) - embedding_idx = 0 - for func_nb in range(encoder_hash_byte_group_nb_functions): - for group_size in encoder_hash_byte_group_size: - hash_ids = byte_group_hash_function( - local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab - ) - embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) - embedding_idx += 1 - - return embeddings - - -def _prepare_patch_cross_attention_mask( - patch_ids: torch.Tensor, - num_patches: int, - sequence_length: int, - patches_as_queries: bool = False, - cross_attn_k: int = 1, - dtype: torch.dtype = torch.float32, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Prepare cross-attention mask for patch-based attention, following mllama's robust approach. - - This function creates masks that control which patches can attend to which other patches, - with support for query/key role swapping and cross-attention multipliers. - - Args: - patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. - num_patches (int): Total number of patches. - sequence_length (int): Length of the sequence. - patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. - cross_attn_k (int): Cross-attention multiplier for repeating patches. - dtype (torch.dtype): Data type for the output mask. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] - - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows - """ - batch_size, seq_len = patch_ids.shape - device = patch_ids.device - - # Determine query and key lengths based on configuration - if patches_as_queries: - q_len = num_patches * cross_attn_k - kv_len = sequence_length - # Create patch-to-sequence mapping - q_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(-1).expand( - batch_size, num_patches, seq_len - ) - kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) - else: - q_len = sequence_length - kv_len = num_patches * cross_attn_k - # Create sequence-to-patch mapping - q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) - kv_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand( - batch_size, seq_len, num_patches - ) - - # Create base attention mask - boolean mask where True means "should attend" - # Exact patch matching - cross_attention_mask = q_patch_ids == kv_patch_ids - - # Handle cross_attn_k multiplier by repeating along appropriate dimension - repeat_dim = 1 if patches_as_queries else -1 - cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) - - # Validate dimensions - expected_shape = (batch_size, q_len, kv_len) - if cross_attention_mask.shape != expected_shape: - raise ValueError(f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}") - - # Reshape so it can be used by attn module - add head dimension - cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len] - - # Invert the mask (following mllama pattern exactly) - # True -> 0.0 (attend), False -> 1.0 (will become -inf) - inverted_cross_attn_mask = (1.0 - cross_attention_mask.to(dtype)) - cross_attention_mask = inverted_cross_attn_mask.masked_fill( - inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min - ) - - # Apply full-row bias (following mllama pattern exactly) - # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's - # last dimension contains negative infinity values, otherwise it's 1 - negative_inf_value = torch.finfo(dtype).min - full_text_row_masked_out_mask = ( - (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] - ) - cross_attention_mask *= full_text_row_masked_out_mask - - return cross_attention_mask, full_text_row_masked_out_mask - - -def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: - """ - Splits patch lengths into smaller segments if they exceed `max_patch_length`. - Pads the result to uniform length across the batch. - Args: - patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. - max_patch_length (int, optional): Maximum allowed length per patch. - Returns: - torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. - """ - if max_patch_length is None: - return patch_lengths - - batch_size = patch_lengths.size(0) - processed = [] - - for seq in patch_lengths: - splits = [] - for length in seq[seq > 0]: - length = length.item() - full_chunks, remainder = divmod(length, max_patch_length) - splits.extend([max_patch_length] * full_chunks) - if remainder: - splits.append(remainder) - processed.append(splits) - - # Find max length to pad to - max_len = max(len(splits) for splits in processed) - padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) +class BLTTransformerLayer(MllamaSelfAttentionDecoderLayer): + pass - for i, splits in enumerate(processed): - if splits: - padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) - - # Trim zero columns - if (padded != 0).any(dim=0).sum() < padded.shape[1]: - last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 - padded = padded[:, :last_nonzero] - - return padded - - -class BLTRotaryEmbedding(nn.Module): - def __init__(self, config, device=None): - super().__init__() - self.rope_type = config.rope_scaling["rope_type"] - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +# ============================================================================== +# BLT-SPECIFIC COMPONENTS (no Mllama equivalent) +# ============================================================================== class BLTLocalEncoder(nn.Module): def __init__(self, config: BLTLocalEncoderConfig): super().__init__() - self.hidden_size = config.hidden_size - self.vocab_size=config.vocab_size - self.num_hidden_layers = config.num_hidden_layers - self.dropout = config.dropout - self.cross_attn_all_layers = config.cross_attn_all_layers - self.cross_attn_k = config.cross_attn_k + self.config = config - self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)]) + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.rotary_emb = BLTRotaryEmbedding(config=config) self.patch_embedding_projection = nn.Linear( - in_features=config.encoder_dim_patch_emb, - out_features=config.encoder_dim_token_emb * config.cross_attn_k, + in_features=config.hidden_size, + out_features=config.hidden_size * config.cross_attn_k, bias=False, ) - self.embed_tokens = nn.Embedding(self.vocab_size + config.pm_size, self.hidden_size) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1 + layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size) + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) def forward( @@ -579,23 +282,23 @@ def forward( batch_size, _, _ = input_embeds.shape - hidden_states = nn.functional.dropout(input_embeds, p=self.dropout, training=self.training) + hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for idx, layer in enumerate(self.layers): layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) hidden_states = layer_outputs[0] - if idx == len(self.layers) - 1 or self.cross_attn_all_layers: + if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) - layer_idx = idx if self.cross_attn_all_layers else 0 + layer_idx = idx if self.config.cross_attn_all_layers else 0 cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, @@ -643,38 +346,33 @@ def __init__(self, config: BLTLocalDecoderConfig): super().__init__() # Extract config values to instance attributes - self.hidden_size = config.hidden_size - self.vocab_size=config.vocab_size - self.num_hidden_layers = config.num_hidden_layers - self.dropout = config.dropout + self.config = config self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove - self.cross_attn_all_layers = config.cross_attn_all_layers - self.cross_attn_k = config.cross_attn_k - self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)]) + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.rotary_emb = BLTRotaryEmbedding(config=config) self.patch_embedding_projection = nn.Linear( in_features=config.hidden_size_global, - out_features=config.decoder_dim_token_emb * config.cross_attn_k, + out_features=config.hidden_size * config.cross_attn_k, bias=False, ) - self.norm = BLTRMSNorm(self.hidden_size, eps=config.norm_eps) + self.norm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) self.cross_attn_layers = torch.nn.ModuleList() - layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1 + layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size) + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) - self.lm_head = nn.Linear( - self.hidden_size, - self.vocab_size, - bias=False, - ) + # self.lm_head = nn.Linear( + # config.hidden_size, + # config.vocab_size, + # bias=False, + # ) def forward( @@ -692,7 +390,7 @@ def forward( hidden_states = embeds patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds @@ -700,9 +398,9 @@ def forward( position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) - hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for i, layer in enumerate(self.layers): - if i == 0 or self.cross_attn_all_layers: + if i == 0 or self.config.cross_attn_all_layers: # Use cross attention to extract info from patch_embeds into hidden_states cross_attention_output, _, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, @@ -718,7 +416,8 @@ def forward( layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) hidden_states = layer_outputs[0] - logits = self.lm_head(self.norm(hidden_states)) + logits = self.norm(hidden_states) + # logits = self.lm_head(logits) return logits, cache @@ -730,12 +429,12 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.config = config self.layer_idx = layer_idx # Use provided hidden_size or fallback to encoder dimension - self.hidden_size = hidden_size or config.hidden_size_local_encoder + self.hidden_size = hidden_size or config.encoder_config.hidden_size self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = None + self.scaling = self.head_dim ** -0.5 self.dropout = config.dropout self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) @@ -787,8 +486,6 @@ def forward( attention_interface: Callable = eager_attention_forward - # self.config._attn_implementation = "sdpa" - # attn = "sdpa" if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -808,7 +505,7 @@ def forward( key_states, value_states, attention_mask, - dropout=0.0, #if not self.training else self.dropout, + dropout=0.0, scaling=self.scaling, **kwargs, ) @@ -816,7 +513,6 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - # Apply full row masking if provided (following mllama pattern) if full_text_row_masked_out_mask is not None: attn_output = full_text_row_masked_out_mask[:, 0] * attn_output @@ -832,12 +528,10 @@ class BLTGlobalTransformer(nn.Module): def __init__(self, config: BLTGlobalTransformerConfig): super().__init__() - self.hidden_size = config.hidden_size - self.num_hidden_layers = config.num_hidden_layers - self.dropout = config.dropout + self.config = config self.layers = nn.ModuleList() - for layer_idx in range(self.num_hidden_layers): + for layer_idx in range(config.num_hidden_layers): self.layers.append(BLTTransformerLayer(config, layer_idx)) self.rotary_emb = BLTRotaryEmbedding(config=config) @@ -853,7 +547,7 @@ def forward( hidden_states = input_embeds - hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -865,77 +559,18 @@ def forward( return hidden_states, cache - - -class BLTPreTrainedModel(PreTrainedModel): - config_class = BLTConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = False # BLT uses its own attention implementation - _supports_sdpa = True - _supports_cache_class = False - - def _init_weights(self, module): - if isinstance(module, nn.Linear): - std = getattr(module, '_custom_std', module.in_features ** (-0.5)) - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - if module.bias is not None: - nn.init.zeros_(module.bias) - - elif isinstance(module, nn.Embedding): - std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - - elif isinstance(module, BLTModel): - if module.encoder_hash_tok_embedding is not None: - emb_std = module.config.hidden_size_local_encoder ** (-0.5) - for emb in module.encoder_hash_tok_embedding: - emb._custom_std = emb_std - - elif isinstance(module, BLTLocalEncoder): - if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.encoder_dim_patch_emb ** (-0.5) - - elif isinstance(module, BLTLocalDecoder): - if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.hidden_size_global ** (-0.5) - - elif isinstance(module, BLTPatcher): - emb_std = module.config.hidden_size ** (-0.5) - module.embed_tokens._custom_std = emb_std - module.lm_head._custom_std = emb_std - - class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) - self.config = config - self.local_encoder = BLTLocalEncoder(config.encoder_config) self.global_transformer = BLTGlobalTransformer(config.global_config) self.local_decoder = BLTLocalDecoder(config.decoder_config) - self.encoder_hash_tok_embedding = init_hash_embeddings( config, - local_encoder_dim=config.hidden_size_local_encoder, + local_encoder_dim=config.encoder_config.hidden_size, encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, ) - if self.config.patch_in_forward: self.patcher = BLTPatcher(config.patcher_config) self.patcher.eval() @@ -944,9 +579,30 @@ def __init__(self, config: BLTConfig): else: self.patcher = None - def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None): + def forward( + self, + tokens: torch.Tensor, + patch_lengths: Optional[torch.Tensor] = None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + **kwargs, + ): + """ + Args: + tokens (torch.Tensor): Input token ids. + patch_lengths (Optional[torch.Tensor]): Patch lengths for patching. + attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Ignored, for compatibility. + Returns: + torch.Tensor: Final hidden states (as before). + """ batch_size, sequence_length = tokens.shape - # Handle patching if patch_lengths is None: if self.config.patching_mode == PatchingModeEnum.entropy: @@ -956,27 +612,23 @@ def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = threshold=self.config.patching_threshold, max_patch_length=self.config.max_patch_length, patching_batch_size=self.config.patching_batch_size, - device=self.config.patching_device, + device=tokens.device, ) else: - # Default to byte-level patching patch_lengths = process_patch_lengths( torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device), self.config.max_patch_length ) - patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) - cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( - patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32 - ) - encoder_embeds = compute_hash_embeddings( tokens, self.local_encoder, self.encoder_hash_tok_embedding, self.config.encoder_hash_byte_group_nb_functions, self.config.encoder_hash_byte_group_size, self.config.encoder_hash_byte_group_vocab, ) - + cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( + patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype + ) encoder_hidden_states, encoder_cross_states = self.local_encoder( input_ids=tokens, input_embeds=encoder_embeds, @@ -986,18 +638,14 @@ def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = num_patches=patch_lengths.shape[1], patch_ids=patch_ids, ) - global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - global_hidden_states, _ = self.global_transformer( input_embeds=global_hidden_states, ) - decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( - decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32 + decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, encoder_embeds.dtype ) - output, _ = self.local_decoder( tokens=tokens, embeds=encoder_hidden_states, @@ -1006,7 +654,11 @@ def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = cross_mask=cross_attn_mask_dec, full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, ) - + if output_hidden_states or output_attentions: + if return_dict: + return {"last_hidden_state": output, "hidden_states": None, "attentions": None} + else: + return (output, None, None) return output def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: @@ -1028,12 +680,12 @@ def __init__(self, config: BLTPatcherConfig): self.rotary_emb = BLTRotaryEmbedding(config=self.config) self.layers = nn.ModuleList() - # Create transformer layers using the patcher config + for layer_idx in range(self.config.num_hidden_layers): self.layers.append(BLTTransformerLayer(self.config, layer_idx)) - self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.hidden_size) + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps) @@ -1078,14 +730,14 @@ def forward( position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) # = BLT self.rope + position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, layer in enumerate(self.layers): - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) #, attn_impl=self.config.patcher_attn_impl ) + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) hidden_states = layer_outputs[0] logits = self.lm_head(self.norm(hidden_states)) - logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] + logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] predictions.append(logits) prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() entropies.append(prediction_entropies) @@ -1169,6 +821,103 @@ def patch_lengths_from_entropies( return patch_lengths + +class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + + def __init__(self, config): + super().__init__(config) + self.model = BLTModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.local_encoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.local_encoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + **kwargs, + ): + """ + Args: + input_ids (torch.LongTensor): Input token ids. + attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Standard transformers arguments. + labels (torch.LongTensor, optional): Labels for language modeling loss. + Returns: + CausalLMOutputWithPast or tuple: Standard transformers output. + """ + # Route only input_ids to BLTModel (as tokens) + hidden_states = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + if isinstance(hidden_states, dict): + sequence_output = hidden_states["last_hidden_state"] + elif isinstance(hidden_states, tuple): + sequence_output = hidden_states[0] + else: + sequence_output = hidden_states + logits = self.lm_head(sequence_output) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + if not return_dict: + output = (logits,) + if loss is not None: + output = (loss,) + output + return output + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + __all__ = [ "BLTPreTrainedModel", "BLTModel", @@ -1177,4 +926,10 @@ def patch_lengths_from_entropies( "BLTLocalDecoder", "BLTGlobalTransformer", "BLTTransformerLayer", -] \ No newline at end of file + "BLTForCausalLM", + "BLTMLP", + "BLTRMSNorm", + "BLTRotaryEmbedding", + "BLTSelfAttention", + "BLTCrossAttention", +] \ No newline at end of file From 359ecf11200eda94badaabb410f92d487be63e27 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 8 Jul 2025 13:14:31 +0000 Subject: [PATCH 064/139] fixing rebase --- src/convert_blt_to_hf.py | 397 ------------------ src/transformers/models/__init__.py | 15 +- .../models/auto/tokenization_auto.py | 39 +- .../models/blt/convert_blt_weights_to_hf.py | 7 +- 4 files changed, 53 insertions(+), 405 deletions(-) delete mode 100644 src/convert_blt_to_hf.py diff --git a/src/convert_blt_to_hf.py b/src/convert_blt_to_hf.py deleted file mode 100644 index 26c05477a169..000000000000 --- a/src/convert_blt_to_hf.py +++ /dev/null @@ -1,397 +0,0 @@ -import argparse -import json -import logging -import os -from typing import Any, Dict, Optional - -import torch -from huggingface_hub import hf_hub_download, upload_folder -from safetensors.torch import load_file, save_file - -from transformers.models.blt_wip.configuration_blt import BLTConfig -from transformers.models.blt_wip.modeling_blt import BLTModel -from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM -from transformers.utils import logging as transformers_logging - - -logger = transformers_logging.get_logger(__name__) -transformers_logging.set_verbosity_info() - - -def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]: - logger.info("Merging configurations") - - with open(config_path, "r") as f: - main_config = json.load(f) - - with open(entropy_params_path, "r") as f: - entropy_data = json.load(f) - - entropy_model_params = entropy_data.get("entropy_model", {}) - patcher_args = entropy_data.get("data", {}).get("patcher_args", {}) - - unified_config = main_config.copy()["args"] - - for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]: - if key in unified_config and not isinstance(unified_config[key], int): - unified_config[key] = int(unified_config[key]) - - patch_size = patcher_args.get("patch_size", 8) - if isinstance(patch_size, float): - patch_size = int(patch_size) - - # Create patcher config - patcher_hidden_size = int(entropy_model_params.get("dim", 512)) - patcher_multiple_of = int(entropy_model_params.get("multiple_of", 256)) - patcher_intermediate_size = patcher_multiple_of * ((int(8 * patcher_hidden_size / 3) + patcher_multiple_of - 1) // patcher_multiple_of) - - patcher_config = { - "vocab_size": int(entropy_model_params.get("vocab_size", 256)), - "hidden_size": patcher_hidden_size, - "num_hidden_layers": int(entropy_model_params.get("n_layers", 8)), - "num_attention_heads": int(entropy_model_params.get("n_heads", 8)), - "num_key_value_heads": int(entropy_model_params.get("n_kv_heads")) - if entropy_model_params.get("n_kv_heads") is not None - else None, - "max_position_embeddings": int(entropy_model_params.get("max_seqlen", 1024)), - "norm_eps": entropy_model_params.get("norm_eps", 1e-5), - "dropout": entropy_model_params.get("dropout", 0.0), - "rope_theta": entropy_model_params.get("rope_theta", 10000.0), - "attn_impl": entropy_model_params.get("attn_impl", "sdpa"), - "attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"), - "intermediate_size": patcher_intermediate_size, - } - - # Create encoder config - encoder_hidden_size = unified_config.get("dim_local_encoder", 1024) - encoder_multiple_of = unified_config.get("multiple_of", 256) - encoder_intermediate_size = encoder_multiple_of * ((int(8 * encoder_hidden_size / 3) + encoder_multiple_of - 1) // encoder_multiple_of) - - encoder_config = { - "vocab_size": unified_config.get("vocab_size", 256), - "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_encoder", False), - "cross_attn_k": unified_config.get("cross_attn_k", 2), - "hidden_size_global": unified_config.get("hidden_size_global", 2048), - "pm_size": unified_config.get("pm_size", 0), - "hidden_size": encoder_hidden_size, - "num_attention_heads": unified_config.get("n_heads_local_encoder", 16), - "num_key_value_heads": unified_config.get("n_kv_heads"), - "num_hidden_layers": unified_config.get("n_layers_local_encoder", 1), - "norm_eps": unified_config.get("norm_eps", 1e-5), - "dropout": unified_config.get("dropout", 0.0), - "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024), - "rope_theta": unified_config.get("rope_theta", 10000.0), - "rope_scaling": {"rope_type": "default"}, - "hidden_act": unified_config.get("hidden_act", "silu"), - "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), - "intermediate_size": encoder_intermediate_size, - } - - # Create decoder config - decoder_hidden_size = unified_config.get("dim_local_decoder", 1024) - decoder_multiple_of = unified_config.get("multiple_of", 256) - decoder_intermediate_size = decoder_multiple_of * ((int(8 * decoder_hidden_size / 3) + decoder_multiple_of - 1) // decoder_multiple_of) - - decoder_config = { - "vocab_size": unified_config.get("vocab_size", 256), - "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_decoder", False), - "cross_attn_k": unified_config.get("cross_attn_k", 2), - "hidden_size_global": unified_config.get("hidden_size_global", 2048), - "hidden_size": decoder_hidden_size, - "num_attention_heads": unified_config.get("n_heads_local_decoder", 16), - "num_key_value_heads": unified_config.get("n_kv_heads"), - "num_hidden_layers": unified_config.get("n_layers_local_decoder", 9), - "norm_eps": unified_config.get("norm_eps", 1e-5), - "dropout": unified_config.get("dropout", 0.0), - "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024), - "rope_theta": unified_config.get("rope_theta", 10000.0), - "rope_scaling": {"rope_type": "default"}, - "hidden_act": unified_config.get("hidden_act", "silu"), - "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), - "intermediate_size": decoder_intermediate_size, - } - - # Create global transformer config - global_hidden_size = unified_config.get("dim_global", 2048) - global_multiple_of = unified_config.get("multiple_of", 256) - global_intermediate_size = global_multiple_of * ((int(8 * global_hidden_size / 3) + global_multiple_of - 1) // global_multiple_of) - - global_config = { - "hidden_size": global_hidden_size, - "num_attention_heads": unified_config.get("n_heads_global", 16), - "num_key_value_heads": unified_config.get("n_kv_heads_global"), - "num_hidden_layers": unified_config.get("n_layers_global", 25), - "norm_eps": unified_config.get("norm_eps", 1e-5), - "dropout": unified_config.get("dropout", 0.0), - "max_position_embeddings": unified_config.get("max_seqlen", 1024), - "rope_theta": unified_config.get("rope_theta", 10000.0), - "rope_scaling": {"rope_type": "default"}, - "hidden_act": unified_config.get("hidden_act", "silu"), - "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), - "intermediate_size": global_intermediate_size, - } - - # Create main config with sub-configs - main_config_dict = { - "model_type": "blt", - "vocab_size": unified_config.get("vocab_size", 256), - "max_position_embeddings": unified_config.get("max_seqlen", 1024), - "patch_in_forward": True, - "realtime_patching": True, - "patching_mode": "entropy", - "patch_size": patch_size, - "patching_threshold": patcher_args.get("threshold", 0.5), - "patching_threshold_add": patcher_args.get("threshold_add", 0.0), - "max_patch_length": patcher_args.get("max_patch_length"), - "patching_batch_size": patcher_args.get("patching_batch_size", 1), - "patching_device": patcher_args.get("patching_device", "cuda"), - "monotonicity": patcher_args.get("monotonicity", False), - "cross_attn_k": unified_config.get("cross_attn_k", 2), - "encoder_hash_byte_group_size": unified_config.get("encoder_hash_byte_group_size"), - "encoder_hash_byte_group_vocab": unified_config.get("encoder_hash_byte_group_vocab", 30000), - "encoder_hash_byte_group_nb_functions": unified_config.get("encoder_hash_byte_group_nb_functions", 3), - "pm_size": unified_config.get("pm_size", 0), - "patcher_config": patcher_config, - "encoder_config": encoder_config, - "decoder_config": decoder_config, - "global_config": global_config, - } - - main_config_dict["tie_word_embeddings"] = False - - logger.info(f"Merged configuration with {len(main_config_dict)} parameters") - return main_config_dict - - -def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - component_mappings = { - ".attention.": ".self_attn.", - ".feed_forward.": ".mlp.", - ".attention_norm.": ".input_layernorm.", - ".ffn_norm.": ".post_attention_layernorm.", - ".tok_embeddings.": ".embed_tokens.", - ".cross_attn_norm_q.": ".q_norm.", - ".cross_attn_norm_kv.": ".k_norm.", - ".w1.": ".gate_proj.", - ".w2.": ".down_proj.", - ".w3.": ".up_proj.", - ".wq.": ".q_proj.", - ".wk.": ".k_proj.", - ".wv.": ".v_proj.", - ".wo.": ".o_proj.", - ".output.": ".lm_head.", - } - - new_state_dict = {} - - for old_key, tensor in state_dict.items(): - new_key = old_key - - for old_pattern, new_pattern in component_mappings.items(): - if old_pattern in new_key: - new_key = new_key.replace(old_pattern, new_pattern) - - new_state_dict[new_key] = tensor - - return new_state_dict - - -def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]: - main_weights = load_file(weights_path) - - entropy_weights = torch.load(entropy_weights_path, map_location="cpu", weights_only=True) - - if "model" in entropy_weights: - entropy_weights = entropy_weights["model"] - elif "state_dict" in entropy_weights: - entropy_weights = entropy_weights["state_dict"] - - unified_weights = main_weights.copy() - - for key, tensor in entropy_weights.items(): - patcher_key = f"patcher.{key}" - unified_weights[patcher_key] = tensor - - unified_weights = apply_weight_mapping(unified_weights) - - decoder_lm_head_key = "local_decoder.lm_head.weight" - top_lm_head_key = "lm_head.weight" - unified_weights[top_lm_head_key] = unified_weights[decoder_lm_head_key] - del unified_weights[decoder_lm_head_key] - - prefixed_weights = {} - for key, tensor in unified_weights.items(): - if key == top_lm_head_key: - prefixed_weights[key] = tensor - elif not key.startswith("model."): - prefixed_weights[f"model.{key}"] = tensor - else: - prefixed_weights[key] = tensor - - unified_weights = prefixed_weights - - return unified_weights - - -def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): - tokenizer_config = { - "tokenizer_class": "BltTokenizer", - "vocab_size": config.get("vocab_size", 256), - "model_max_length": config.get("max_seqlen", 1024), - "add_bos_token": True, - "add_eos_token": True, - "bos_token": "", - "eos_token": "", - "pad_token": "", - "unk_token": "", - } - - tokenizer_path = os.path.join(output_dir, "tokenizer_config.json") - with open(tokenizer_path, "w") as f: - json.dump(tokenizer_config, f, indent=2) - - -def push_to_hub( - local_dir: str, - repo_id: str, - commit_message: str = "Upload converted BLT model", - private: bool = False, - token: Optional[str] = None, -) -> None: - try: - upload_folder( - folder_path=local_dir, - repo_id=repo_id, - commit_message=commit_message, - repo_type="model", - token=token, - ) - logger.info(f"Successfully pushed model to {repo_id}") - - except Exception as e: - logger.error(f"Failed to push model to Hub: {e}") - raise - - -def convert_hf_blt_to_unified( - model_id: str, - output_dir: str, - config_name: str = "config.json", - weights_name: str = "model.bin", - cache_dir: Optional[str] = None, - push_to_hub_repo: Optional[str] = None, - hub_private: bool = False, - hub_token: Optional[str] = None, -) -> None: - # Download model files - config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir) - weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir) - entropy_params_path = hf_hub_download(repo_id=model_id, filename="entropy_model/params.json", cache_dir=cache_dir) - entropy_weights_path = hf_hub_download( - repo_id=model_id, filename="entropy_model/consolidated.pth", cache_dir=cache_dir - ) - - unified_config = merge_configurations(config_path, entropy_params_path) - unified_weights = merge_weights(weights_path, entropy_weights_path) - - os.makedirs(output_dir, exist_ok=True) - - config_path = os.path.join(output_dir, config_name) - with open(config_path, "w") as f: - json.dump(unified_config, f, indent=2) - - if weights_name.endswith(".bin"): - weights_name = weights_name.replace(".bin", ".safetensors") - - weights_path = os.path.join(output_dir, weights_name) - save_file(unified_weights, weights_path) - - create_tokenizer_config(output_dir, unified_config) - - logger.info(f"Conversion completed, model saved to: {output_dir}") - - if push_to_hub_repo: - push_to_hub( - local_dir=output_dir, - repo_id=push_to_hub_repo, - commit_message="Upload BLT model converted", - private=hub_private, - token=hub_token, - ) - - -def main(): - parser = argparse.ArgumentParser( - description="Convert BLT models from HuggingFace Hub format to unified format", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - parser.add_argument( - "--model_id", - type=str, - default="facebook/blt-1b", - ) - parser.add_argument( - "--output_dir", - type=str, - default="./blt_converted", - ) - parser.add_argument( - "--config_name", - type=str, - default="config.json", - ) - parser.add_argument( - "--weights_name", - type=str, - default="model.bin", - ) - parser.add_argument( - "--cache_dir", - type=str, - default=None, - ) - parser.add_argument( - "--debug", - action="store_true", - default=True, - ) - parser.add_argument( - "--push_to_hub", - type=str, - default=None, - ) - parser.add_argument( - "--hub_private", - action="store_true", - default=False, - ) - parser.add_argument( - "--hub_token", - type=str, - default="hf_token", - ) - - args = parser.parse_args() - - transformers_logging.set_verbosity_debug() - logging.basicConfig(level=logging.DEBUG) - - try: - convert_hf_blt_to_unified( - model_id=args.model_id, - output_dir=args.output_dir, - config_name=args.config_name, - weights_name=args.weights_name, - cache_dir=args.cache_dir, - push_to_hub_repo=args.push_to_hub, - hub_private=args.hub_private, - hub_token=args.hub_token, - ) - except Exception as e: - logger.error(f"Conversion failed: {e}") - raise - - -if __name__ == "__main__": - main() diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5fbeacd21d78..5cde03fc4b5a 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -22,6 +22,7 @@ from .albert import * from .align import * from .altclip import * + from .arcee import * from .aria import * from .audio_spectrogram_transformer import * from .auto import * @@ -46,8 +47,8 @@ from .blenderbot_small import * from .blip import * from .blip_2 import * - from .blt import * from .bloom import * + from .blt import * from .bridgetower import * from .bros import * from .byt5 import * @@ -65,6 +66,7 @@ from .cohere2 import * from .cohere2_vision import * from .colpali import * + from .colqwen2 import * from .conditional_detr import * from .convbert import * from .convnext import * @@ -92,6 +94,7 @@ from .depth_anything import * from .depth_pro import * from .detr import * + from .dia import * from .dialogpt import * from .diffllama import * from .dinat import * @@ -102,6 +105,7 @@ from .distilbert import * from .dit import * from .donut import * + from .dots1 import * from .dpr import * from .dpt import * from .efficientloftr import * @@ -172,6 +176,7 @@ from .janus import * from .jetmoe import * from .kosmos2 import * + from .kyutai_speech_to_text import * from .layoutlm import * from .layoutlmv2 import * from .layoutlmv3 import * @@ -207,6 +212,7 @@ from .mgp_str import * from .mimi import * from .ministral import * + from .minimax import * from .mistral import * from .mistral3 import * from .mixtral import * @@ -264,6 +270,7 @@ from .plbart import * from .poolformer import * from .pop2piano import * + from .prompt_depth_anything import * from .prophetnet import * from .pvt import * from .pvt_v2 import * @@ -276,6 +283,8 @@ from .qwen3_next import * from .qwen3_vl import * from .qwen3_vl_moe import * + from .qwen3 import * + from .qwen3_moe import * from .rag import * from .recurrent_gemma import * from .reformer import * @@ -300,6 +309,7 @@ from .seggpt import * from .sew import * from .sew_d import * + from .shieldgemma2 import * from .siglip import * from .siglip2 import * from .smolvlm import * @@ -318,6 +328,7 @@ from .swinv2 import * from .switch_transformers import * from .t5 import * + from .t5gemma import * from .table_transformer import * from .tapas import * from .textnet import * @@ -378,4 +389,4 @@ import sys _file = globals()["__file__"] - sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) \ No newline at end of file diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 52e32ddd6b77..3ddbfacf0311 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -73,6 +73,7 @@ ), ), ("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("arcee", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("aya_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)), ("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), @@ -157,6 +158,7 @@ ("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)), ("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)), ("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)), ( "cpm", @@ -193,6 +195,7 @@ "LlamaTokenizerFast" if is_tokenizers_available() else None, ), ), + ("dia", ("DiaTokenizer", None)), ( "deepseek_vl", ( @@ -275,6 +278,20 @@ "GemmaTokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "gemma3n", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "gemma3n_text", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), @@ -289,6 +306,10 @@ ("gpt_oss", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)), + ("granite", ("GPT2Tokenizer", None)), + ("granitemoe", ("GPT2Tokenizer", None)), + ("granitemoehybrid", ("GPT2Tokenizer", None)), + ("granitemoeshared", ("GPT2Tokenizer", None)), ("grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("helium", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), @@ -394,6 +415,13 @@ ), ), ("mgp-str", ("MgpstrTokenizer", None)), + ( + "minimax", + ( + "GPT2Tokenizer" if is_sentencepiece_available() else None, + "GPT2TokenizerFast" if is_tokenizers_available() else None, + ), + ), ( "mistral", ( @@ -645,6 +673,13 @@ "T5TokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "t5gemma", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("tapas", ("TapasTokenizer", None)), ("tapex", ("TapexTokenizer", None)), ("transfo-xl", ("TransfoXLTokenizer", None)), @@ -900,7 +935,7 @@ class AutoTokenizer: """ def __init__(self): - raise EnvironmentError( + raise OSError( "AutoTokenizer is designed to be instantiated " "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method." ) @@ -1178,4 +1213,4 @@ def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class), exist_ok=exist_ok) -__all__ = ["TOKENIZER_MAPPING", "AutoTokenizer"] +__all__ = ["TOKENIZER_MAPPING", "AutoTokenizer"] \ No newline at end of file diff --git a/src/transformers/models/blt/convert_blt_weights_to_hf.py b/src/transformers/models/blt/convert_blt_weights_to_hf.py index d025e09cbc31..aa68e6b09a9c 100644 --- a/src/transformers/models/blt/convert_blt_weights_to_hf.py +++ b/src/transformers/models/blt/convert_blt_weights_to_hf.py @@ -8,9 +8,8 @@ from huggingface_hub import hf_hub_download, upload_folder from safetensors.torch import load_file, save_file -from transformers.models.blt_wip.configuration_blt import BLTConfig -from transformers.models.blt_wip.modeling_blt import BLTModel -from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM +from transformers.models.blt.configuration_blt import BLTConfig +from transformers.models.blt.modeling_blt import BLTModel, BLTForCausalLM from transformers.utils import logging as transformers_logging @@ -394,7 +393,7 @@ def main(): config_name=args.config_name, weights_name=args.weights_name, cache_dir=args.cache_dir, - push_to_hub_repo=args.push_to_hub, + push_to_hub_repo=False, #args.push_to_hub, hub_private=args.hub_private, hub_token=args.hub_token, ) From 4e528374013dd8d7c1af009f9d39ca5523be1bd4 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 8 Jul 2025 13:20:04 +0000 Subject: [PATCH 065/139] ruff --- src/transformers/models/blt/__init__.py | 1 + .../models/blt/configuration_blt.py | 2 - .../models/blt/convert_blt_weights_to_hf.py | 14 +- src/transformers/models/blt/modeling_blt.py | 41 ++-- src/transformers/models/blt/modular_blt.py | 201 ++++++++++-------- .../models/blt/tokenization_blt.py | 18 +- tests/models/blt/test_modeling_blt.py | 6 +- 7 files changed, 146 insertions(+), 137 deletions(-) diff --git a/src/transformers/models/blt/__init__.py b/src/transformers/models/blt/__init__.py index c29d2aa5a8f0..703b81ecdd09 100644 --- a/src/transformers/models/blt/__init__.py +++ b/src/transformers/models/blt/__init__.py @@ -16,6 +16,7 @@ from ...utils import _LazyModule from ...utils.import_utils import define_import_structure + if TYPE_CHECKING: from .configuration_blt import * from .modeling_blt import * diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 15613e307dc5..30555f483b2c 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -14,8 +14,6 @@ # limitations under the License. """BLT model configuration""" -from enum import Enum -from typing import Union from ...configuration_utils import PretrainedConfig from ...utils import logging diff --git a/src/transformers/models/blt/convert_blt_weights_to_hf.py b/src/transformers/models/blt/convert_blt_weights_to_hf.py index aa68e6b09a9c..713f0e8ef112 100644 --- a/src/transformers/models/blt/convert_blt_weights_to_hf.py +++ b/src/transformers/models/blt/convert_blt_weights_to_hf.py @@ -2,14 +2,12 @@ import json import logging import os -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from huggingface_hub import hf_hub_download, upload_folder from safetensors.torch import load_file, save_file -from transformers.models.blt.configuration_blt import BLTConfig -from transformers.models.blt.modeling_blt import BLTModel, BLTForCausalLM from transformers.utils import logging as transformers_logging @@ -17,7 +15,7 @@ transformers_logging.set_verbosity_info() -def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]: +def merge_configurations(config_path: str, entropy_params_path: str) -> dict[str, Any]: logger.info("Merging configurations") with open(config_path, "r") as f: @@ -172,7 +170,7 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str return main_config_dict -def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: +def apply_weight_mapping(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: component_mappings = { ".attention.": ".self_attn.", ".feed_forward.": ".mlp.", @@ -205,7 +203,7 @@ def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch return new_state_dict -def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]: +def merge_weights(weights_path: str, entropy_weights_path: str) -> dict[str, torch.Tensor]: main_weights = load_file(weights_path) entropy_weights = torch.load(entropy_weights_path, map_location="cpu", weights_only=True) @@ -242,7 +240,7 @@ def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, tor return unified_weights -def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): +def create_tokenizer_config(output_dir: str, config: dict[str, Any]): tokenizer_config = { "tokenizer_class": "BltTokenizer", "vocab_size": config.get("vocab_size", 256), @@ -393,7 +391,7 @@ def main(): config_name=args.config_name, weights_name=args.weights_name, cache_dir=args.cache_dir, - push_to_hub_repo=False, #args.push_to_hub, + push_to_hub_repo=False, # args.push_to_hub, hub_private=args.hub_private, hub_token=args.hub_token, ) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index e089d2cf97a8..d8d485ff8935 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -14,13 +14,8 @@ # limitations under the License. """BLT model.""" -from ...utils import is_torch_flex_attn_available, logging -from typing import Callable, List, Optional, Tuple, Union - from enum import Enum - -from ...cache_utils import Cache -from ...activations import ACT2FN +from typing import Callable, Optional, Union import torch import torch.distributions @@ -28,23 +23,25 @@ import torch.nn as nn from torch.nn import functional as F +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation.utils import GenerationMixin +from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update - from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import is_torch_flex_attn_available, logging from .configuration_blt import ( BLTConfig, - BLTLocalEncoderConfig, - BLTLocalDecoderConfig, BLTGlobalTransformerConfig, + BLTLocalDecoderConfig, + BLTLocalEncoderConfig, BLTPatcherConfig, ) -from ...generation.utils import GenerationMixin -from ...modeling_outputs import CausalLMOutputWithPast if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask - from ...integrations.flex_attention import make_flex_block_causal_mask + logger = logging.get_logger(__name__) @@ -182,9 +179,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -396,7 +393,7 @@ def _prepare_patch_cross_attention_mask( patches_as_queries: bool = False, cross_attn_k: int = 1, dtype: torch.dtype = torch.float32, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """ Prepare cross-attention mask for patch-based attention, following mllama's robust approach. @@ -584,10 +581,10 @@ def forward( patch_embeds: Optional[torch.Tensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, - cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, ): """ """ if input_embeds is None: @@ -700,8 +697,8 @@ def forward( patch_embeds: Optional[torch.Tensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, ): batch_size, _, _ = embeds.shape @@ -771,12 +768,12 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() @@ -860,7 +857,7 @@ def forward( self, input_embeds: torch.Tensor, mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, - cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, ): batch_size, seq_len, _ = input_embeds.shape diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 19b3b2a3ae21..6219b392b791 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -14,60 +14,53 @@ # limitations under the License. """BLT modular model, inheriting from Mllama where appropriate.""" -from typing import Callable, List, Optional, Tuple, Union -from enum import Enum +from typing import Callable, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn import functional as F from ...cache_utils import Cache -from ...activations import ACT2FN -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_outputs import CausalLMOutputWithPast from ...generation.utils import GenerationMixin -from ...utils import logging, is_torch_flex_attn_available -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...utils import is_torch_flex_attn_available, logging # Import configuration classes from .configuration_blt import ( BLTConfig, - BLTLocalEncoderConfig, - BLTLocalDecoderConfig, BLTGlobalTransformerConfig, + BLTLocalDecoderConfig, + BLTLocalEncoderConfig, BLTPatcherConfig, ) + if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask - from ...integrations.flex_attention import make_flex_block_causal_mask + # Import from mllama for inheritance from ..mllama.modeling_mllama import ( - MllamaTextMLP, - MllamaTextRMSNorm, + MllamaPreTrainedModel, MllamaRotaryEmbedding, - MllamaTextCrossAttention, MllamaSelfAttentionDecoderLayer, - MllamaPreTrainedModel, + MllamaTextMLP, + MllamaTextRMSNorm, eager_attention_forward, - repeat_kv, - apply_rotary_pos_emb as mllama_apply_rotary_pos_emb, ) # Import other utility functions and classes from original BLT from .modeling_blt import ( PatchingModeEnum, - byte_group_hash_function, - rolling_polynomial_hash, - init_hash_embeddings, - compute_hash_embeddings, _prepare_patch_cross_attention_mask, - process_patch_lengths, apply_rotary_pos_emb, + compute_hash_embeddings, + init_hash_embeddings, + process_patch_lengths, ) + logger = logging.get_logger(__name__) @@ -75,6 +68,7 @@ # INHERITED COMPONENTS (minimal changes from Mllama) # ============================================================================== + class BLTMLP(MllamaTextMLP): pass @@ -91,8 +85,10 @@ class BLTRotaryEmbedding(MllamaRotaryEmbedding): # INHERITED BUT CUSTOMIZED COMPONENTS # ============================================================================== + class BLTPreTrainedModel(MllamaPreTrainedModel): """BLT PreTrainedModel inheriting from Mllama but with BLT-specific init.""" + config_class = BLTConfig base_model_prefix = "model" supports_gradient_checkpointing = True @@ -104,7 +100,7 @@ class BLTPreTrainedModel(MllamaPreTrainedModel): def _init_weights(self, module): if isinstance(module, nn.Linear): - std = getattr(module, '_custom_std', module.in_features ** (-0.5)) + std = getattr(module, "_custom_std", module.in_features ** (-0.5)) nn.init.trunc_normal_( module.weight, mean=0.0, @@ -114,9 +110,9 @@ def _init_weights(self, module): ) if module.bias is not None: nn.init.zeros_(module.bias) - + elif isinstance(module, nn.Embedding): - std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) + std = getattr(module, "_custom_std", module.embedding_dim ** (-0.5)) nn.init.trunc_normal_( module.weight, mean=0.0, @@ -124,26 +120,26 @@ def _init_weights(self, module): a=-3 * std, b=3 * std, ) - + elif isinstance(module, BLTModel): if module.encoder_hash_tok_embedding is not None: emb_std = module.config.encoder_config.hidden_size ** (-0.5) for emb in module.encoder_hash_tok_embedding: emb._custom_std = emb_std - + elif isinstance(module, BLTLocalEncoder): if module.patch_embedding_projection is not None: module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) - + elif isinstance(module, BLTLocalDecoder): if module.patch_embedding_projection is not None: module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) - + elif isinstance(module, BLTPatcher): emb_std = module.config.hidden_size ** (-0.5) module.embed_tokens._custom_std = emb_std module.lm_head._custom_std = emb_std - + elif isinstance(module, BLTForCausalLM): if module.lm_head is not None: module.lm_head._custom_std = module.config.decoder_config.hidden_size ** (-0.5) @@ -151,7 +147,7 @@ def _init_weights(self, module): class BLTSelfAttention(nn.Module): """BLT Self Attention that could inherit from Mllama but has some BLT-specific patterns.""" - + def __init__(self, config, layer_idx: int): super().__init__() self.config = config @@ -161,7 +157,7 @@ def __init__(self, config, layer_idx: int): self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.rope_theta = config.rope_theta self.layer_idx = layer_idx @@ -171,30 +167,30 @@ def __init__(self, config, layer_idx: int): self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, - output_attentions: bool = False, - use_cache: bool = False, - past_key_value=None, - cache_position=None, - **kwargs, + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, ): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} @@ -236,16 +232,19 @@ class BLTTransformerLayer(MllamaSelfAttentionDecoderLayer): # ============================================================================== -# BLT-SPECIFIC COMPONENTS (no Mllama equivalent) +# BLT-SPECIFIC COMPONENTS (no Mllama equivalent) # ============================================================================== + class BLTLocalEncoder(nn.Module): def __init__(self, config: BLTLocalEncoderConfig): super().__init__() - + self.config = config - - self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + + self.layers = nn.ModuleList( + [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.rotary_emb = BLTRotaryEmbedding(config=config) @@ -271,10 +270,10 @@ def forward( patch_embeds: Optional[torch.Tensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, - cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, ): """ """ if input_embeds is None: @@ -282,10 +281,10 @@ def forward( batch_size, _, _ = input_embeds.shape - hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) + hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) @@ -296,7 +295,9 @@ def forward( if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) + patch_embeds = patch_embeds.reshape( + batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size + ) layer_idx = idx if self.config.cross_attn_all_layers else 0 cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( @@ -312,7 +313,7 @@ def forward( encoder_cross_states = patch_embeds return hidden_states, encoder_cross_states - + def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): """ Reduce variable length patches to single embedding per patch @@ -328,7 +329,9 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) - reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device) + reduced_embeddings = torch.zeros( + (batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) reduced_embeddings = reduced_embeddings.scatter_reduce( src=hidden_states, dim=1, @@ -347,9 +350,11 @@ def __init__(self, config: BLTLocalDecoderConfig): # Extract config values to instance attributes self.config = config - self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove + self.cross_attn_decoder = True # config.cross_attn_decoder #TODO: maybe remove - self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.rotary_emb = BLTRotaryEmbedding(config=config) @@ -374,7 +379,6 @@ def __init__(self, config: BLTLocalDecoderConfig): # bias=False, # ) - def forward( self, tokens: torch.Tensor, @@ -382,21 +386,23 @@ def forward( patch_embeds: Optional[torch.Tensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, ): batch_size, _, _ = embeds.shape hidden_states = embeds patch_embeds = self.patch_embedding_projection(patch_embeds) - patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) + patch_embeds = patch_embeds.reshape( + batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size + ) if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for i, layer in enumerate(self.layers): @@ -417,7 +423,7 @@ def forward( hidden_states = layer_outputs[0] logits = self.norm(hidden_states) - # logits = self.lm_head(logits) + # logits = self.lm_head(logits) return logits, cache @@ -434,7 +440,7 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.dropout = config.dropout self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) @@ -451,16 +457,16 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() - - query_states = self.q_norm(hidden_states) # BLT normalizes first + + query_states = self.q_norm(hidden_states) # BLT normalizes first query_states = self.q_proj(query_states) if cross_attention_states is not None: @@ -536,12 +542,11 @@ def __init__(self, config: BLTGlobalTransformerConfig): self.rotary_emb = BLTRotaryEmbedding(config=config) - def forward( self, input_embeds: torch.Tensor, mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, - cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, ): batch_size, seq_len, _ = input_embeds.shape @@ -550,7 +555,7 @@ def forward( hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, layer in enumerate(self.layers): layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) @@ -617,11 +622,13 @@ def forward( else: patch_lengths = process_patch_lengths( torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device), - self.config.max_patch_length + self.config.max_patch_length, ) patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) encoder_embeds = compute_hash_embeddings( - tokens, self.local_encoder, self.encoder_hash_tok_embedding, + tokens, + self.local_encoder, + self.encoder_hash_tok_embedding, self.config.encoder_hash_byte_group_nb_functions, self.config.encoder_hash_byte_group_size, self.config.encoder_hash_byte_group_vocab, @@ -644,7 +651,12 @@ def forward( ) decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( - decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, encoder_embeds.dtype + decoder_patch_ids, + patch_lengths.shape[1], + sequence_length, + False, + self.config.cross_attn_k, + encoder_embeds.dtype, ) output, _ = self.local_decoder( tokens=tokens, @@ -660,18 +672,21 @@ def forward( else: return (output, None, None) return output - + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: """Convert patch lengths to patch IDs for each token position.""" batch_size = patch_lengths.shape[0] - patch_starts = torch.cat([ - torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), - patch_lengths.cumsum(dim=-1)[:, :-1] - ], dim=-1) - + patch_starts = torch.cat( + [ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1], + ], + dim=-1, + ) + token_positions = torch.arange(seq_len, device=patch_lengths.device) return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 - + class BLTPatcher(BLTPreTrainedModel): def __init__(self, config: BLTPatcherConfig): @@ -680,11 +695,10 @@ def __init__(self, config: BLTPatcherConfig): self.rotary_emb = BLTRotaryEmbedding(config=self.config) self.layers = nn.ModuleList() - + for layer_idx in range(self.config.num_hidden_layers): self.layers.append(BLTTransformerLayer(self.config, layer_idx)) - self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps) @@ -704,7 +718,6 @@ def forward( patching_batch_size: int = 1, device: Optional[str] = None, ): - # Handle chunked processing for entropy calculation entropies = [] predictions = [] @@ -729,15 +742,15 @@ def forward( batch_size, _, _ = input_embeds.shape position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - - position_embeddings = self.rotary_emb(hidden_states, position_ids) - + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for i, layer in enumerate(self.layers): layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) hidden_states = layer_outputs[0] logits = self.lm_head(self.norm(hidden_states)) - logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] + logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] predictions.append(logits) prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() entropies.append(prediction_entropies) @@ -758,7 +771,9 @@ def forward( ) else: # Default to byte-level patching - patch_lengths = torch.ones((batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device) + patch_lengths = torch.ones( + (batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device + ) patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) return concat_entropies, patch_lengths, concat_predictions @@ -780,7 +795,9 @@ def patch_lengths_from_entropies( batch_size = entropies.shape[0] # Always include token 0 and 1 as starting tokens - init_tokens = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + init_tokens = ( + torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + ) offset = init_tokens.shape[1] # Ignore first token entropy (BOS) @@ -923,7 +940,7 @@ def forward( "BLTModel", "BLTPatcher", "BLTLocalEncoder", - "BLTLocalDecoder", + "BLTLocalDecoder", "BLTGlobalTransformer", "BLTTransformerLayer", "BLTForCausalLM", @@ -932,4 +949,4 @@ def forward( "BLTRotaryEmbedding", "BLTSelfAttention", "BLTCrossAttention", -] \ No newline at end of file +] diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py index 0669b0299d9e..1c6c39544a59 100644 --- a/src/transformers/models/blt/tokenization_blt.py +++ b/src/transformers/models/blt/tokenization_blt.py @@ -14,14 +14,14 @@ # limitations under the License. """Tokenization classes for BLT.""" -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Optional from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging if TYPE_CHECKING: - from ...tokenization_utils_base import TextInput + pass logger = logging.get_logger(__name__) @@ -150,7 +150,7 @@ def _convert_id_to_token(self, index: int) -> str: return self.decoder.get(index, str(self.unk_token)) - def convert_tokens_to_string(self, tokens: List[str]) -> str: + def convert_tokens_to_string(self, tokens: list[str]) -> str: """Converts a sequence of tokens to a single string.""" byte_values = [] @@ -173,13 +173,13 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str: return bytes(byte_values).decode("utf-8", errors="ignore") - def _tokenize(self, text: str, **kwargs) -> List[str]: + def _tokenize(self, text: str, **kwargs) -> list[str]: """Converts a string to a list of tokens. For BLT, we work directly with byte values.""" return [str(byte_val) for byte_val in text.encode("utf-8", errors="ignore")] def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: """ Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating and adding special tokens. A BLT sequence has the following format: @@ -204,8 +204,8 @@ def build_inputs_with_special_tokens( return bos + token_ids_0 + eos + token_ids_1 + eos def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False - ) -> List[int]: + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. @@ -237,7 +237,7 @@ def get_vocab_size(self) -> int: """Get vocab size like the original tokenizer.""" return self.vocab_size - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: # BLT doesn't require external vocabulary files since it uses byte-level tokenization return () diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 5910d46eabbe..0e673f2822bc 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -15,9 +15,7 @@ import unittest -from packaging import version - -from transformers import AutoTokenizer, StaticCache, is_torch_available +from transformers import AutoTokenizer, is_torch_available from transformers.testing_utils import ( cleanup, require_read_token, @@ -34,7 +32,7 @@ if is_torch_available(): import torch - from transformers import BLTConfig, BLTForCausalLM, BLTModel, BLTTokenizer + from transformers import BLTConfig, BLTForCausalLM, BLTModel from transformers.models.blt.modeling_blt import BLTRotaryEmbedding From fed995884b3df22d5368acf725c59eb2cfa964a7 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 9 Jul 2025 10:25:23 +0000 Subject: [PATCH 066/139] adding correct basemodel output and updating config with checkpoint vals (for testing) --- .../models/blt/configuration_blt.py | 81 +++++----- src/transformers/models/blt/modeling_blt.py | 140 +++++++++++------- 2 files changed, 130 insertions(+), 91 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 30555f483b2c..c00ee31b7a7e 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -31,21 +31,22 @@ class BLTLocalEncoderConfig(PretrainedConfig): def __init__( self, - vocab_size=256, - cross_attn_all_layers=True, + vocab_size=260, + cross_attn_all_layers=False, cross_attn_k=2, hidden_size_global=2048, - hidden_size=512, - num_attention_heads=8, + pm_size=0, + hidden_size=1024, + num_attention_heads=16, num_key_value_heads=None, - num_hidden_layers=8, + num_hidden_layers=1, norm_eps=1e-5, dropout=0.0, - max_position_embeddings=1024, - rope_theta=10000.0, + max_position_embeddings=24576, + rope_theta=500000.0, rope_scaling=None, hidden_act="silu", - intermediate_size=None, + intermediate_size=2816, _attn_implementation="sdpa", **kwargs, ): @@ -53,6 +54,7 @@ def __init__( self.cross_attn_all_layers = cross_attn_all_layers self.cross_attn_k = cross_attn_k self.hidden_size_global = hidden_size_global + self.pm_size = pm_size self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads or num_attention_heads @@ -80,21 +82,21 @@ class BLTLocalDecoderConfig(PretrainedConfig): def __init__( self, - vocab_size=256, + vocab_size=260, cross_attn_all_layers=True, cross_attn_k=2, hidden_size_global=2048, - hidden_size=512, - num_attention_heads=8, + hidden_size=1024, + num_attention_heads=16, num_key_value_heads=None, - num_hidden_layers=8, + num_hidden_layers=9, norm_eps=1e-5, dropout=0.0, - max_position_embeddings=1024, - rope_theta=10000.0, + max_position_embeddings=24576, + rope_theta=500000.0, rope_scaling=None, hidden_act="silu", - intermediate_size=None, + intermediate_size=2816, _attn_implementation="sdpa", **kwargs, ): @@ -130,17 +132,17 @@ class BLTGlobalTransformerConfig(PretrainedConfig): def __init__( self, - hidden_size=512, - num_attention_heads=8, + hidden_size=2048, + num_attention_heads=16, num_key_value_heads=None, - num_hidden_layers=8, + num_hidden_layers=25, norm_eps=1e-5, dropout=0.0, - max_position_embeddings=1024, - rope_theta=10000.0, + max_position_embeddings=4096, + rope_theta=500000.0, rope_scaling=None, hidden_act="silu", - intermediate_size=None, + intermediate_size=5632, _attn_implementation="sdpa", **kwargs, ): @@ -201,18 +203,18 @@ class BLTPatcherConfig(PretrainedConfig): def __init__( self, - vocab_size=256, - hidden_size=512, - num_hidden_layers=8, - num_attention_heads=8, + vocab_size=260, + hidden_size=768, + num_hidden_layers=14, + num_attention_heads=12, num_key_value_heads=None, - max_position_embeddings=1024, + max_position_embeddings=8192, norm_eps=1e-5, dropout=0.0, rope_theta=10000.0, _attn_implementation="sdpa", - attn_bias_type="causal", - intermediate_size=None, + attn_bias_type="local_block_causal", + intermediate_size=2048, **kwargs, ): self.vocab_size = vocab_size @@ -313,18 +315,18 @@ class BLTConfig(PretrainedConfig): def __init__( self, - vocab_size=256, - max_position_embeddings=1024, - patch_in_forward=False, - patch_size=None, - patching_mode=None, - patching_threshold=None, + vocab_size=260, + max_position_embeddings=4096, + patch_in_forward=True, + patch_size=4, + patching_mode="entropy", + patching_threshold=1.335442066192627, patching_batch_size=1, max_patch_length=None, cross_attn_k=2, encoder_hash_byte_group_size=None, - encoder_hash_byte_group_vocab=30000, - encoder_hash_byte_group_nb_functions=3, + encoder_hash_byte_group_vocab=500002, + encoder_hash_byte_group_nb_functions=1, patcher_config=None, encoder_config=None, decoder_config=None, @@ -346,12 +348,17 @@ def __init__( self.patching_threshold = patching_threshold self.patching_batch_size = patching_batch_size self.max_patch_length = max_patch_length + self.patching_device = kwargs.get("patching_device", "cuda") + self.realtime_patching = kwargs.get("realtime_patching", True) + self.patching_threshold_add = kwargs.get("patching_threshold_add", None) + self.monotonicity = kwargs.get("monotonicity", False) + self.pm_size = kwargs.get("pm_size", 0) # Cross attention configurations self.cross_attn_k = cross_attn_k # Encoder configurations - self.encoder_hash_byte_group_size = encoder_hash_byte_group_size or [2, 3, 4] + self.encoder_hash_byte_group_size = encoder_hash_byte_group_size or [3, 4, 5, 6, 7, 8] self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index d8d485ff8935..ec028570bb2f 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -26,7 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation.utils import GenerationMixin -from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import is_torch_flex_attn_available, logging @@ -43,7 +43,6 @@ from torch.nn.attention.flex_attention import BlockMask - logger = logging.get_logger(__name__) @@ -260,12 +259,13 @@ def __init__(self, config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - past_key_value=None, - cache_position=None, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ): bsz, q_len, _ = hidden_states.size() @@ -278,17 +278,20 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = position_embeddings - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = ( + {"sin": sin, "cos": cos, "cache_position": cache_position} + if position_embeddings is not None + else {"cache_position": cache_position} + ) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - output_attentions = False if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -576,15 +579,23 @@ def __init__(self, config: BLTLocalEncoderConfig): def forward( self, - input_ids: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, input_embeds: Optional[torch.Tensor] = None, patch_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, + **kwargs, ): """ """ if input_embeds is None: @@ -692,13 +703,21 @@ def __init__(self, config: BLTLocalDecoderConfig): def forward( self, - tokens: torch.Tensor, - embeds: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, + embeds: Optional[torch.Tensor] = None, patch_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, + **kwargs, ): batch_size, _, _ = embeds.shape @@ -712,7 +731,7 @@ def forward( if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds - position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + position_ids = torch.arange(input_ids.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) @@ -772,6 +791,7 @@ def forward( output_attentions: bool = False, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -856,8 +876,16 @@ def __init__(self, config: BLTGlobalTransformerConfig): def forward( self, input_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, + **kwargs, ): batch_size, seq_len, _ = input_embeds.shape @@ -954,47 +982,47 @@ def __init__(self, config: BLTConfig): def forward( self, - tokens: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, patch_lengths: Optional[torch.Tensor] = None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - cache_position=None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, - ): + ) -> Union[BaseModelOutputWithPast, tuple]: """ Args: - tokens (torch.Tensor): Input token ids. + input_ids (torch.LongTensor, optional): Input token ids. patch_lengths (Optional[torch.Tensor]): Patch lengths for patching. attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Ignored, for compatibility. Returns: - torch.Tensor: Final hidden states (as before). + Union[BaseModelOutputWithPast, tuple]: Model outputs. """ - batch_size, sequence_length = tokens.shape + batch_size, sequence_length = input_ids.shape # Handle patching if patch_lengths is None: if self.config.patching_mode == PatchingModeEnum.entropy: _, patch_lengths, _ = self.patcher( - tokens, + input_ids, patch_size=self.config.patch_size, threshold=self.config.patching_threshold, max_patch_length=self.config.max_patch_length, patching_batch_size=self.config.patching_batch_size, - device=tokens.device, + device=input_ids.device, ) else: patch_lengths = process_patch_lengths( - torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device), + torch.ones((batch_size, sequence_length + 1), dtype=input_ids.dtype, device=input_ids.device), self.config.max_patch_length, ) patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) encoder_embeds = compute_hash_embeddings( - tokens, + input_ids, self.local_encoder, self.encoder_hash_tok_embedding, self.config.encoder_hash_byte_group_nb_functions, @@ -1005,7 +1033,7 @@ def forward( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) encoder_hidden_states, encoder_cross_states = self.local_encoder( - input_ids=tokens, + input_ids=input_ids, input_embeds=encoder_embeds, patch_embeds=None, cross_mask=cross_attn_mask_enc, @@ -1027,19 +1055,23 @@ def forward( encoder_embeds.dtype, ) output, _ = self.local_decoder( - tokens=tokens, + input_ids=input_ids, embeds=encoder_hidden_states, patch_embeds=global_hidden_states, mask=None, cross_mask=cross_attn_mask_dec, full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, ) - if output_hidden_states or output_attentions: - if return_dict: - return {"last_hidden_state": output, "hidden_states": None, "attentions": None} - else: - return (output, None, None) - return output + + if not return_dict: + return (output, past_key_values) if use_cache else (output,) + + return BaseModelOutputWithPast( + last_hidden_state=output, + past_key_values=past_key_values if use_cache else None, + hidden_states=None, # BLT doesn't currently support output_hidden_states + attentions=None, # BLT doesn't currently support output_attentions + ) def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: """Convert patch lengths to patch IDs for each token position.""" @@ -1240,26 +1272,26 @@ def get_decoder(self): def forward( self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - cache_position=None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, - ): + ) -> Union[CausalLMOutputWithPast, tuple]: """ Args: input_ids (torch.LongTensor): Input token ids. attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Standard transformers arguments. labels (torch.LongTensor, optional): Labels for language modeling loss. Returns: - CausalLMOutputWithPast or tuple: Standard transformers output. + Union[CausalLMOutputWithPast, tuple]: Standard transformers output. """ # Route only input_ids to BLTModel (as tokens) hidden_states = self.model( From 3ab48a664bf308c529eca16d88be535de834997a Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 10 Jul 2025 08:35:29 +0000 Subject: [PATCH 067/139] BLTModelTests git status --- src/transformers/models/blt/modeling_blt.py | 139 ++++++++++++++++---- tests/models/blt/test_modeling_blt.py | 112 ++++++++++++++++ 2 files changed, 229 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index ec028570bb2f..b922282925c5 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -598,6 +598,10 @@ def forward( **kwargs, ): """ """ + # Initialize output collections + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + if input_embeds is None: input_embeds = self.embed_tokens(input_ids) @@ -611,9 +615,15 @@ def forward( hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for idx, layer in enumerate(self.layers): - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) patch_embeds = self.patch_embedding_projection(patch_embeds) @@ -634,7 +644,7 @@ def forward( patch_embeds = patch_embeds + cross_attention_output encoder_cross_states = patch_embeds - return hidden_states, encoder_cross_states + return hidden_states, encoder_cross_states, all_hidden_states, all_attentions def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): """ @@ -719,6 +729,10 @@ def forward( cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, **kwargs, ): + # Initialize output collections + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + batch_size, _, _ = embeds.shape hidden_states = embeds @@ -736,6 +750,9 @@ def forward( hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + if i == 0 or self.config.cross_attn_all_layers: # Use cross attention to extract info from patch_embeds into hidden_states cross_attention_output, _, _ = self.cross_attn_layers[i]( @@ -749,12 +766,15 @@ def forward( ) hidden_states = hidden_states + cross_attention_output - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + logits = self.norm(hidden_states) # logits = self.lm_head(logits) - return logits, cache + return logits, all_hidden_states, all_attentions class BLTCrossAttention(nn.Module): @@ -887,6 +907,10 @@ def forward( cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, **kwargs, ): + # Initialize output collections + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + batch_size, seq_len, _ = input_embeds.shape hidden_states = input_embeds @@ -897,10 +921,16 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, layer in enumerate(self.layers): - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] - return hidden_states, cache + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + return hidden_states, all_hidden_states, all_attentions class BLTPreTrainedModel(PreTrainedModel): @@ -1003,10 +1033,19 @@ def forward( Returns: Union[BaseModelOutputWithPast, tuple]: Model outputs. """ + # Set defaults from config when parameters are None + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Initialize collections to None - we will ONLY collect from decoder + all_hidden_states = None + all_attentions = None + batch_size, sequence_length = input_ids.shape # Handle patching if patch_lengths is None: - if self.config.patching_mode == PatchingModeEnum.entropy: + if self.config.patching_mode == PatchingModeEnum.entropy and self.patcher is not None: _, patch_lengths, _ = self.patcher( input_ids, patch_size=self.config.patch_size, @@ -1032,7 +1071,7 @@ def forward( cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) - encoder_hidden_states, encoder_cross_states = self.local_encoder( + encoder_hidden_states, encoder_cross_states, encoder_hidden_states_all, encoder_attentions_all = self.local_encoder( input_ids=input_ids, input_embeds=encoder_embeds, patch_embeds=None, @@ -1040,11 +1079,17 @@ def forward( full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, + output_attentions=False, # Don't collect encoder attentions + output_hidden_states=False, # Don't collect encoder hidden states ) + global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - global_hidden_states, _ = self.global_transformer( + global_hidden_states, global_hidden_states_all, global_attentions_all = self.global_transformer( input_embeds=global_hidden_states, + output_attentions=False, # Don't collect global transformer attentions + output_hidden_states=False, # Don't collect global transformer hidden states ) + decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( decoder_patch_ids, @@ -1054,23 +1099,43 @@ def forward( self.config.cross_attn_k, encoder_embeds.dtype, ) - output, _ = self.local_decoder( + output, decoder_hidden_states_all, decoder_attentions_all = self.local_decoder( input_ids=input_ids, embeds=encoder_hidden_states, patch_embeds=global_hidden_states, mask=None, cross_mask=cross_attn_mask_dec, full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, + output_attentions=output_attentions, # Only collect decoder attentions + output_hidden_states=output_hidden_states, # Only collect decoder hidden states ) + # Only use decoder outputs (which match the expected num_hidden_layers) + if output_hidden_states and decoder_hidden_states_all is not None: + all_hidden_states = decoder_hidden_states_all + else: + all_hidden_states = None + + if output_attentions and decoder_attentions_all is not None: + all_attentions = decoder_attentions_all + else: + all_attentions = None + if not return_dict: - return (output, past_key_values) if use_cache else (output,) + output = (output,) + if past_key_values is not None: + output = output + (past_key_values,) + if all_hidden_states is not None: + output = output + (all_hidden_states,) + if all_attentions is not None: + output = output + (all_attentions,) + return output return BaseModelOutputWithPast( last_hidden_state=output, past_key_values=past_key_values if use_cache else None, - hidden_states=None, # BLT doesn't currently support output_hidden_states - attentions=None, # BLT doesn't currently support output_attentions + hidden_states=all_hidden_states, + attentions=all_attentions, ) def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: @@ -1293,8 +1358,13 @@ def forward( Returns: Union[CausalLMOutputWithPast, tuple]: Standard transformers output. """ + # Set defaults from config when parameters are None + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.return_dict + # Route only input_ids to BLTModel (as tokens) - hidden_states = self.model( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1307,12 +1377,37 @@ def forward( cache_position=cache_position, **kwargs, ) - if isinstance(hidden_states, dict): - sequence_output = hidden_states["last_hidden_state"] - elif isinstance(hidden_states, tuple): - sequence_output = hidden_states[0] + + if isinstance(outputs, dict): + sequence_output = outputs["last_hidden_state"] + past_key_values = outputs.get("past_key_values") + hidden_states = outputs.get("hidden_states") + attentions = outputs.get("attentions") + elif isinstance(outputs, tuple): + sequence_output = outputs[0] + # Handle tuple format: (output, past_key_values?, hidden_states?, attentions?) + idx = 1 + past_key_values = None + hidden_states = None + attentions = None + + if len(outputs) > idx and use_cache: + past_key_values = outputs[idx] + idx += 1 + + if len(outputs) > idx and output_hidden_states: + hidden_states = outputs[idx] + idx += 1 + + if len(outputs) > idx and output_attentions: + attentions = outputs[idx] + idx += 1 else: - sequence_output = hidden_states + sequence_output = outputs + past_key_values = None + hidden_states = None + attentions = None + logits = self.lm_head(sequence_output) loss = None if labels is not None: @@ -1329,9 +1424,9 @@ def forward( return CausalLMOutputWithPast( loss=loss, logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, + past_key_values=past_key_values, + hidden_states=hidden_states, + attentions=attentions, ) diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 0e673f2822bc..006cf353901d 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -42,6 +42,118 @@ class BLTModelTester(CausalLMModelTester): base_model_class = BLTModel causal_lm_class = BLTForCausalLM + def __init__( + self, + parent, + ignore_index=-100, + seq_length=7, + is_training=True, + ): + super().__init__(parent) + self.parent = parent + self.ignore_index = ignore_index + self.seq_length = seq_length + self.is_training = is_training + self.batch_size = 3 + + # Common parameters for all configs + self.hidden_size = 32 + self.num_hidden_layers = 2 + self.num_attention_heads = 4 + self.num_key_value_heads = 4 + self.intermediate_size = 37 + self.hidden_act = "silu" + self.max_position_embeddings = 512 + self.vocab_size = 99 + self.rope_theta = 500000.0 + self.rope_scaling = {"rope_type": "default"} + self.norm_eps = 1e-5 + self.dropout = 0.0 + + self.patcher_config = { + "hidden_size": self.hidden_size, + "num_hidden_layers": self.num_hidden_layers, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "intermediate_size": self.intermediate_size, + "max_position_embeddings": self.max_position_embeddings, + "rope_theta": self.rope_theta, + "rope_scaling": self.rope_scaling, + "hidden_act": self.hidden_act, + "norm_eps": self.norm_eps, + "dropout": self.dropout, + "_attn_implementation": "eager" + } + + self.encoder_config = { + "hidden_size": self.hidden_size, + "num_hidden_layers": self.num_hidden_layers, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "intermediate_size": self.intermediate_size, + "max_position_embeddings": self.max_position_embeddings, + "rope_theta": self.rope_theta, + "rope_scaling": self.rope_scaling, + "hidden_act": self.hidden_act, + "norm_eps": self.norm_eps, + "dropout": self.dropout, + "_attn_implementation": "eager" + } + + self.decoder_config = { + "vocab_size": self.vocab_size, + "hidden_size": self.hidden_size, + "hidden_size_global": self.hidden_size * 2, # Must match global transformer output size + "num_hidden_layers": self.num_hidden_layers, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "intermediate_size": self.intermediate_size, + "max_position_embeddings": self.max_position_embeddings, + "rope_theta": self.rope_theta, + "rope_scaling": self.rope_scaling, + "hidden_act": self.hidden_act, + "norm_eps": self.norm_eps, + "dropout": self.dropout, + "_attn_implementation": "eager" + } + + self.global_config = { + "hidden_size": self.hidden_size * 2, # Double the hidden size for global transformer + "num_hidden_layers": self.num_hidden_layers, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "intermediate_size": self.intermediate_size, + "max_position_embeddings": self.max_position_embeddings, + "rope_theta": self.rope_theta, + "rope_scaling": self.rope_scaling, + "hidden_act": self.hidden_act, + "norm_eps": self.norm_eps, + "dropout": self.dropout, + "_attn_implementation": "eager" + } + + def get_config(self): + return BLTConfig( + vocab_size=self.vocab_size, + max_position_embeddings=self.max_position_embeddings, + patch_in_forward=False, # Disable patching for tests + patch_size=4, + patching_mode="entropy", + patching_threshold=1.335442066192627, + patching_batch_size=1, + max_patch_length=None, + cross_attn_k=2, + encoder_hash_byte_group_size=[3, 4, 5, 6, 7, 8], + encoder_hash_byte_group_vocab=500002, + encoder_hash_byte_group_nb_functions=1, + patcher_config=self.patcher_config, + encoder_config=self.encoder_config, + decoder_config=self.decoder_config, + global_config=self.global_config, + tie_word_embeddings=False, + _attn_implementation="eager" + ) + @require_torch class BLTModelTest(CausalLMModelTest, unittest.TestCase): From 6f9319928533f276f36f3621da5c8b2cee861688 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 10 Jul 2025 15:29:25 +0000 Subject: [PATCH 068/139] enabling inputs_embeds, although won't be equal to input_ids since need ids for patching logic --- .../models/blt/configuration_blt.py | 20 ++- src/transformers/models/blt/modeling_blt.py | 119 +++++++++--------- tests/models/blt/test_modeling_blt.py | 65 +++++++++- 3 files changed, 137 insertions(+), 67 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index c00ee31b7a7e..84d3ee302259 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -65,7 +65,7 @@ def __init__( self.dropout = dropout self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta - self.rope_scaling = rope_scaling or {"rope_type": "default"} + self.rope_scaling = rope_scaling or {"type": "default"} self.hidden_act = hidden_act super().__init__(**kwargs) @@ -114,7 +114,7 @@ def __init__( self.dropout = dropout self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta - self.rope_scaling = rope_scaling or {"rope_type": "default"} + self.rope_scaling = rope_scaling or {"type": "default"} self.hidden_act = hidden_act self._attn_implementation = _attn_implementation @@ -156,7 +156,7 @@ def __init__( self.dropout = dropout self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta - self.rope_scaling = rope_scaling or {"rope_type": "default"} + self.rope_scaling = rope_scaling or {"type": "default"} self.hidden_act = hidden_act super().__init__(**kwargs) @@ -230,7 +230,7 @@ def __init__( self.attn_bias_type = attn_bias_type self.hidden_act = "silu" # BLT uses silu activation self.intermediate_size = intermediate_size or int(8 * self.hidden_size / 3) - self.rope_scaling = {"rope_type": "default"} + self.rope_scaling = {"type": "default"} super().__init__(**kwargs) self._attn_implementation = _attn_implementation @@ -333,6 +333,9 @@ def __init__( global_config=None, tie_word_embeddings=False, _attn_implementation="sdpa", + initializer_range=0.02, + rope_theta=500000.0, + rope_scaling=None, **kwargs, ): # Basic model configuration @@ -340,6 +343,9 @@ def __init__( self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self._attn_implementation = _attn_implementation + self.initializer_range = initializer_range + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling or {"type": "default"} # Patching configuration self.patch_in_forward = patch_in_forward @@ -397,6 +403,12 @@ def __init__( super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + # Add decoder config attributes to main config for compatibility with tests + # These mirror the decoder config attributes since the main model interface uses the decoder + # self.hidden_size = self.decoder_config.hidden_size + # self.num_hidden_layers = self.decoder_config.num_hidden_layers + # self.num_attention_heads = self.decoder_config.num_attention_heads + __all__ = [ "BLTConfig", diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index b922282925c5..fca2ca79e8eb 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -524,7 +524,7 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optiona class BLTRotaryEmbedding(nn.Module): def __init__(self, config, device=None): super().__init__() - self.rope_type = config.rope_scaling["rope_type"] + self.rope_type = config.rope_scaling.get("type", "default") self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings @@ -609,7 +609,7 @@ def forward( hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) - position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_ids = torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) @@ -705,12 +705,6 @@ def __init__(self, config: BLTLocalDecoderConfig): BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) - # self.lm_head = nn.Linear( - # config.hidden_size, - # config.vocab_size, - # bias=False, - # ) - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -745,7 +739,9 @@ def forward( if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds - position_ids = torch.arange(input_ids.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + # Use sequence length from embeds (standard transformers pattern) + seq_len = embeds.shape[1] + position_ids = torch.arange(seq_len, device=embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) @@ -772,6 +768,10 @@ def forward( if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) + # Add final hidden state after all layers + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + logits = self.norm(hidden_states) # logits = self.lm_head(logits) return logits, all_hidden_states, all_attentions @@ -944,50 +944,23 @@ class BLTPreTrainedModel(PreTrainedModel): _supports_cache_class = False def _init_weights(self, module): - if isinstance(module, nn.Linear): - std = getattr(module, "_custom_std", module.in_features ** (-0.5)) - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - if module.bias is not None: - nn.init.zeros_(module.bias) + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() elif isinstance(module, nn.Embedding): - std = getattr(module, "_custom_std", module.embedding_dim ** (-0.5)) - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - - elif isinstance(module, BLTModel): - if module.encoder_hash_tok_embedding is not None: - emb_std = module.config.encoder_config.hidden_size ** (-0.5) - for emb in module.encoder_hash_tok_embedding: - emb._custom_std = emb_std - - elif isinstance(module, BLTLocalEncoder): - if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) - - elif isinstance(module, BLTLocalDecoder): - if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) - - elif isinstance(module, BLTPatcher): - emb_std = module.config.hidden_size ** (-0.5) - module.embed_tokens._custom_std = emb_std - module.lm_head._custom_std = emb_std - - elif isinstance(module, BLTForCausalLM): - if module.lm_head is not None: - module.lm_head._custom_std = module.config.decoder_config.hidden_size ** (-0.5) + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, BLTRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, nn.RMSNorm): + module.weight.data.fill_(1.0) class BLTModel(BLTPreTrainedModel): @@ -1010,6 +983,9 @@ def __init__(self, config: BLTConfig): else: self.patcher = None + # Initialize weights and apply final processing + self.post_init() + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1042,10 +1018,18 @@ def forward( all_hidden_states = None all_attentions = None - batch_size, sequence_length = input_ids.shape + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape + else: + batch_size, sequence_length, _ = inputs_embeds.shape # Handle patching if patch_lengths is None: if self.config.patching_mode == PatchingModeEnum.entropy and self.patcher is not None: + if input_ids is None: + raise ValueError("input_ids is required for entropy-based patching") _, patch_lengths, _ = self.patcher( input_ids, patch_size=self.config.patch_size, @@ -1055,19 +1039,24 @@ def forward( device=input_ids.device, ) else: + device = input_ids.device if input_ids is not None else inputs_embeds.device + dtype = input_ids.dtype if input_ids is not None else inputs_embeds.dtype patch_lengths = process_patch_lengths( - torch.ones((batch_size, sequence_length + 1), dtype=input_ids.dtype, device=input_ids.device), + torch.ones((batch_size, sequence_length + 1), dtype=dtype, device=device), self.config.max_patch_length, ) patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) - encoder_embeds = compute_hash_embeddings( - input_ids, - self.local_encoder, - self.encoder_hash_tok_embedding, - self.config.encoder_hash_byte_group_nb_functions, - self.config.encoder_hash_byte_group_size, - self.config.encoder_hash_byte_group_vocab, - ) + if inputs_embeds is not None: + encoder_embeds = inputs_embeds + else: + encoder_embeds = compute_hash_embeddings( + input_ids, + self.local_encoder, + self.encoder_hash_tok_embedding, + self.config.encoder_hash_byte_group_nb_functions, + self.config.encoder_hash_byte_group_size, + self.config.encoder_hash_byte_group_vocab, + ) cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) @@ -1138,6 +1127,14 @@ def forward( attentions=all_attentions, ) + def get_input_embeddings(self): + """Returns the model's input embeddings.""" + return self.local_encoder.embed_tokens + + def set_input_embeddings(self, value): + """Sets the model's input embeddings.""" + self.local_encoder.embed_tokens = value + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: """Convert patch lengths to patch IDs for each token position.""" batch_size = patch_lengths.shape[0] diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 006cf353901d..6aa9eaec1a40 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -15,7 +15,9 @@ import unittest -from transformers import AutoTokenizer, is_torch_available +import pytest +from parameterized import parameterized +from transformers import AutoTokenizer, is_torch_available, set_seed from transformers.testing_utils import ( cleanup, require_read_token, @@ -27,6 +29,7 @@ ) from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester +from ...test_modeling_common import ids_tensor if is_torch_available(): @@ -133,7 +136,7 @@ def __init__( } def get_config(self): - return BLTConfig( + config = BLTConfig( vocab_size=self.vocab_size, max_position_embeddings=self.max_position_embeddings, patch_in_forward=False, # Disable patching for tests @@ -154,6 +157,12 @@ def get_config(self): _attn_implementation="eager" ) + config.num_attention_heads = config.decoder_config.num_attention_heads + config.num_hidden_layers = config.decoder_config.num_hidden_layers + config.hidden_size = config.decoder_config.hidden_size + + return config + @require_torch class BLTModelTest(CausalLMModelTest, unittest.TestCase): @@ -186,6 +195,58 @@ class BLTModelTest(CausalLMModelTest, unittest.TestCase): # used in `test_torch_compile_for_training` _torch_compile_train_cls = BLTForCausalLM if is_torch_available() else None + @pytest.mark.generate + @parameterized.expand([("greedy", 1), ("beam search", 2)]) + def test_generate_from_inputs_embeds(self, _, num_beams): + """Skip this test for BLT as it has complex embedding computation that requires real token IDs for hash-based embeddings.""" + self.skipTest("BLT requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs") + + @pytest.mark.generate + def test_inputs_embeds_matches_input_ids(self): + """Skip this test for BLT as it has complex embedding computation that requires real token IDs for hash-based embeddings.""" + self.skipTest("BLT requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs") + + + @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) + def test_model_rope_scaling_from_config(self, scaling_type): + """Override rope scaling from config test to handle BLT's sub-config structure.""" + if self.rotary_embedding_layer is None: + self.skipTest("Rotary embedding layer not set") + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = self.model_tester_class.base_model_class(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + # Propagate rope_scaling to sub-configs for BLT + config.encoder_config.rope_scaling = config.rope_scaling + config.decoder_config.rope_scaling = config.rope_scaling + config.global_config.rope_scaling = config.rope_scaling + config.patcher_config.rope_scaling = config.rope_scaling + + scaled_model = self.model_tester_class.base_model_class(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + @require_torch_accelerator class BLTIntegrationTest(unittest.TestCase): From 07816a2a79f160f0a02e86bfff284b844ca601c1 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 15 Jul 2025 14:30:24 +0000 Subject: [PATCH 069/139] fix sdpa == causal tests --- .../models/blt/configuration_blt.py | 21 ++---- src/transformers/models/blt/modeling_blt.py | 40 ++++++++++- tests/models/blt/test_modeling_blt.py | 67 ++++++++++++++----- 3 files changed, 91 insertions(+), 37 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 84d3ee302259..3da8dcda2d6b 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -47,7 +47,6 @@ def __init__( rope_scaling=None, hidden_act="silu", intermediate_size=2816, - _attn_implementation="sdpa", **kwargs, ): self.vocab_size = vocab_size @@ -70,8 +69,6 @@ def __init__( super().__init__(**kwargs) - self._attn_implementation = _attn_implementation - class BLTLocalDecoderConfig(PretrainedConfig): """ @@ -97,7 +94,6 @@ def __init__( rope_scaling=None, hidden_act="silu", intermediate_size=2816, - _attn_implementation="sdpa", **kwargs, ): self.vocab_size = vocab_size @@ -116,11 +112,9 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling or {"type": "default"} self.hidden_act = hidden_act - self._attn_implementation = _attn_implementation super().__init__(**kwargs) - self._attn_implementation = _attn_implementation class BLTGlobalTransformerConfig(PretrainedConfig): @@ -143,7 +137,6 @@ def __init__( rope_scaling=None, hidden_act="silu", intermediate_size=5632, - _attn_implementation="sdpa", **kwargs, ): self.hidden_size = hidden_size @@ -161,7 +154,6 @@ def __init__( super().__init__(**kwargs) - self._attn_implementation = _attn_implementation class BLTPatcherConfig(PretrainedConfig): @@ -212,7 +204,6 @@ def __init__( norm_eps=1e-5, dropout=0.0, rope_theta=10000.0, - _attn_implementation="sdpa", attn_bias_type="local_block_causal", intermediate_size=2048, **kwargs, @@ -233,8 +224,6 @@ def __init__( self.rope_scaling = {"type": "default"} super().__init__(**kwargs) - self._attn_implementation = _attn_implementation - class BLTConfig(PretrainedConfig): r""" @@ -332,7 +321,6 @@ def __init__( decoder_config=None, global_config=None, tie_word_embeddings=False, - _attn_implementation="sdpa", initializer_range=0.02, rope_theta=500000.0, rope_scaling=None, @@ -342,7 +330,6 @@ def __init__( self.tie_word_embeddings = tie_word_embeddings self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings - self._attn_implementation = _attn_implementation self.initializer_range = initializer_range self.rope_theta = rope_theta self.rope_scaling = rope_scaling or {"type": "default"} @@ -370,7 +357,7 @@ def __init__( # Initialize component configurations if patcher_config is None: - self.patcher_config = BLTPatcherConfig(_attn_implementation=_attn_implementation) + self.patcher_config = BLTPatcherConfig() logger.info("patcher_config is None, using default BLT patcher config") elif isinstance(patcher_config, dict): self.patcher_config = BLTPatcherConfig(**patcher_config) @@ -378,7 +365,7 @@ def __init__( self.patcher_config = patcher_config if encoder_config is None: - self.encoder_config = BLTLocalEncoderConfig(_attn_implementation=_attn_implementation) + self.encoder_config = BLTLocalEncoderConfig() logger.info("encoder_config is None, using default BLT encoder config") elif isinstance(encoder_config, dict): self.encoder_config = BLTLocalEncoderConfig(**encoder_config) @@ -386,7 +373,7 @@ def __init__( self.encoder_config = encoder_config if decoder_config is None: - self.decoder_config = BLTLocalDecoderConfig(_attn_implementation=_attn_implementation) + self.decoder_config = BLTLocalDecoderConfig() logger.info("decoder_config is None, using default BLT decoder config") elif isinstance(decoder_config, dict): self.decoder_config = BLTLocalDecoderConfig(**decoder_config) @@ -394,7 +381,7 @@ def __init__( self.decoder_config = decoder_config if global_config is None: - self.global_config = BLTGlobalTransformerConfig(_attn_implementation=_attn_implementation) + self.global_config = BLTGlobalTransformerConfig() logger.info("global_config is None, using default BLT global config") elif isinstance(global_config, dict): self.global_config = BLTGlobalTransformerConfig(**global_config) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index fca2ca79e8eb..6277effb3342 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -14,6 +14,7 @@ # limitations under the License. """BLT model.""" +import os from enum import Enum from typing import Callable, Optional, Union @@ -36,7 +37,8 @@ BLTLocalDecoderConfig, BLTLocalEncoderConfig, BLTPatcherConfig, -) +) +from ...masking_utils import create_causal_mask if is_torch_flex_attn_available(): @@ -251,6 +253,10 @@ def __init__(self, config, layer_idx: int): self.rope_theta = config.rope_theta self.layer_idx = layer_idx + # For BLT: We'll dynamically set is_causal based on context + # Decoder layers need causal behavior for generation + self.is_causal = False # Default, will be set dynamically + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -283,7 +289,6 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = ( {"sin": sin, "cos": cos, "cache_position": cache_position} if position_embeddings is not None @@ -292,6 +297,7 @@ def forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -301,6 +307,14 @@ def forward( else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + # Dynamic is_causal: Check if we're in a decoder context by checking the layer index + # BLT decoder layers should use causal attention for correct generation + original_is_causal = self.is_causal + # Enable causal behavior for decoder layers (based on context clues) + # If attention_mask is None, we're likely in a decoder that should be causal + if attention_mask is None: + self.is_causal = True + attn_output, attn_weights = attention_interface( self, query_states, @@ -311,6 +325,9 @@ def forward( scaling=self.scaling, **kwargs, ) + + # Restore original is_causal + self.is_causal = original_is_causal attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -792,6 +809,9 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.scaling = self.head_dim**-0.5 self.dropout = config.dropout + + # Cross-attention should not be causal + self.is_causal = False self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -963,6 +983,8 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) + + class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) @@ -1057,6 +1079,19 @@ def forward( self.config.encoder_hash_byte_group_size, self.config.encoder_hash_byte_group_vocab, ) + + # Create cache_position for mask construction if not provided (like LLaMA) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + encoder_embeds.shape[1], device=encoder_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = attention_mask + cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) @@ -1092,6 +1127,7 @@ def forward( input_ids=input_ids, embeds=encoder_hidden_states, patch_embeds=global_hidden_states, + attention_mask=causal_mask, mask=None, cross_mask=cross_attn_mask_dec, full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 6aa9eaec1a40..1535ac4741bc 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -24,12 +24,13 @@ require_torch, require_torch_accelerator, require_torch_bf16, + require_torch_sdpa, slow, torch_device, ) from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester -from ...test_modeling_common import ids_tensor +from ...test_modeling_common import ids_tensor, _test_eager_matches_sdpa_inference, TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION if is_torch_available(): @@ -60,11 +61,11 @@ def __init__( self.batch_size = 3 # Common parameters for all configs - self.hidden_size = 32 - self.num_hidden_layers = 2 - self.num_attention_heads = 4 - self.num_key_value_heads = 4 - self.intermediate_size = 37 + self.hidden_size = 16 + self.num_hidden_layers = 1 + self.num_attention_heads = 2 + self.num_key_value_heads = 2 + self.intermediate_size = 18 self.hidden_act = "silu" self.max_position_embeddings = 512 self.vocab_size = 99 @@ -85,7 +86,6 @@ def __init__( "hidden_act": self.hidden_act, "norm_eps": self.norm_eps, "dropout": self.dropout, - "_attn_implementation": "eager" } self.encoder_config = { @@ -100,7 +100,6 @@ def __init__( "hidden_act": self.hidden_act, "norm_eps": self.norm_eps, "dropout": self.dropout, - "_attn_implementation": "eager" } self.decoder_config = { @@ -116,9 +115,7 @@ def __init__( "rope_scaling": self.rope_scaling, "hidden_act": self.hidden_act, "norm_eps": self.norm_eps, - "dropout": self.dropout, - "_attn_implementation": "eager" - } + "dropout": self.dropout, } self.global_config = { "hidden_size": self.hidden_size * 2, # Double the hidden size for global transformer @@ -131,9 +128,7 @@ def __init__( "rope_scaling": self.rope_scaling, "hidden_act": self.hidden_act, "norm_eps": self.norm_eps, - "dropout": self.dropout, - "_attn_implementation": "eager" - } + "dropout": self.dropout, } def get_config(self): config = BLTConfig( @@ -153,9 +148,7 @@ def get_config(self): encoder_config=self.encoder_config, decoder_config=self.decoder_config, global_config=self.global_config, - tie_word_embeddings=False, - _attn_implementation="eager" - ) + tie_word_embeddings=False, ) config.num_attention_heads = config.decoder_config.num_attention_heads config.num_hidden_layers = config.decoder_config.num_hidden_layers @@ -206,6 +199,45 @@ def test_inputs_embeds_matches_input_ids(self): """Skip this test for BLT as it has complex embedding computation that requires real token IDs for hash-based embeddings.""" self.skipTest("BLT requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs") + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + @require_torch_sdpa + def test_eager_matches_sdpa_inference( + self, + name, + torch_dtype, + padding_side, + use_attention_mask, + output_attentions, + enable_kernels, + ): + "We need to relax a bit the `atols` for fp32 here due to the altup projections" + atols = { + ("cpu", False, torch.float32): 2e-2, # this was relaxed + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 2e-2, # this was relaxed + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 2e-2, # this was relaxed + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 2e-2, # this was relaxed + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + _test_eager_matches_sdpa_inference( + self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels, atols=atols + ) + + + def test_torchscript_simple(self): + """Skip torchscript test for BLT as it has complex patching logic that's not compatible.""" + self.skipTest("BLT has complex patching logic that's not compatible with torchscript") + + def test_torchscript_output_hidden_state(self): + """Skip torchscript test for BLT as it has complex patching logic that's not compatible.""" + self.skipTest("BLT has complex patching logic that's not compatible with torchscript") + @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) def test_model_rope_scaling_from_config(self, scaling_type): @@ -274,7 +306,6 @@ def test_blt(self): inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - old_input_ids = torch.tensor([tokenizer.encode(prompt, add_eos=False)]).to(torch_device) generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) output_text = tokenizer.decode(generated_ids[0]) From db4cc78590153d11db0f20de67351b614be1eb68 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 15 Jul 2025 15:26:25 +0000 Subject: [PATCH 070/139] fix small model test and some gradient checkpointing --- src/transformers/models/blt/modeling_blt.py | 62 ++++++++++++++++----- tests/models/blt/test_modeling_blt.py | 23 +++++--- 2 files changed, 61 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 6277effb3342..e0e7ca8492b1 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -23,6 +23,7 @@ import torch.nn import torch.nn as nn from torch.nn import functional as F +from ...modeling_layers import GradientCheckpointingLayer from ...activations import ACT2FN from ...cache_utils import Cache @@ -160,7 +161,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -class BLTTransformerLayer(nn.Module): +class BLTTransformerLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -253,9 +254,7 @@ def __init__(self, config, layer_idx: int): self.rope_theta = config.rope_theta self.layer_idx = layer_idx - # For BLT: We'll dynamically set is_causal based on context - # Decoder layers need causal behavior for generation - self.is_causal = False # Default, will be set dynamically + self.is_causal = False self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -289,6 +288,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = ( {"sin": sin, "cos": cos, "cache_position": cache_position} if position_embeddings is not None @@ -297,7 +297,6 @@ def forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -307,10 +306,9 @@ def forward( else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # Dynamic is_causal: Check if we're in a decoder context by checking the layer index + # Check if we're in a decoder context by checking the layer index # BLT decoder layers should use causal attention for correct generation original_is_causal = self.is_causal - # Enable causal behavior for decoder layers (based on context clues) # If attention_mask is None, we're likely in a decoder that should be causal if attention_mask is None: self.is_causal = True @@ -326,7 +324,6 @@ def forward( **kwargs, ) - # Restore original is_causal self.is_causal = original_is_causal attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -572,6 +569,7 @@ def __init__(self, config: BLTLocalEncoderConfig): super().__init__() self.config = config + self.gradient_checkpointing = False self.layers = nn.ModuleList( [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] @@ -635,7 +633,19 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + position_embeddings, + None, # attention_mask + None, # past_key_value + False, # output_attentions + False, # use_cache + None, # cache_position + ) + else: + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] if output_attentions: @@ -700,6 +710,7 @@ def __init__(self, config: BLTLocalDecoderConfig): # Extract config values to instance attributes self.config = config self.cross_attn_decoder = True # config.cross_attn_decoder #TODO: maybe remove + self.gradient_checkpointing = False self.layers = nn.ModuleList( [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] @@ -779,7 +790,19 @@ def forward( ) hidden_states = hidden_states + cross_attention_output - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + position_embeddings, + None, # attention_mask + None, # past_key_value + False, # output_attentions + False, # use_cache + None, # cache_position + ) + else: + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] if output_attentions: @@ -810,7 +833,6 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.scaling = self.head_dim**-0.5 self.dropout = config.dropout - # Cross-attention should not be causal self.is_causal = False self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) @@ -906,6 +928,7 @@ def __init__(self, config: BLTGlobalTransformerConfig): super().__init__() self.config = config + self.gradient_checkpointing = False self.layers = nn.ModuleList() for layer_idx in range(config.num_hidden_layers): @@ -944,7 +967,19 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + position_embeddings, + None, # attention_mask + None, # past_key_value + False, # output_attentions + False, # use_cache + None, # cache_position + ) + else: + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] if output_attentions: @@ -983,8 +1018,6 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) - - class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) @@ -1080,7 +1113,6 @@ def forward( self.config.encoder_hash_byte_group_vocab, ) - # Create cache_position for mask construction if not provided (like LLaMA) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 1535ac4741bc..01f89f59f46a 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -65,15 +65,18 @@ def __init__( self.num_hidden_layers = 1 self.num_attention_heads = 2 self.num_key_value_heads = 2 - self.intermediate_size = 18 + self.intermediate_size = 32 self.hidden_act = "silu" - self.max_position_embeddings = 512 - self.vocab_size = 99 + self.max_position_embeddings = 32 + self.vocab_size = 32 self.rope_theta = 500000.0 self.rope_scaling = {"rope_type": "default"} self.norm_eps = 1e-5 self.dropout = 0.0 - + self.encoder_hash_byte_group_size = [2, 3] + self.encoder_hash_byte_group_vocab = 64 + self.encoder_hash_byte_group_nb_functions = 1 + # Common parameters for all configs self.patcher_config = { "hidden_size": self.hidden_size, "num_hidden_layers": self.num_hidden_layers, @@ -115,7 +118,8 @@ def __init__( "rope_scaling": self.rope_scaling, "hidden_act": self.hidden_act, "norm_eps": self.norm_eps, - "dropout": self.dropout, } + "dropout": self.dropout, + } self.global_config = { "hidden_size": self.hidden_size * 2, # Double the hidden size for global transformer @@ -128,7 +132,8 @@ def __init__( "rope_scaling": self.rope_scaling, "hidden_act": self.hidden_act, "norm_eps": self.norm_eps, - "dropout": self.dropout, } + "dropout": self.dropout, + } def get_config(self): config = BLTConfig( @@ -141,9 +146,9 @@ def get_config(self): patching_batch_size=1, max_patch_length=None, cross_attn_k=2, - encoder_hash_byte_group_size=[3, 4, 5, 6, 7, 8], - encoder_hash_byte_group_vocab=500002, - encoder_hash_byte_group_nb_functions=1, + encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab, + encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions, patcher_config=self.patcher_config, encoder_config=self.encoder_config, decoder_config=self.decoder_config, From c20e4849311d2bb7d307b9feca4a44112cd8837e Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 15 Jul 2025 15:29:59 +0000 Subject: [PATCH 071/139] skip training GC tests --- tests/models/blt/test_modeling_blt.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 01f89f59f46a..1b81305a3904 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -284,6 +284,23 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + @unittest.skip(reason="Training is not supported yet") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @require_torch_accelerator class BLTIntegrationTest(unittest.TestCase): From d2dab12910b3ba0d288ecc41affd747add987ca2 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 15 Jul 2025 15:40:37 +0000 Subject: [PATCH 072/139] fix test --- src/transformers/models/blt/modeling_blt.py | 1 + tests/models/blt/test_modeling_blt.py | 82 +++++++++++---------- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index e0e7ca8492b1..7128b8d778a5 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -1370,6 +1370,7 @@ def patch_lengths_from_entropies( class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] config_class = BLTConfig base_model_prefix = "model" supports_gradient_checkpointing = True diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 1b81305a3904..ec60fb7eb056 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -30,7 +30,11 @@ ) from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester -from ...test_modeling_common import ids_tensor, _test_eager_matches_sdpa_inference, TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION +from ...test_modeling_common import ( + ids_tensor, + _test_eager_matches_sdpa_inference, + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, +) if is_torch_available(): @@ -153,7 +157,8 @@ def get_config(self): encoder_config=self.encoder_config, decoder_config=self.decoder_config, global_config=self.global_config, - tie_word_embeddings=False, ) + tie_word_embeddings=False, + ) config.num_attention_heads = config.decoder_config.num_attention_heads config.num_hidden_layers = config.decoder_config.num_hidden_layers @@ -197,12 +202,16 @@ class BLTModelTest(CausalLMModelTest, unittest.TestCase): @parameterized.expand([("greedy", 1), ("beam search", 2)]) def test_generate_from_inputs_embeds(self, _, num_beams): """Skip this test for BLT as it has complex embedding computation that requires real token IDs for hash-based embeddings.""" - self.skipTest("BLT requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs") + self.skipTest( + "BLT requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs" + ) @pytest.mark.generate def test_inputs_embeds_matches_input_ids(self): """Skip this test for BLT as it has complex embedding computation that requires real token IDs for hash-based embeddings.""" - self.skipTest("BLT requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs") + self.skipTest( + "BLT requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs" + ) @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) @require_torch_sdpa @@ -234,7 +243,6 @@ def test_eager_matches_sdpa_inference( self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels, atols=atols ) - def test_torchscript_simple(self): """Skip torchscript test for BLT as it has complex patching logic that's not compatible.""" self.skipTest("BLT has complex patching logic that's not compatible with torchscript") @@ -243,7 +251,6 @@ def test_torchscript_output_hidden_state(self): """Skip torchscript test for BLT as it has complex patching logic that's not compatible.""" self.skipTest("BLT has complex patching logic that's not compatible with torchscript") - @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) def test_model_rope_scaling_from_config(self, scaling_type): """Override rope scaling from config test to handle BLT's sub-config structure.""" @@ -267,7 +274,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): config.decoder_config.rope_scaling = config.rope_scaling config.global_config.rope_scaling = config.rope_scaling config.patcher_config.rope_scaling = config.rope_scaling - + scaled_model = self.model_tester_class.base_model_class(config) scaled_model.to(torch_device) scaled_model.eval() @@ -301,7 +308,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @require_torch_accelerator class BLTIntegrationTest(unittest.TestCase): def tearDown(self): @@ -444,52 +450,52 @@ def test_model_logits_bf16(self): [ [ -10.5000, - -10.7500, - -6.2188, - -10.5625, - -10.3750, - -9.1875, + -10.6875, + -6.1875, + -10.5000, + -10.3125, + -9.1250, -8.5000, -8.6250, -9.1875, - -9.6250, + -9.5625, -9.3750, -8.5000, -9.0625, - -3.4219, + -3.4062, 2.9688, -10.3125, -6.4062, - -6.0000, + -5.9688, -9.6875, - -9.2500, + -9.1875, -8.8125, - -9.8750, - -9.7500, - -9.5000, -9.8125, - -9.5000, - -9.0625, - -9.8750, - -9.5000, - -9.3750, + -9.7500, + -9.4375, + -9.7500, + -9.4375, + -9.0000, + -9.7500, + -9.4375, + -9.3125, ], [ - -13.3750, + -13.3125, + -13.1875, + -5.6875, -13.2500, - -5.5938, - -13.3750, -13.5000, -8.7500, -7.0312, -7.0000, - -10.1875, + -10.1250, -10.3750, -9.8750, - -7.8125, + -7.7812, -8.8750, - -5.3125, - -3.5469, + -5.2500, + -3.5312, -12.5625, -9.1875, -6.7812, @@ -497,13 +503,13 @@ def test_model_logits_bf16(self): -9.2500, -10.6250, -11.5000, - -11.2500, - -11.0000, - -10.6250, + -11.1875, -10.9375, - -11.1250, - -11.3750, -10.5625, + -10.8750, + -11.0625, + -11.3750, + -10.5000, -10.0000, ], ], @@ -517,6 +523,8 @@ def test_model_logits_bf16(self): with torch.no_grad(): output = model(torch.tensor([input_ids]).to(torch_device))[0] + # print(output[0, :2, :30]) + torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-3, atol=1e-3) @slow From 2dce4bf556e72b26608f572d87065ae4ed63247c Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 15 Jul 2025 16:43:56 +0000 Subject: [PATCH 073/139] updated modular --- src/transformers/models/blt/modular_blt.py | 825 +++++++++++++++++---- 1 file changed, 682 insertions(+), 143 deletions(-) diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 6219b392b791..2b675efeacff 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -14,17 +14,24 @@ # limitations under the License. """BLT modular model, inheriting from Mllama where appropriate.""" +import os +from enum import Enum from typing import Callable, Optional, Union import torch +import torch.distributions import torch.nn as nn import torch.nn.functional as F +from ...activations import ACT2FN from ...cache_utils import Cache from ...generation.utils import GenerationMixin -from ...modeling_outputs import CausalLMOutputWithPast -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import is_torch_flex_attn_available, logging +from ...modeling_layers import GradientCheckpointingLayer +from ...masking_utils import create_causal_mask # Import configuration classes from .configuration_blt import ( @@ -44,26 +51,269 @@ from ..mllama.modeling_mllama import ( MllamaPreTrainedModel, MllamaRotaryEmbedding, - MllamaSelfAttentionDecoderLayer, MllamaTextMLP, MllamaTextRMSNorm, eager_attention_forward, ) -# Import other utility functions and classes from original BLT -from .modeling_blt import ( - PatchingModeEnum, - _prepare_patch_cross_attention_mask, - apply_rotary_pos_emb, - compute_hash_embeddings, - init_hash_embeddings, - process_patch_lengths, -) - logger = logging.get_logger(__name__) +class PatchingModeEnum(str, Enum): + entropy = "entropy" + bpe = "bpe" + bpe_patcher = "bpe_patcher" + space = "space" + static = "static" + byte = "byte" + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + # TODO: not exactly equivalent to other transformers implementations,, need feedback + # Extract first head_dim//2 elements which correspond to the unique frequencies + # This matches the original BLT approach which uses head_dim//2 frequency pairs + head_dim = q.shape[-1] + cos_freqs = cos[..., : head_dim // 2] # [B, S, D/2] + sin_freqs = sin[..., : head_dim // 2] # [B, S, D/2] + + # Expand cos/sin to match query/key tensor format [B, H, S, D/2] + cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + + # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... + q_pairs = q.view(*q.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] + k_pairs = k.view(*k.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] + + # Extract real and i parts + q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] + k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] + + # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] + q_real_rot = cos_freqs * q_real - sin_freqs * q_imag + q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag + k_real_rot = cos_freqs * k_real - sin_freqs * k_imag + k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag + + # Recombine pairs and reshape back to original format + q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] + k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] + + return q_rot.type_as(q), k_rot.type_as(k) + + +def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): + primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, + ] + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) + powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) + prime_powers = prime**powers + return torch.sum(token_tensor * prime_powers, dim=-1) + + +def byte_group_hash_function( + token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 +): + """Hash token groups and map to range [0, max_hash].""" + with torch.no_grad(): + batch_size, seq_len = token_ids.shape + # Add padding for sliding window + padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) + padded_tokens = torch.cat([padding, token_ids], dim=1) + + # Create sliding windows and compute hashes + windows = padded_tokens.unfold(1, group_size, 1) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values = hashes % max_hash + + hash_values.requires_grad = False + return hash_values + + +def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): + """Initialize hash-based token embeddings for the BLT encoder.""" + num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) + embeddings = [nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) for _ in range(num_embeddings)] + return nn.ModuleList(embeddings) + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """Compute token embeddings enhanced with hash-based embeddings.""" + embeddings = local_encoder.embed_tokens(local_encoder_tokens) + embedding_idx = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab + ) + embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) + embedding_idx += 1 + + return embeddings + + +def _prepare_patch_cross_attention_mask( + patch_ids: torch.Tensor, + num_patches: int, + sequence_length: int, + patches_as_queries: bool = False, + cross_attn_k: int = 1, + dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Prepare cross-attention mask for patch-based attention, following mllama's robust approach. + + This function creates masks that control which patches can attend to which other patches, + with support for query/key role swapping and cross-attention multipliers. + + Args: + patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. + num_patches (int): Total number of patches. + sequence_length (int): Length of the sequence. + patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. + cross_attn_k (int): Cross-attention multiplier for repeating patches. + dtype (torch.dtype): Data type for the output mask. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] + - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows + """ + batch_size, seq_len = patch_ids.shape + device = patch_ids.device + + # Determine query and key lengths based on configuration + if patches_as_queries: + q_len = num_patches * cross_attn_k + kv_len = sequence_length + # Create patch-to-sequence mapping + q_patch_ids = ( + torch.arange(num_patches, device=device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(batch_size, num_patches, seq_len) + ) + kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) + else: + q_len = sequence_length + kv_len = num_patches * cross_attn_k + # Create sequence-to-patch mapping + q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) + kv_patch_ids = ( + torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, num_patches) + ) + + # Create base attention mask - boolean mask where True means "should attend" + # Exact patch matching + cross_attention_mask = q_patch_ids == kv_patch_ids + + # Handle cross_attn_k multiplier by repeating along appropriate dimension + repeat_dim = 1 if patches_as_queries else -1 + cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) + + # Validate dimensions + expected_shape = (batch_size, q_len, kv_len) + if cross_attention_mask.shape != expected_shape: + raise ValueError( + f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}" + ) + + # Reshape so it can be used by attn module - add head dimension + cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len] + + # Invert the mask (following mllama pattern exactly) + # True -> 0.0 (attend), False -> 1.0 (will become -inf) + inverted_cross_attn_mask = 1.0 - cross_attention_mask.to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # Apply full-row bias (following mllama pattern exactly) + # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: + """ + Splits patch lengths into smaller segments if they exceed `max_patch_length`. + Pads the result to uniform length across the batch. + + Args: + patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. + max_patch_length (int, optional): Maximum allowed length per patch. + + Returns: + torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. + """ + if max_patch_length is None: + return patch_lengths + + batch_size = patch_lengths.size(0) + processed = [] + + for seq in patch_lengths: + splits = [] + for length in seq[seq > 0]: + length = length.item() + full_chunks, remainder = divmod(length, max_patch_length) + splits.extend([max_patch_length] * full_chunks) + if remainder: + splits.append(remainder) + processed.append(splits) + + # Find max length to pad to + max_len = max(len(splits) for splits in processed) + padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) + + for i, splits in enumerate(processed): + if splits: + padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) + + # Trim zero columns + if (padded != 0).any(dim=0).sum() < padded.shape[1]: + last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 + padded = padded[:, :last_nonzero] + + return padded + + # ============================================================================== # INHERITED COMPONENTS (minimal changes from Mllama) # ============================================================================== @@ -77,8 +327,33 @@ class BLTRMSNorm(MllamaTextRMSNorm): pass -class BLTRotaryEmbedding(MllamaRotaryEmbedding): - pass +class BLTRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + self.rope_type = config.rope_scaling.get("type", "default") + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # ============================================================================== @@ -99,55 +374,105 @@ class BLTPreTrainedModel(MllamaPreTrainedModel): _supports_cache_class = False def _init_weights(self, module): - if isinstance(module, nn.Linear): - std = getattr(module, "_custom_std", module.in_features ** (-0.5)) - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) - if module.bias is not None: - nn.init.zeros_(module.bias) + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() elif isinstance(module, nn.Embedding): - std = getattr(module, "_custom_std", module.embedding_dim ** (-0.5)) - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=std, - a=-3 * std, - b=3 * std, - ) + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, BLTRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, nn.RMSNorm): + module.weight.data.fill_(1.0) + + +class BLTTransformerLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) + self.mlp = BLTMLP(config) + self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor`, *optional*): + Position indices of tokens in the sequence for RoPE computation. + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states - elif isinstance(module, BLTModel): - if module.encoder_hash_tok_embedding is not None: - emb_std = module.config.encoder_config.hidden_size ** (-0.5) - for emb in module.encoder_hash_tok_embedding: - emb._custom_std = emb_std + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states - elif isinstance(module, BLTLocalEncoder): - if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) + outputs = (hidden_states,) - elif isinstance(module, BLTLocalDecoder): - if module.patch_embedding_projection is not None: - module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) + if output_attentions: + outputs += (self_attn_weights,) - elif isinstance(module, BLTPatcher): - emb_std = module.config.hidden_size ** (-0.5) - module.embed_tokens._custom_std = emb_std - module.lm_head._custom_std = emb_std + if use_cache: + outputs += (present_key_value,) - elif isinstance(module, BLTForCausalLM): - if module.lm_head is not None: - module.lm_head._custom_std = module.config.decoder_config.hidden_size ** (-0.5) + return outputs class BLTSelfAttention(nn.Module): - """BLT Self Attention that could inherit from Mllama but has some BLT-specific patterns.""" - def __init__(self, config, layer_idx: int): super().__init__() self.config = config @@ -161,6 +486,8 @@ def __init__(self, config, layer_idx: int): self.rope_theta = config.rope_theta self.layer_idx = layer_idx + self.is_causal = False + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -169,12 +496,13 @@ def __init__(self, config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - past_key_value=None, - cache_position=None, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ): bsz, q_len, _ = hidden_states.size() @@ -187,17 +515,20 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = position_embeddings - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = ( + {"sin": sin, "cos": cos, "cache_position": cache_position} + if position_embeddings is not None + else {"cache_position": cache_position} + ) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - output_attentions = False if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -207,6 +538,13 @@ def forward( else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + # Check if we're in a decoder context by checking the layer index + # BLT decoder layers should use causal attention for correct generation + original_is_causal = self.is_causal + # If attention_mask is None, we're likely in a decoder that should be causal + if attention_mask is None: + self.is_causal = True + attn_output, attn_weights = attention_interface( self, query_states, @@ -217,6 +555,8 @@ def forward( scaling=self.scaling, **kwargs, ) + + self.is_causal = original_is_causal attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -227,10 +567,6 @@ def forward( return attn_output, attn_weights, past_key_value -class BLTTransformerLayer(MllamaSelfAttentionDecoderLayer): - pass - - # ============================================================================== # BLT-SPECIFIC COMPONENTS (no Mllama equivalent) # ============================================================================== @@ -241,6 +577,7 @@ def __init__(self, config: BLTLocalEncoderConfig): super().__init__() self.config = config + self.gradient_checkpointing = False self.layers = nn.ModuleList( [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] @@ -265,17 +602,29 @@ def __init__(self, config: BLTLocalEncoderConfig): def forward( self, - input_ids: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, input_embeds: Optional[torch.Tensor] = None, patch_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, + **kwargs, ): """ """ + # Initialize output collections + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + if input_embeds is None: input_embeds = self.embed_tokens(input_ids) @@ -283,15 +632,33 @@ def forward( hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) - position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_ids = torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for idx, layer in enumerate(self.layers): - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + position_embeddings, + None, # attention_mask + None, # past_key_value + False, # output_attentions + False, # use_cache + None, # cache_position + ) + else: + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) patch_embeds = self.patch_embedding_projection(patch_embeds) @@ -312,7 +679,7 @@ def forward( patch_embeds = patch_embeds + cross_attention_output encoder_cross_states = patch_embeds - return hidden_states, encoder_cross_states + return hidden_states, encoder_cross_states, all_hidden_states, all_attentions def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): """ @@ -351,6 +718,7 @@ def __init__(self, config: BLTLocalDecoderConfig): # Extract config values to instance attributes self.config = config self.cross_attn_decoder = True # config.cross_attn_decoder #TODO: maybe remove + self.gradient_checkpointing = False self.layers = nn.ModuleList( [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] @@ -373,22 +741,28 @@ def __init__(self, config: BLTLocalDecoderConfig): BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) - # self.lm_head = nn.Linear( - # config.hidden_size, - # config.vocab_size, - # bias=False, - # ) - def forward( self, - tokens: torch.Tensor, - embeds: Optional[torch.Tensor], + input_ids: Optional[torch.LongTensor] = None, + embeds: Optional[torch.Tensor] = None, patch_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, + **kwargs, ): + # Initialize output collections + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + batch_size, _, _ = embeds.shape hidden_states = embeds @@ -401,11 +775,16 @@ def forward( if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds - position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + # Use sequence length from embeds (standard transformers pattern) + seq_len = embeds.shape[1] + position_ids = torch.arange(seq_len, device=embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + if i == 0 or self.config.cross_attn_all_layers: # Use cross attention to extract info from patch_embeds into hidden_states cross_attention_output, _, _ = self.cross_attn_layers[i]( @@ -419,12 +798,31 @@ def forward( ) hidden_states = hidden_states + cross_attention_output - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + position_embeddings, + None, # attention_mask + None, # past_key_value + False, # output_attentions + False, # use_cache + None, # cache_position + ) + else: + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add final hidden state after all layers + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + logits = self.norm(hidden_states) # logits = self.lm_head(logits) - return logits, cache + return logits, all_hidden_states, all_attentions class BLTCrossAttention(nn.Module): @@ -442,6 +840,8 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.scaling = self.head_dim**-0.5 self.dropout = config.dropout + + self.is_causal = False self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -461,6 +861,7 @@ def forward( output_attentions: bool = False, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -535,6 +936,7 @@ def __init__(self, config: BLTGlobalTransformerConfig): super().__init__() self.config = config + self.gradient_checkpointing = False self.layers = nn.ModuleList() for layer_idx in range(config.num_hidden_layers): @@ -545,9 +947,21 @@ def __init__(self, config: BLTGlobalTransformerConfig): def forward( self, input_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, + **kwargs, ): + # Initialize output collections + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + batch_size, seq_len, _ = input_embeds.shape hidden_states = input_embeds @@ -558,10 +972,28 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, layer in enumerate(self.layers): - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + position_embeddings, + None, # attention_mask + None, # past_key_value + False, # output_attentions + False, # use_cache + None, # cache_position + ) + else: + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] - return hidden_states, cache + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + return hidden_states, all_hidden_states, all_attentions class BLTModel(BLTPreTrainedModel): @@ -584,71 +1016,114 @@ def __init__(self, config: BLTConfig): else: self.patcher = None + # Initialize weights and apply final processing + self.post_init() + def forward( self, - tokens: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, patch_lengths: Optional[torch.Tensor] = None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - cache_position=None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, - ): + ) -> Union[BaseModelOutputWithPast, tuple]: """ Args: - tokens (torch.Tensor): Input token ids. + input_ids (torch.LongTensor, optional): Input token ids. patch_lengths (Optional[torch.Tensor]): Patch lengths for patching. attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Ignored, for compatibility. Returns: - torch.Tensor: Final hidden states (as before). + Union[BaseModelOutputWithPast, tuple]: Model outputs. """ - batch_size, sequence_length = tokens.shape + # Set defaults from config when parameters are None + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Initialize collections to None - we will ONLY collect from decoder + all_hidden_states = None + all_attentions = None + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape + else: + batch_size, sequence_length, _ = inputs_embeds.shape # Handle patching if patch_lengths is None: - if self.config.patching_mode == PatchingModeEnum.entropy: + if self.config.patching_mode == PatchingModeEnum.entropy and self.patcher is not None: + if input_ids is None: + raise ValueError("input_ids is required for entropy-based patching") _, patch_lengths, _ = self.patcher( - tokens, + input_ids, patch_size=self.config.patch_size, threshold=self.config.patching_threshold, max_patch_length=self.config.max_patch_length, patching_batch_size=self.config.patching_batch_size, - device=tokens.device, + device=input_ids.device, ) else: + device = input_ids.device if input_ids is not None else inputs_embeds.device + dtype = input_ids.dtype if input_ids is not None else inputs_embeds.dtype patch_lengths = process_patch_lengths( - torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device), + torch.ones((batch_size, sequence_length + 1), dtype=dtype, device=device), self.config.max_patch_length, ) patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) - encoder_embeds = compute_hash_embeddings( - tokens, - self.local_encoder, - self.encoder_hash_tok_embedding, - self.config.encoder_hash_byte_group_nb_functions, - self.config.encoder_hash_byte_group_size, - self.config.encoder_hash_byte_group_vocab, - ) + if inputs_embeds is not None: + encoder_embeds = inputs_embeds + else: + encoder_embeds = compute_hash_embeddings( + input_ids, + self.local_encoder, + self.encoder_hash_tok_embedding, + self.config.encoder_hash_byte_group_nb_functions, + self.config.encoder_hash_byte_group_size, + self.config.encoder_hash_byte_group_vocab, + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + encoder_embeds.shape[1], device=encoder_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = attention_mask + cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) - encoder_hidden_states, encoder_cross_states = self.local_encoder( - input_ids=tokens, + encoder_hidden_states, encoder_cross_states, encoder_hidden_states_all, encoder_attentions_all = self.local_encoder( + input_ids=input_ids, input_embeds=encoder_embeds, patch_embeds=None, cross_mask=cross_attn_mask_enc, full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, + output_attentions=False, # Don't collect encoder attentions + output_hidden_states=False, # Don't collect encoder hidden states ) + global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - global_hidden_states, _ = self.global_transformer( + global_hidden_states, global_hidden_states_all, global_attentions_all = self.global_transformer( input_embeds=global_hidden_states, + output_attentions=False, # Don't collect global transformer attentions + output_hidden_states=False, # Don't collect global transformer hidden states ) + decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( decoder_patch_ids, @@ -658,20 +1133,53 @@ def forward( self.config.cross_attn_k, encoder_embeds.dtype, ) - output, _ = self.local_decoder( - tokens=tokens, + output, decoder_hidden_states_all, decoder_attentions_all = self.local_decoder( + input_ids=input_ids, embeds=encoder_hidden_states, patch_embeds=global_hidden_states, + attention_mask=causal_mask, mask=None, cross_mask=cross_attn_mask_dec, full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, + output_attentions=output_attentions, # Only collect decoder attentions + output_hidden_states=output_hidden_states, # Only collect decoder hidden states ) - if output_hidden_states or output_attentions: - if return_dict: - return {"last_hidden_state": output, "hidden_states": None, "attentions": None} - else: - return (output, None, None) - return output + + # Only use decoder outputs (which match the expected num_hidden_layers) + if output_hidden_states and decoder_hidden_states_all is not None: + all_hidden_states = decoder_hidden_states_all + else: + all_hidden_states = None + + if output_attentions and decoder_attentions_all is not None: + all_attentions = decoder_attentions_all + else: + all_attentions = None + + if not return_dict: + output = (output,) + if past_key_values is not None: + output = output + (past_key_values,) + if all_hidden_states is not None: + output = output + (all_hidden_states,) + if all_attentions is not None: + output = output + (all_attentions,) + return output + + return BaseModelOutputWithPast( + last_hidden_state=output, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + def get_input_embeddings(self): + """Returns the model's input embeddings.""" + return self.local_encoder.embed_tokens + + def set_input_embeddings(self, value): + """Sets the model's input embeddings.""" + self.local_encoder.embed_tokens = value def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: """Convert patch lengths to patch IDs for each token position.""" @@ -840,6 +1348,7 @@ def patch_lengths_from_entropies( class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] config_class = BLTConfig base_model_prefix = "model" supports_gradient_checkpointing = True @@ -872,29 +1381,34 @@ def get_decoder(self): def forward( self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - cache_position=None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, - ): + ) -> Union[CausalLMOutputWithPast, tuple]: """ Args: input_ids (torch.LongTensor): Input token ids. attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Standard transformers arguments. labels (torch.LongTensor, optional): Labels for language modeling loss. Returns: - CausalLMOutputWithPast or tuple: Standard transformers output. + Union[CausalLMOutputWithPast, tuple]: Standard transformers output. """ + # Set defaults from config when parameters are None + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.return_dict + # Route only input_ids to BLTModel (as tokens) - hidden_states = self.model( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -907,12 +1421,37 @@ def forward( cache_position=cache_position, **kwargs, ) - if isinstance(hidden_states, dict): - sequence_output = hidden_states["last_hidden_state"] - elif isinstance(hidden_states, tuple): - sequence_output = hidden_states[0] + + if isinstance(outputs, dict): + sequence_output = outputs["last_hidden_state"] + past_key_values = outputs.get("past_key_values") + hidden_states = outputs.get("hidden_states") + attentions = outputs.get("attentions") + elif isinstance(outputs, tuple): + sequence_output = outputs[0] + # Handle tuple format: (output, past_key_values?, hidden_states?, attentions?) + idx = 1 + past_key_values = None + hidden_states = None + attentions = None + + if len(outputs) > idx and use_cache: + past_key_values = outputs[idx] + idx += 1 + + if len(outputs) > idx and output_hidden_states: + hidden_states = outputs[idx] + idx += 1 + + if len(outputs) > idx and output_attentions: + attentions = outputs[idx] + idx += 1 else: - sequence_output = hidden_states + sequence_output = outputs + past_key_values = None + hidden_states = None + attentions = None + logits = self.lm_head(sequence_output) loss = None if labels is not None: @@ -929,9 +1468,9 @@ def forward( return CausalLMOutputWithPast( loss=loss, logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, + past_key_values=past_key_values, + hidden_states=hidden_states, + attentions=attentions, ) From e64ed907b9790a44207c7a37151f85e1be016418 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 15 Jul 2025 16:44:40 +0000 Subject: [PATCH 074/139] update modular --- src/transformers/models/blt/modular_blt.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 2b675efeacff..4e2901822fff 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -314,9 +314,7 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optiona return padded -# ============================================================================== # INHERITED COMPONENTS (minimal changes from Mllama) -# ============================================================================== class BLTMLP(MllamaTextMLP): @@ -356,9 +354,9 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# ============================================================================== + # INHERITED BUT CUSTOMIZED COMPONENTS -# ============================================================================== + class BLTPreTrainedModel(MllamaPreTrainedModel): @@ -567,9 +565,7 @@ def forward( return attn_output, attn_weights, past_key_value -# ============================================================================== # BLT-SPECIFIC COMPONENTS (no Mllama equivalent) -# ============================================================================== class BLTLocalEncoder(nn.Module): From 7738005359ca7585deed29c4658349714812c3ff Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 15 Jul 2025 16:54:54 +0000 Subject: [PATCH 075/139] ruff --- .../models/blt/configuration_blt.py | 3 - src/transformers/models/blt/modeling_blt.py | 93 +++++++++++------- src/transformers/models/blt/modular_blt.py | 97 +++++++++++-------- tests/models/blt/test_modeling_blt.py | 5 +- 4 files changed, 117 insertions(+), 81 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 3da8dcda2d6b..8e73b88798b9 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -14,7 +14,6 @@ # limitations under the License. """BLT model configuration""" - from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -116,7 +115,6 @@ def __init__( super().__init__(**kwargs) - class BLTGlobalTransformerConfig(PretrainedConfig): """ Configuration class for the BLT Global Transformer component. @@ -155,7 +153,6 @@ def __init__( super().__init__(**kwargs) - class BLTPatcherConfig(PretrainedConfig): r""" Configuration class for the BLT Patcher/Entropy model component. diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 7128b8d778a5..f014315f3c34 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -14,7 +14,6 @@ # limitations under the License. """BLT model.""" -import os from enum import Enum from typing import Callable, Optional, Union @@ -23,11 +22,11 @@ import torch.nn import torch.nn as nn from torch.nn import functional as F -from ...modeling_layers import GradientCheckpointingLayer from ...activations import ACT2FN from ...cache_utils import Cache from ...generation.utils import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -38,8 +37,7 @@ BLTLocalDecoderConfig, BLTLocalEncoderConfig, BLTPatcherConfig, -) -from ...masking_utils import create_causal_mask +) if is_torch_flex_attn_available(): @@ -254,7 +252,7 @@ def __init__(self, config, layer_idx: int): self.rope_theta = config.rope_theta self.layer_idx = layer_idx - self.is_causal = False + self.is_causal = False self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -312,7 +310,7 @@ def forward( # If attention_mask is None, we're likely in a decoder that should be causal if attention_mask is None: self.is_causal = True - + attn_output, attn_weights = attention_interface( self, query_states, @@ -323,7 +321,7 @@ def forward( scaling=self.scaling, **kwargs, ) - + self.is_causal = original_is_causal attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -624,7 +622,9 @@ def forward( hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) - position_ids = torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_ids = ( + torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) @@ -640,12 +640,17 @@ def forward( position_embeddings, None, # attention_mask None, # past_key_value - False, # output_attentions - False, # use_cache + False, # output_attentions + False, # use_cache None, # cache_position ) else: - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) + layer_outputs = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=None, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: @@ -797,12 +802,17 @@ def forward( position_embeddings, None, # attention_mask None, # past_key_value - False, # output_attentions - False, # use_cache + False, # output_attentions + False, # use_cache None, # cache_position ) else: - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) + layer_outputs = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=None, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: @@ -832,7 +842,7 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.scaling = self.head_dim**-0.5 self.dropout = config.dropout - + self.is_causal = False self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) @@ -974,12 +984,17 @@ def forward( position_embeddings, None, # attention_mask None, # past_key_value - False, # output_attentions - False, # use_cache + False, # output_attentions + False, # use_cache None, # cache_position ) else: - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) + layer_outputs = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=None, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: @@ -1066,7 +1081,9 @@ def forward( """ # Set defaults from config when parameters are None output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.return_dict # Initialize collections to None - we will ONLY collect from decoder @@ -1075,7 +1092,7 @@ def forward( if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - + if input_ids is not None: batch_size, sequence_length = input_ids.shape else: @@ -1127,16 +1144,18 @@ def forward( cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) - encoder_hidden_states, encoder_cross_states, encoder_hidden_states_all, encoder_attentions_all = self.local_encoder( - input_ids=input_ids, - input_embeds=encoder_embeds, - patch_embeds=None, - cross_mask=cross_attn_mask_enc, - full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, - num_patches=patch_lengths.shape[1], - patch_ids=patch_ids, - output_attentions=False, # Don't collect encoder attentions - output_hidden_states=False, # Don't collect encoder hidden states + encoder_hidden_states, encoder_cross_states, encoder_hidden_states_all, encoder_attentions_all = ( + self.local_encoder( + input_ids=input_ids, + input_embeds=encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + output_attentions=False, # Don't collect encoder attentions + output_hidden_states=False, # Don't collect encoder hidden states + ) ) global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) @@ -1172,7 +1191,7 @@ def forward( all_hidden_states = decoder_hidden_states_all else: all_hidden_states = None - + if output_attentions and decoder_attentions_all is not None: all_attentions = decoder_attentions_all else: @@ -1426,7 +1445,9 @@ def forward( """ # Set defaults from config when parameters are None output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.return_dict # Route only input_ids to BLTModel (as tokens) @@ -1443,7 +1464,7 @@ def forward( cache_position=cache_position, **kwargs, ) - + if isinstance(outputs, dict): sequence_output = outputs["last_hidden_state"] past_key_values = outputs.get("past_key_values") @@ -1456,15 +1477,15 @@ def forward( past_key_values = None hidden_states = None attentions = None - + if len(outputs) > idx and use_cache: past_key_values = outputs[idx] idx += 1 - + if len(outputs) > idx and output_hidden_states: hidden_states = outputs[idx] idx += 1 - + if len(outputs) > idx and output_attentions: attentions = outputs[idx] idx += 1 diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 4e2901822fff..36390e0915ef 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -14,7 +14,6 @@ # limitations under the License. """BLT modular model, inheriting from Mllama where appropriate.""" -import os from enum import Enum from typing import Callable, Optional, Union @@ -23,15 +22,13 @@ import torch.nn as nn import torch.nn.functional as F -from ...activations import ACT2FN from ...cache_utils import Cache from ...generation.utils import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import is_torch_flex_attn_available, logging -from ...modeling_layers import GradientCheckpointingLayer -from ...masking_utils import create_causal_mask # Import configuration classes from .configuration_blt import ( @@ -50,7 +47,6 @@ # Import from mllama for inheritance from ..mllama.modeling_mllama import ( MllamaPreTrainedModel, - MllamaRotaryEmbedding, MllamaTextMLP, MllamaTextRMSNorm, eager_attention_forward, @@ -354,11 +350,9 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - # INHERITED BUT CUSTOMIZED COMPONENTS - class BLTPreTrainedModel(MllamaPreTrainedModel): """BLT PreTrainedModel inheriting from Mllama but with BLT-specific init.""" @@ -484,7 +478,7 @@ def __init__(self, config, layer_idx: int): self.rope_theta = config.rope_theta self.layer_idx = layer_idx - self.is_causal = False + self.is_causal = False self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -542,7 +536,7 @@ def forward( # If attention_mask is None, we're likely in a decoder that should be causal if attention_mask is None: self.is_causal = True - + attn_output, attn_weights = attention_interface( self, query_states, @@ -553,7 +547,7 @@ def forward( scaling=self.scaling, **kwargs, ) - + self.is_causal = original_is_causal attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -628,7 +622,9 @@ def forward( hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) - position_ids = torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_ids = ( + torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) @@ -644,12 +640,17 @@ def forward( position_embeddings, None, # attention_mask None, # past_key_value - False, # output_attentions - False, # use_cache + False, # output_attentions + False, # use_cache None, # cache_position ) else: - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) + layer_outputs = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=None, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: @@ -801,12 +802,17 @@ def forward( position_embeddings, None, # attention_mask None, # past_key_value - False, # output_attentions - False, # use_cache + False, # output_attentions + False, # use_cache None, # cache_position ) else: - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) + layer_outputs = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=None, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: @@ -836,7 +842,7 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.scaling = self.head_dim**-0.5 self.dropout = config.dropout - + self.is_causal = False self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) @@ -978,12 +984,17 @@ def forward( position_embeddings, None, # attention_mask None, # past_key_value - False, # output_attentions - False, # use_cache + False, # output_attentions + False, # use_cache None, # cache_position ) else: - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None, output_attentions=output_attentions) + layer_outputs = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=None, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: @@ -1040,7 +1051,9 @@ def forward( """ # Set defaults from config when parameters are None output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.return_dict # Initialize collections to None - we will ONLY collect from decoder @@ -1049,7 +1062,7 @@ def forward( if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - + if input_ids is not None: batch_size, sequence_length = input_ids.shape else: @@ -1101,16 +1114,18 @@ def forward( cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) - encoder_hidden_states, encoder_cross_states, encoder_hidden_states_all, encoder_attentions_all = self.local_encoder( - input_ids=input_ids, - input_embeds=encoder_embeds, - patch_embeds=None, - cross_mask=cross_attn_mask_enc, - full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, - num_patches=patch_lengths.shape[1], - patch_ids=patch_ids, - output_attentions=False, # Don't collect encoder attentions - output_hidden_states=False, # Don't collect encoder hidden states + encoder_hidden_states, encoder_cross_states, encoder_hidden_states_all, encoder_attentions_all = ( + self.local_encoder( + input_ids=input_ids, + input_embeds=encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + output_attentions=False, # Don't collect encoder attentions + output_hidden_states=False, # Don't collect encoder hidden states + ) ) global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) @@ -1146,7 +1161,7 @@ def forward( all_hidden_states = decoder_hidden_states_all else: all_hidden_states = None - + if output_attentions and decoder_attentions_all is not None: all_attentions = decoder_attentions_all else: @@ -1400,7 +1415,9 @@ def forward( """ # Set defaults from config when parameters are None output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.return_dict # Route only input_ids to BLTModel (as tokens) @@ -1417,7 +1434,7 @@ def forward( cache_position=cache_position, **kwargs, ) - + if isinstance(outputs, dict): sequence_output = outputs["last_hidden_state"] past_key_values = outputs.get("past_key_values") @@ -1430,15 +1447,15 @@ def forward( past_key_values = None hidden_states = None attentions = None - + if len(outputs) > idx and use_cache: past_key_values = outputs[idx] idx += 1 - + if len(outputs) > idx and output_hidden_states: hidden_states = outputs[idx] idx += 1 - + if len(outputs) > idx and output_attentions: attentions = outputs[idx] idx += 1 diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index ec60fb7eb056..6079096722f3 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -17,6 +17,7 @@ import pytest from parameterized import parameterized + from transformers import AutoTokenizer, is_torch_available, set_seed from transformers.testing_utils import ( cleanup, @@ -31,9 +32,9 @@ from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester from ...test_modeling_common import ( - ids_tensor, - _test_eager_matches_sdpa_inference, TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, + _test_eager_matches_sdpa_inference, + ids_tensor, ) From 8b2a23830eeb279422bf80e1b41a8b0a21d34ccb Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 15 Jul 2025 17:05:37 +0000 Subject: [PATCH 076/139] adding modular + modeling --- src/transformers/models/blt/modeling_blt.py | 836 ++++++++++++-------- tests/models/blt/test_modeling_blt.py | 4 + 2 files changed, 500 insertions(+), 340 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index f014315f3c34..4a6c9bba1888 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/blt/modular_blt.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_blt.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved. # @@ -12,25 +18,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""BLT model.""" from enum import Enum from typing import Callable, Optional, Union import torch import torch.distributions -import torch.nn import torch.nn as nn -from torch.nn import functional as F +import torch.nn.functional as F from ...activations import ACT2FN from ...cache_utils import Cache from ...generation.utils import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...utils import is_torch_flex_attn_available, logging +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging + +# Import configuration classes from .configuration_blt import ( BLTConfig, BLTGlobalTransformerConfig, @@ -44,6 +52,10 @@ from torch.nn.attention.flex_attention import BlockMask +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -56,88 +68,19 @@ class PatchingModeEnum(str, Enum): byte = "byte" -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = F.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - # TODO: not exactly equivalent to other transformers implementations,, need feedback - # Extract first head_dim//2 elements which correspond to the unique frequencies - # This matches the original BLT approach which uses head_dim//2 frequency pairs - head_dim = q.shape[-1] - cos_freqs = cos[..., : head_dim // 2] # [B, S, D/2] - sin_freqs = sin[..., : head_dim // 2] # [B, S, D/2] - - # Expand cos/sin to match query/key tensor format [B, H, S, D/2] - cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] - sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] - - # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... - q_pairs = q.view(*q.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] - k_pairs = k.view(*k.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] - - # Extract real and i parts - q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] - k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] - - # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] - q_real_rot = cos_freqs * q_real - sin_freqs * q_imag - q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag - k_real_rot = cos_freqs * k_real - sin_freqs * k_imag - k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag - - # Recombine pairs and reshape back to original format - q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] - k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] - - return q_rot.type_as(q), k_rot.type_as(k) - - class BLTMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + # Ignore copy self.act_fn = ACT2FN[config.hidden_act] - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -158,6 +101,199 @@ def forward(self, hidden_states): hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class BLTRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + self.rope_type = config.rope_scaling.get("type", "default") + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class BLTPreTrainedModel(PreTrainedModel): + """BLT PreTrainedModel inheriting from Mllama but with BLT-specific init.""" + + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + _supports_cache_class = False + _supports_static_cache = False # static cache cannot have different shapes for each layer + _supports_sdpa = True + _supports_flash_attn = True + _supports_quantized_cache = True + _supports_flex_attn = True + _supports_attention_backend = True + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False # BLT uses its own attention implementation + + def _init_weights(self, module): + std = self.config.initializer_range + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, BLTRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, nn.RMSNorm): + module.weight.data.fill_(1.0) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + class BLTTransformerLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx: int): @@ -238,6 +374,77 @@ def forward( return outputs +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + # TODO: not exactly equivalent to other transformers implementations,, need feedback + # Extract first head_dim//2 elements which correspond to the unique frequencies + # This matches the original BLT approach which uses head_dim//2 frequency pairs + head_dim = q.shape[-1] + cos_freqs = cos[..., : head_dim // 2] # [B, S, D/2] + sin_freqs = sin[..., : head_dim // 2] # [B, S, D/2] + + # Expand cos/sin to match query/key tensor format [B, H, S, D/2] + cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + + # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... + q_pairs = q.view(*q.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] + k_pairs = k.view(*k.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] + + # Extract real and i parts + q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] + k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] + + # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] + q_real_rot = cos_freqs * q_real - sin_freqs * q_imag + q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag + k_real_rot = cos_freqs * k_real - sin_freqs * k_imag + k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag + + # Recombine pairs and reshape back to original format + q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] + k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] + + return q_rot.type_as(q), k_rot.type_as(k) + + class BLTSelfAttention(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() @@ -316,250 +523,24 @@ def forward( query_states, key_states, value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, - scaling=self.scaling, - **kwargs, - ) - - self.is_causal = original_is_causal - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): - primes = [ - 1000000007, - 5915587277, - 1500450271, - 3267000013, - 5754853343, - 4093082899, - 9576890767, - 3628273133, - 2860486313, - 5463458053, - 3367900313, - ] - prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) - powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) - prime_powers = prime**powers - return torch.sum(token_tensor * prime_powers, dim=-1) - - -def byte_group_hash_function( - token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 -): - """Hash token groups and map to range [0, max_hash].""" - with torch.no_grad(): - batch_size, seq_len = token_ids.shape - # Add padding for sliding window - padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) - padded_tokens = torch.cat([padding, token_ids], dim=1) - - # Create sliding windows and compute hashes - windows = padded_tokens.unfold(1, group_size, 1) - hashes = rolling_polynomial_hash(windows, hash_func_nb) - hash_values = hashes % max_hash - - hash_values.requires_grad = False - return hash_values - - -def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): - """Initialize hash-based token embeddings for the BLT encoder.""" - num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) - embeddings = [nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) for _ in range(num_embeddings)] - return nn.ModuleList(embeddings) - - -def compute_hash_embeddings( - local_encoder_tokens: torch.Tensor, - local_encoder, - encoder_hash_tok_embedding: nn.ModuleList, - encoder_hash_byte_group_nb_functions: int, - encoder_hash_byte_group_size: list, - encoder_hash_byte_group_vocab: int, -) -> torch.Tensor: - """Compute token embeddings enhanced with hash-based embeddings.""" - embeddings = local_encoder.embed_tokens(local_encoder_tokens) - embedding_idx = 0 - for func_nb in range(encoder_hash_byte_group_nb_functions): - for group_size in encoder_hash_byte_group_size: - hash_ids = byte_group_hash_function( - local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab - ) - embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) - embedding_idx += 1 - - return embeddings - - -def _prepare_patch_cross_attention_mask( - patch_ids: torch.Tensor, - num_patches: int, - sequence_length: int, - patches_as_queries: bool = False, - cross_attn_k: int = 1, - dtype: torch.dtype = torch.float32, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Prepare cross-attention mask for patch-based attention, following mllama's robust approach. - - This function creates masks that control which patches can attend to which other patches, - with support for query/key role swapping and cross-attention multipliers. - - Args: - patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. - num_patches (int): Total number of patches. - sequence_length (int): Length of the sequence. - patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. - cross_attn_k (int): Cross-attention multiplier for repeating patches. - dtype (torch.dtype): Data type for the output mask. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] - - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows - """ - batch_size, seq_len = patch_ids.shape - device = patch_ids.device - - # Determine query and key lengths based on configuration - if patches_as_queries: - q_len = num_patches * cross_attn_k - kv_len = sequence_length - # Create patch-to-sequence mapping - q_patch_ids = ( - torch.arange(num_patches, device=device) - .unsqueeze(0) - .unsqueeze(-1) - .expand(batch_size, num_patches, seq_len) - ) - kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) - else: - q_len = sequence_length - kv_len = num_patches * cross_attn_k - # Create sequence-to-patch mapping - q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) - kv_patch_ids = ( - torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, num_patches) - ) - - # Create base attention mask - boolean mask where True means "should attend" - # Exact patch matching - cross_attention_mask = q_patch_ids == kv_patch_ids - - # Handle cross_attn_k multiplier by repeating along appropriate dimension - repeat_dim = 1 if patches_as_queries else -1 - cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) - - # Validate dimensions - expected_shape = (batch_size, q_len, kv_len) - if cross_attention_mask.shape != expected_shape: - raise ValueError( - f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}" - ) - - # Reshape so it can be used by attn module - add head dimension - cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len] - - # Invert the mask (following mllama pattern exactly) - # True -> 0.0 (attend), False -> 1.0 (will become -inf) - inverted_cross_attn_mask = 1.0 - cross_attention_mask.to(dtype) - cross_attention_mask = inverted_cross_attn_mask.masked_fill( - inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min - ) - - # Apply full-row bias (following mllama pattern exactly) - # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's - # last dimension contains negative infinity values, otherwise it's 1 - negative_inf_value = torch.finfo(dtype).min - full_text_row_masked_out_mask = ( - (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] - ) - cross_attention_mask *= full_text_row_masked_out_mask - - return cross_attention_mask, full_text_row_masked_out_mask - - -def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: - """ - Splits patch lengths into smaller segments if they exceed `max_patch_length`. - Pads the result to uniform length across the batch. - - Args: - patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. - max_patch_length (int, optional): Maximum allowed length per patch. - - Returns: - torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. - """ - if max_patch_length is None: - return patch_lengths - - batch_size = patch_lengths.size(0) - processed = [] - - for seq in patch_lengths: - splits = [] - for length in seq[seq > 0]: - length = length.item() - full_chunks, remainder = divmod(length, max_patch_length) - splits.extend([max_patch_length] * full_chunks) - if remainder: - splits.append(remainder) - processed.append(splits) - - # Find max length to pad to - max_len = max(len(splits) for splits in processed) - padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) - - for i, splits in enumerate(processed): - if splits: - padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) - - # Trim zero columns - if (padded != 0).any(dim=0).sum() < padded.shape[1]: - last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 - padded = padded[:, :last_nonzero] - - return padded + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + self.is_causal = original_is_causal -class BLTRotaryEmbedding(nn.Module): - def __init__(self, config, device=None): - super().__init__() - self.rope_type = config.rope_scaling.get("type", "default") - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + if not output_attentions: + attn_weights = None - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() + return attn_output, attn_weights, past_key_value - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +# BLT-SPECIFIC COMPONENTS (no Mllama equivalent) class BLTLocalEncoder(nn.Module): @@ -1003,34 +984,204 @@ def forward( return hidden_states, all_hidden_states, all_attentions -class BLTPreTrainedModel(PreTrainedModel): - config_class = BLTConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = False # BLT uses its own attention implementation - _supports_sdpa = True - _supports_cache_class = False +def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): + primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, + ] + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) + powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) + prime_powers = prime**powers + return torch.sum(token_tensor * prime_powers, dim=-1) - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - elif isinstance(module, BLTRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, nn.RMSNorm): - module.weight.data.fill_(1.0) +def byte_group_hash_function( + token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 +): + """Hash token groups and map to range [0, max_hash].""" + with torch.no_grad(): + batch_size, seq_len = token_ids.shape + # Add padding for sliding window + padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) + padded_tokens = torch.cat([padding, token_ids], dim=1) + + # Create sliding windows and compute hashes + windows = padded_tokens.unfold(1, group_size, 1) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values = hashes % max_hash + + hash_values.requires_grad = False + return hash_values + + +def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): + """Initialize hash-based token embeddings for the BLT encoder.""" + num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) + embeddings = [nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) for _ in range(num_embeddings)] + return nn.ModuleList(embeddings) + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """Compute token embeddings enhanced with hash-based embeddings.""" + embeddings = local_encoder.embed_tokens(local_encoder_tokens) + embedding_idx = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab + ) + embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) + embedding_idx += 1 + + return embeddings + + +def _prepare_patch_cross_attention_mask( + patch_ids: torch.Tensor, + num_patches: int, + sequence_length: int, + patches_as_queries: bool = False, + cross_attn_k: int = 1, + dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Prepare cross-attention mask for patch-based attention, following mllama's robust approach. + + This function creates masks that control which patches can attend to which other patches, + with support for query/key role swapping and cross-attention multipliers. + + Args: + patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. + num_patches (int): Total number of patches. + sequence_length (int): Length of the sequence. + patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. + cross_attn_k (int): Cross-attention multiplier for repeating patches. + dtype (torch.dtype): Data type for the output mask. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] + - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows + """ + batch_size, seq_len = patch_ids.shape + device = patch_ids.device + + # Determine query and key lengths based on configuration + if patches_as_queries: + q_len = num_patches * cross_attn_k + kv_len = sequence_length + # Create patch-to-sequence mapping + q_patch_ids = ( + torch.arange(num_patches, device=device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(batch_size, num_patches, seq_len) + ) + kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) + else: + q_len = sequence_length + kv_len = num_patches * cross_attn_k + # Create sequence-to-patch mapping + q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) + kv_patch_ids = ( + torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, num_patches) + ) + + # Create base attention mask - boolean mask where True means "should attend" + # Exact patch matching + cross_attention_mask = q_patch_ids == kv_patch_ids + + # Handle cross_attn_k multiplier by repeating along appropriate dimension + repeat_dim = 1 if patches_as_queries else -1 + cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) + + # Validate dimensions + expected_shape = (batch_size, q_len, kv_len) + if cross_attention_mask.shape != expected_shape: + raise ValueError( + f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}" + ) + + # Reshape so it can be used by attn module - add head dimension + cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len] + + # Invert the mask (following mllama pattern exactly) + # True -> 0.0 (attend), False -> 1.0 (will become -inf) + inverted_cross_attn_mask = 1.0 - cross_attention_mask.to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # Apply full-row bias (following mllama pattern exactly) + # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: + """ + Splits patch lengths into smaller segments if they exceed `max_patch_length`. + Pads the result to uniform length across the batch. + + Args: + patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. + max_patch_length (int, optional): Maximum allowed length per patch. + + Returns: + torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. + """ + if max_patch_length is None: + return patch_lengths + + batch_size = patch_lengths.size(0) + processed = [] + + for seq in patch_lengths: + splits = [] + for length in seq[seq > 0]: + length = length.item() + full_chunks, remainder = divmod(length, max_patch_length) + splits.extend([max_patch_length] * full_chunks) + if remainder: + splits.append(remainder) + processed.append(splits) + + # Find max length to pad to + max_len = max(len(splits) for splits in processed) + padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) + + for i, splits in enumerate(processed): + if splits: + padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) + + # Trim zero columns + if (padded != 0).any(dim=0).sum() < padded.shape[1]: + last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 + padded = padded[:, :last_nonzero] + + return padded class BLTModel(BLTPreTrainedModel): @@ -1526,4 +1677,9 @@ def forward( "BLTGlobalTransformer", "BLTTransformerLayer", "BLTForCausalLM", + "BLTMLP", + "BLTRMSNorm", + "BLTRotaryEmbedding", + "BLTSelfAttention", + "BLTCrossAttention", ] diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 6079096722f3..8e97ab4d9ec0 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -308,6 +308,10 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return + @require_torch_accelerator class BLTIntegrationTest(unittest.TestCase): From 0c23353fd4430b0291ccf0e9cdb78fc6f73b92e0 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 16 Jul 2025 10:36:51 +0000 Subject: [PATCH 077/139] modular --- .../models/blt/configuration_blt.py | 12 ++-- src/transformers/models/blt/modeling_blt.py | 48 ++++++------- src/transformers/models/blt/modular_blt.py | 68 +++---------------- tests/models/blt/test_modeling_blt.py | 4 +- 4 files changed, 40 insertions(+), 92 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 8e73b88798b9..b60cdf5de150 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -63,7 +63,7 @@ def __init__( self.dropout = dropout self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta - self.rope_scaling = rope_scaling or {"type": "default"} + self.rope_scaling = rope_scaling self.hidden_act = hidden_act super().__init__(**kwargs) @@ -109,7 +109,7 @@ def __init__( self.dropout = dropout self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta - self.rope_scaling = rope_scaling or {"type": "default"} + self.rope_scaling = rope_scaling self.hidden_act = hidden_act super().__init__(**kwargs) @@ -147,7 +147,7 @@ def __init__( self.dropout = dropout self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta - self.rope_scaling = rope_scaling or {"type": "default"} + self.rope_scaling = rope_scaling self.hidden_act = hidden_act super().__init__(**kwargs) @@ -203,6 +203,7 @@ def __init__( rope_theta=10000.0, attn_bias_type="local_block_causal", intermediate_size=2048, + rope_scaling=None, **kwargs, ): self.vocab_size = vocab_size @@ -218,7 +219,8 @@ def __init__( self.attn_bias_type = attn_bias_type self.hidden_act = "silu" # BLT uses silu activation self.intermediate_size = intermediate_size or int(8 * self.hidden_size / 3) - self.rope_scaling = {"type": "default"} + self.rope_scaling = rope_scaling + super().__init__(**kwargs) @@ -329,7 +331,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.rope_theta = rope_theta - self.rope_scaling = rope_scaling or {"type": "default"} + self.rope_scaling = rope_scaling # Patching configuration self.patch_in_forward = patch_in_forward diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 4a6c9bba1888..c2c64b3e27da 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -18,8 +18,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from enum import Enum from typing import Callable, Optional, Union import torch @@ -37,8 +35,6 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging - -# Import configuration classes from .configuration_blt import ( BLTConfig, BLTGlobalTransformerConfig, @@ -59,15 +55,6 @@ logger = logging.get_logger(__name__) -class PatchingModeEnum(str, Enum): - entropy = "entropy" - bpe = "bpe" - bpe_patcher = "bpe_patcher" - space = "space" - static = "static" - byte = "byte" - - class BLTMLP(nn.Module): def __init__(self, config): super().__init__() @@ -106,9 +93,14 @@ def extra_repr(self): class BLTRotaryEmbedding(nn.Module): - def __init__(self, config, device=None): + def __init__(self, config: BLTConfig, device=None): super().__init__() - self.rope_type = config.rope_scaling.get("type", "default") + # BC: "rope_type" was originally "type" + self.rope_type = ( + config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + if config.rope_scaling is not None + else "default" + ) self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings @@ -374,6 +366,18 @@ def forward( return outputs +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -400,18 +404,6 @@ def eager_attention_forward( return attn_output, attn_weights -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # TODO: not exactly equivalent to other transformers implementations,, need feedback # Extract first head_dim//2 elements which correspond to the unique frequencies @@ -1250,7 +1242,7 @@ def forward( batch_size, sequence_length, _ = inputs_embeds.shape # Handle patching if patch_lengths is None: - if self.config.patching_mode == PatchingModeEnum.entropy and self.patcher is not None: + if self.config.patching_mode == "entropy" and self.patcher is not None: if input_ids is None: raise ValueError("input_ids is required for entropy-based patching") _, patch_lengths, _ = self.patcher( diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 36390e0915ef..dcb3cf9d0837 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -14,7 +14,6 @@ # limitations under the License. """BLT modular model, inheriting from Mllama where appropriate.""" -from enum import Enum from typing import Callable, Optional, Union import torch @@ -26,11 +25,8 @@ from ...generation.utils import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import is_torch_flex_attn_available, logging - -# Import configuration classes from .configuration_blt import ( BLTConfig, BLTGlobalTransformerConfig, @@ -44,9 +40,9 @@ from torch.nn.attention.flex_attention import BlockMask -# Import from mllama for inheritance from ..mllama.modeling_mllama import ( MllamaPreTrainedModel, + MllamaRotaryEmbedding, MllamaTextMLP, MllamaTextRMSNorm, eager_attention_forward, @@ -56,27 +52,6 @@ logger = logging.get_logger(__name__) -class PatchingModeEnum(str, Enum): - entropy = "entropy" - bpe = "bpe" - bpe_patcher = "bpe_patcher" - space = "space" - static = "static" - byte = "byte" - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # TODO: not exactly equivalent to other transformers implementations,, need feedback # Extract first head_dim//2 elements which correspond to the unique frequencies @@ -310,9 +285,6 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optiona return padded -# INHERITED COMPONENTS (minimal changes from Mllama) - - class BLTMLP(MllamaTextMLP): pass @@ -321,33 +293,15 @@ class BLTRMSNorm(MllamaTextRMSNorm): pass -class BLTRotaryEmbedding(nn.Module): - def __init__(self, config, device=None): - super().__init__() - self.rope_type = config.rope_scaling.get("type", "default") - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +class BLTRotaryEmbedding(MllamaRotaryEmbedding): + def __init__(self, config: BLTConfig, device=None): + super().__init__(config=config, device=device) + # BC: "rope_type" was originally "type" + self.rope_type = ( + config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + if config.rope_scaling is not None + else "default" + ) # INHERITED BUT CUSTOMIZED COMPONENTS @@ -1069,7 +1023,7 @@ def forward( batch_size, sequence_length, _ = inputs_embeds.shape # Handle patching if patch_lengths is None: - if self.config.patching_mode == PatchingModeEnum.entropy and self.patcher is not None: + if self.config.patching_mode == "entropy" and self.patcher is not None: if input_ids is None: raise ValueError("input_ids is required for entropy-based patching") _, patch_lengths, _ = self.patcher( diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 8e97ab4d9ec0..aade85b7189f 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -158,6 +158,7 @@ def get_config(self): encoder_config=self.encoder_config, decoder_config=self.decoder_config, global_config=self.global_config, + rope_scaling=self.rope_scaling, tie_word_embeddings=False, ) @@ -269,7 +270,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): original_long_output = original_model(long_input).last_hidden_state set_seed(42) # Fixed seed at init time so the two models get the same random weights - config.rope_scaling = {"type": scaling_type, "factor": 10.0} + config.rope_scaling = {"rope_type": scaling_type, "factor": 10.0} # Propagate rope_scaling to sub-configs for BLT config.encoder_config.rope_scaling = config.rope_scaling config.decoder_config.rope_scaling = config.rope_scaling @@ -289,7 +290,6 @@ def test_model_rope_scaling_from_config(self, scaling_type): else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) @unittest.skip(reason="Training is not supported yet") From 141d788b9adc13c501a04ca5498cb5ec2d82b862 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 16 Jul 2025 12:17:17 +0000 Subject: [PATCH 078/139] more modern is_casual check --- src/transformers/models/blt/modeling_blt.py | 12 ++---------- src/transformers/models/blt/modular_blt.py | 11 +---------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index c2c64b3e27da..373154d4aa56 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -18,6 +18,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Callable, Optional, Union import torch @@ -451,7 +452,7 @@ def __init__(self, config, layer_idx: int): self.rope_theta = config.rope_theta self.layer_idx = layer_idx - self.is_causal = False + self.is_causal = True self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -503,13 +504,6 @@ def forward( else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # Check if we're in a decoder context by checking the layer index - # BLT decoder layers should use causal attention for correct generation - original_is_causal = self.is_causal - # If attention_mask is None, we're likely in a decoder that should be causal - if attention_mask is None: - self.is_causal = True - attn_output, attn_weights = attention_interface( self, query_states, @@ -521,8 +515,6 @@ def forward( **kwargs, ) - self.is_causal = original_is_causal - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index dcb3cf9d0837..f56a21654d74 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -432,7 +432,7 @@ def __init__(self, config, layer_idx: int): self.rope_theta = config.rope_theta self.layer_idx = layer_idx - self.is_causal = False + self.is_causal = True self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -484,13 +484,6 @@ def forward( else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # Check if we're in a decoder context by checking the layer index - # BLT decoder layers should use causal attention for correct generation - original_is_causal = self.is_causal - # If attention_mask is None, we're likely in a decoder that should be causal - if attention_mask is None: - self.is_causal = True - attn_output, attn_weights = attention_interface( self, query_states, @@ -502,8 +495,6 @@ def forward( **kwargs, ) - self.is_causal = original_is_causal - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) From fd1dd4ab4edc186378d45f7f507d8fd00efc1221 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 16 Jul 2025 15:09:00 +0000 Subject: [PATCH 079/139] cleaning up modular --- .../models/blt/configuration_blt.py | 18 +- src/transformers/models/blt/modeling_blt.py | 96 ++++----- src/transformers/models/blt/modular_blt.py | 202 ++---------------- 3 files changed, 63 insertions(+), 253 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index b60cdf5de150..d349c2bb4df1 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -39,7 +39,7 @@ def __init__( num_attention_heads=16, num_key_value_heads=None, num_hidden_layers=1, - norm_eps=1e-5, + rms_norm_eps=1e-5, dropout=0.0, max_position_embeddings=24576, rope_theta=500000.0, @@ -59,7 +59,7 @@ def __init__( self.head_dim = hidden_size // num_attention_heads self.intermediate_size = intermediate_size or int(8 * hidden_size / 3) self.num_hidden_layers = num_hidden_layers - self.norm_eps = norm_eps + self.rms_norm_eps = rms_norm_eps self.dropout = dropout self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta @@ -86,7 +86,7 @@ def __init__( num_attention_heads=16, num_key_value_heads=None, num_hidden_layers=9, - norm_eps=1e-5, + rms_norm_eps=1e-5, dropout=0.0, max_position_embeddings=24576, rope_theta=500000.0, @@ -105,7 +105,7 @@ def __init__( self.head_dim = hidden_size // num_attention_heads self.intermediate_size = intermediate_size or int(8 * hidden_size / 3) self.num_hidden_layers = num_hidden_layers - self.norm_eps = norm_eps + self.rms_norm_eps = rms_norm_eps self.dropout = dropout self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta @@ -128,7 +128,7 @@ def __init__( num_attention_heads=16, num_key_value_heads=None, num_hidden_layers=25, - norm_eps=1e-5, + rms_norm_eps=1e-5, dropout=0.0, max_position_embeddings=4096, rope_theta=500000.0, @@ -143,7 +143,7 @@ def __init__( self.head_dim = hidden_size // num_attention_heads self.intermediate_size = intermediate_size or int(8 * hidden_size / 3) self.num_hidden_layers = num_hidden_layers - self.norm_eps = norm_eps + self.rms_norm_eps = rms_norm_eps self.dropout = dropout self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta @@ -172,7 +172,7 @@ class BLTPatcherConfig(PretrainedConfig): Number of key-value heads in the entropy model. max_position_embeddings (`int`, *optional*, defaults to 1024): Maximum sequence length for the entropy model. - norm_eps (`float`, *optional*, defaults to 1e-5): + rms_norm_eps (`float`, *optional*, defaults to 1e-5): Layer normalization epsilon for the entropy model. dropout (`float`, *optional*, defaults to 0.0): Dropout probability for the entropy model. @@ -198,7 +198,7 @@ def __init__( num_attention_heads=12, num_key_value_heads=None, max_position_embeddings=8192, - norm_eps=1e-5, + rms_norm_eps=1e-5, dropout=0.0, rope_theta=10000.0, attn_bias_type="local_block_causal", @@ -213,7 +213,7 @@ def __init__( self.head_dim = hidden_size // num_attention_heads self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads self.max_position_embeddings = max_position_embeddings - self.norm_eps = norm_eps + self.rms_norm_eps = rms_norm_eps self.dropout = dropout self.rope_theta = rope_theta self.attn_bias_type = attn_bias_type diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 373154d4aa56..6caf3fd6625e 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -296,8 +296,8 @@ def __init__(self, config, layer_idx: int): self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) self.mlp = BLTMLP(config) - self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) - self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -439,7 +439,9 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class BLTSelfAttention(nn.Module): - def __init__(self, config, layer_idx: int): + """BLT variant of MllamaTextSelfAttention. Inherits all logic directly.""" + + def __init__(self, config: BLTConfig, layer_idx: int): super().__init__() self.config = config self.num_heads = config.num_attention_heads @@ -452,23 +454,21 @@ def __init__(self, config, layer_idx: int): self.rope_theta = config.rope_theta self.layer_idx = layer_idx - self.is_causal = True - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.is_causal = True def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value=None, + cache_position=None, **kwargs, ): bsz, q_len, _ = hidden_states.size() @@ -481,20 +481,16 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is not None: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = ( - {"sin": sin, "cos": cos, "cache_position": cache_position} - if position_embeddings is not None - else {"cache_position": cache_position} - ) + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( @@ -521,7 +517,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # BLT-SPECIFIC COMPONENTS (no Mllama equivalent) @@ -694,7 +690,7 @@ def __init__(self, config: BLTLocalDecoderConfig): bias=False, ) - self.norm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + self.norm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn_layers = torch.nn.ModuleList() layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 @@ -798,25 +794,22 @@ class BLTCrossAttention(nn.Module): def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None): super().__init__() self.config = config + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // self.num_heads self.layer_idx = layer_idx - # Use provided hidden_size or fallback to encoder dimension - self.hidden_size = hidden_size or config.encoder_config.hidden_size - self.num_heads = config.num_attention_heads - self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention - self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.dropout = config.dropout - - self.is_causal = False self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.q_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) - self.k_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) + self.q_norm = BLTRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = BLTRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.is_causal = False def forward( self, @@ -824,39 +817,39 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() - - query_states = self.q_norm(hidden_states) # BLT normalizes first - query_states = self.q_proj(query_states) + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_norm(query_states) if cross_attention_states is not None: - cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_norm(key_states) if past_key_value is not None: - # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states - # we still update the cross key states, past_cross_states, new_cross_states. And use it! + # if we have a new image + new tokens, we only computed key_states on that new image + # we still update the cross key states, past_image, new_image. And use it! key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - elif cache_position is not None and cache_position[0] != 0: + elif cache_position[0] != 0: key_states, value_states = ( past_key_value.key_cache[self.layer_idx], past_key_value.value_cache[self.layer_idx], ) else: - if cross_attention_states is None: - raise ValueError( - "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!" - ) + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) attention_interface: Callable = eager_attention_forward @@ -869,17 +862,13 @@ def forward( else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, - dropout=0.0, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, **kwargs, ) @@ -887,15 +876,10 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - if full_text_row_masked_out_mask is not None: - attn_output = full_text_row_masked_out_mask[:, 0] * attn_output - - attn_output = attn_output + hidden_states - if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class BLTGlobalTransformer(nn.Module): @@ -1385,7 +1369,7 @@ def __init__(self, config: BLTPatcherConfig): self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) - self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps) + self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.lm_head = nn.Linear( self.config.hidden_size, diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index f56a21654d74..45903552a256 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -14,7 +14,7 @@ # limitations under the License. """BLT modular model, inheriting from Mllama where appropriate.""" -from typing import Callable, Optional, Union +from typing import Optional, Union import torch import torch.distributions @@ -25,7 +25,6 @@ from ...generation.utils import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import is_torch_flex_attn_available, logging from .configuration_blt import ( BLTConfig, @@ -43,9 +42,10 @@ from ..mllama.modeling_mllama import ( MllamaPreTrainedModel, MllamaRotaryEmbedding, + MllamaTextCrossAttention, MllamaTextMLP, MllamaTextRMSNorm, - eager_attention_forward, + MllamaTextSelfAttention, ) @@ -347,8 +347,8 @@ def __init__(self, config, layer_idx: int): self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) self.mlp = BLTMLP(config) - self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) - self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -418,91 +418,13 @@ def forward( return outputs -class BLTSelfAttention(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - self.config = config - self.num_heads = config.num_attention_heads - self.dropout = config.dropout - self.hidden_size = config.hidden_size - self.num_key_value_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // self.num_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.rope_theta = config.rope_theta - self.layer_idx = layer_idx +class BLTSelfAttention(MllamaTextSelfAttention): + """BLT variant of MllamaTextSelfAttention. Inherits all logic directly.""" + def __init__(self, config: BLTConfig, layer_idx: int): + super().__init__(config, layer_idx) self.is_causal = True - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if position_embeddings is not None: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = ( - {"sin": sin, "cos": cos, "cache_position": cache_position} - if position_embeddings is not None - else {"cache_position": cache_position} - ) - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - # BLT-SPECIFIC COMPONENTS (no Mllama equivalent) @@ -674,7 +596,7 @@ def __init__(self, config: BLTLocalDecoderConfig): bias=False, ) - self.norm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + self.norm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn_layers = torch.nn.ModuleList() layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 @@ -772,110 +694,14 @@ def forward( return logits, all_hidden_states, all_attentions -class BLTCrossAttention(nn.Module): +class BLTCrossAttention(MllamaTextCrossAttention): """Cross-attention module for BLT, following transformers style""" def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None): super().__init__() - self.config = config - self.layer_idx = layer_idx - # Use provided hidden_size or fallback to encoder dimension - self.hidden_size = hidden_size or config.encoder_config.hidden_size - self.num_heads = config.num_attention_heads - self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.dropout = config.dropout - self.is_causal = False - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.q_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) - self.k_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - **kwargs, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_norm(hidden_states) # BLT normalizes first - query_states = self.q_proj(query_states) - - if cross_attention_states is not None: - cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - if past_key_value is not None: - # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states - # we still update the cross key states, past_cross_states, new_cross_states. And use it! - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - elif cache_position is not None and cache_position[0] != 0: - key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], - ) - else: - if cross_attention_states is None: - raise ValueError( - "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!" - ) - - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if full_text_row_masked_out_mask is not None: - attn_output = full_text_row_masked_out_mask[:, 0] * attn_output - - attn_output = attn_output + hidden_states - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + self.q_norm = BLTRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = BLTRMSNorm(self.head_dim, eps=config.rms_norm_eps) class BLTGlobalTransformer(nn.Module): @@ -1165,7 +991,7 @@ def __init__(self, config: BLTPatcherConfig): self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) - self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps) + self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.lm_head = nn.Linear( self.config.hidden_size, From 82bff4e9df3690d82e3a6306a88407549c3c7215 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 16 Jul 2025 15:35:46 +0000 Subject: [PATCH 080/139] more modular reduction --- src/transformers/models/blt/modeling_blt.py | 210 +++++++++--------- src/transformers/models/blt/modular_blt.py | 233 ++------------------ 2 files changed, 132 insertions(+), 311 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 6caf3fd6625e..404738cefeb9 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -30,7 +30,7 @@ from ...cache_utils import Cache from ...generation.utils import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -143,7 +143,7 @@ class BLTPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = False # BLT uses its own attention implementation + _supports_flash_attn_2 = False def _init_weights(self, module): std = self.config.initializer_range @@ -288,28 +288,33 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -class BLTTransformerLayer(GradientCheckpointingLayer): +# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer +class BLTTransformerLayer(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.layer_idx = layer_idx self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) self.mlp = BLTMLP(config) self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx + def forward( self, hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -317,18 +322,16 @@ def forward( attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. - position_ids (`torch.LongTensor`, *optional*): - Position indices of tokens in the sequence for RoPE computation. - past_key_value (`Cache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): @@ -336,9 +339,11 @@ def forward( into the model """ residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights, present_key_value = self.self_attn( + # Self Attention + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -351,6 +356,7 @@ def forward( ) hidden_states = residual + hidden_states + # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) @@ -361,9 +367,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -1203,14 +1206,13 @@ def forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.return_dict + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # Initialize collections to None - we will ONLY collect from decoder - all_hidden_states = None - all_attentions = None - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + # Explicit input validation (not XOR) + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") if input_ids is not None: batch_size, sequence_length = input_ids.shape @@ -1306,25 +1308,20 @@ def forward( ) # Only use decoder outputs (which match the expected num_hidden_layers) - if output_hidden_states and decoder_hidden_states_all is not None: - all_hidden_states = decoder_hidden_states_all - else: - all_hidden_states = None - - if output_attentions and decoder_attentions_all is not None: - all_attentions = decoder_attentions_all - else: - all_attentions = None + all_hidden_states = ( + decoder_hidden_states_all if output_hidden_states and decoder_hidden_states_all is not None else None + ) + all_attentions = decoder_attentions_all if output_attentions and decoder_attentions_all is not None else None if not return_dict: - output = (output,) + outputs = (output,) if past_key_values is not None: - output = output + (past_key_values,) + outputs = outputs + (past_key_values,) if all_hidden_states is not None: - output = output + (all_hidden_states,) + outputs = outputs + (all_hidden_states,) if all_attentions is not None: - output = output + (all_attentions,) - return output + outputs = outputs + (all_attentions,) + return outputs return BaseModelOutputWithPast( last_hidden_state=output, @@ -1359,18 +1356,13 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> class BLTPatcher(BLTPreTrainedModel): def __init__(self, config: BLTPatcherConfig): super().__init__(config) - self.rotary_emb = BLTRotaryEmbedding(config=self.config) - self.layers = nn.ModuleList() - for layer_idx in range(self.config.num_hidden_layers): self.layers.append(BLTTransformerLayer(self.config, layer_idx)) self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) - self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - self.lm_head = nn.Linear( self.config.hidden_size, self.config.vocab_size, @@ -1507,18 +1499,26 @@ def patch_lengths_from_entropies( return patch_lengths +@auto_docstring( + custom_intro=""" + The BLT Text Model with a language modeling head on top. + """ +) class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] config_class = BLTConfig + _supports_static_cache = True # only the LLM without cross attn can do compile base_model_prefix = "model" + _tied_weights_keys = ["lm_head.weight"] supports_gradient_checkpointing = True _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] def __init__(self, config): - super().__init__(config) - self.model = BLTModel(config) + super().__init__(config.get_text_config()) + self.text_config = config.get_text_config() self.vocab_size = config.vocab_size + self.model = BLTModel(config) self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) + self.post_init() def get_input_embeddings(self): @@ -1539,12 +1539,16 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, + cross_attention_states: Optional[torch.LongTensor] = None, + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1552,28 +1556,67 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Union[CausalLMOutputWithPast, tuple]: + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + cross_attention_states (`torch.FloatTensor`, *optional*): + Output of the vision model, used for cross-attention. This tensor contains the processed image features that + the language model will attend to. + cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*): + Cross-attention mask to control the interaction between text tokens and image tiles. + This 4D tensor defines which image tiles each text token should attend to. + + For each text token (in seq_length): + - 1 indicates the token **should attend** to the corresponding image tile + - 0 indicates the token **should not attend** to the corresponding image tile + full_text_row_masked_out_mask (`tuple[torch.Tensor, torch.Tensor]`, *optional*): + A tuple containing two tensors that mask out rows in the cross-attention mechanism: + - The first tensor has shape `(batch_size, 1, seq_length, 1)` and contains values of 0 or 1. + A value of 0 indicates that the corresponding text token's entire row in the cross-attention + matrix should be masked out (all image tokens ignored). + - The second tensor has the same shape and is used internally to apply the masking during + the forward pass of cross-attention layers. + This mask is derived from the cross_attention_mask and is used to handle cases where a text token + should not attend to any image token. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, BLTForCausalLM + + >>> model = BLTForCausalLM.from_pretrained("Llama-3.2-11B-Vision") + >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") + + >>> prompt = "If I had to write a haiku, it would be:" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) + >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(result) + If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. + I love the idea of snowflakes gently falling, each one + ``` """ - Args: - input_ids (torch.LongTensor): Input token ids. - attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Standard transformers arguments. - labels (torch.LongTensor, optional): Labels for language modeling loss. - Returns: - Union[CausalLMOutputWithPast, tuple]: Standard transformers output. - """ - # Set defaults from config when parameters are None output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.return_dict + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # Route only input_ids to BLTModel (as tokens) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( - input_ids, + input_ids=input_ids, + cross_attention_states=cross_attention_states, attention_mask=attention_mask, position_ids=position_ids, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -1584,55 +1627,24 @@ def forward( **kwargs, ) - if isinstance(outputs, dict): - sequence_output = outputs["last_hidden_state"] - past_key_values = outputs.get("past_key_values") - hidden_states = outputs.get("hidden_states") - attentions = outputs.get("attentions") - elif isinstance(outputs, tuple): - sequence_output = outputs[0] - # Handle tuple format: (output, past_key_values?, hidden_states?, attentions?) - idx = 1 - past_key_values = None - hidden_states = None - attentions = None - - if len(outputs) > idx and use_cache: - past_key_values = outputs[idx] - idx += 1 - - if len(outputs) > idx and output_hidden_states: - hidden_states = outputs[idx] - idx += 1 - - if len(outputs) > idx and output_attentions: - attentions = outputs[idx] - idx += 1 - else: - sequence_output = outputs - past_key_values = None - hidden_states = None - attentions = None + hidden_states = outputs[0] + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]).float() - logits = self.lm_head(sequence_output) loss = None if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss_fct = torch.nn.CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + if not return_dict: - output = (logits,) - if loss is not None: - output = (loss,) + output - return output + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + return CausalLMOutputWithPast( loss=loss, logits=logits, - past_key_values=past_key_values, - hidden_states=hidden_states, - attentions=attentions, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 45903552a256..cb4c6858cdf3 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -22,9 +22,7 @@ import torch.nn.functional as F from ...cache_utils import Cache -from ...generation.utils import GenerationMixin -from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast from ...utils import is_torch_flex_attn_available, logging from .configuration_blt import ( BLTConfig, @@ -40,8 +38,10 @@ from ..mllama.modeling_mllama import ( + MllamaForCausalLM, MllamaPreTrainedModel, MllamaRotaryEmbedding, + MllamaSelfAttentionDecoderLayer, MllamaTextCrossAttention, MllamaTextMLP, MllamaTextRMSNorm, @@ -304,9 +304,6 @@ def __init__(self, config: BLTConfig, device=None): ) -# INHERITED BUT CUSTOMIZED COMPONENTS - - class BLTPreTrainedModel(MllamaPreTrainedModel): """BLT PreTrainedModel inheriting from Mllama but with BLT-specific init.""" @@ -315,7 +312,7 @@ class BLTPreTrainedModel(MllamaPreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = False # BLT uses its own attention implementation + _supports_flash_attn_2 = False _supports_sdpa = True _supports_cache_class = False @@ -339,84 +336,15 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) -class BLTTransformerLayer(GradientCheckpointingLayer): +class BLTTransformerLayer(MllamaSelfAttentionDecoderLayer): def __init__(self, config, layer_idx: int): super().__init__() - self.hidden_size = config.hidden_size - self.layer_idx = layer_idx self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) self.mlp = BLTMLP(config) self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - position_ids (`torch.LongTensor`, *optional*): - Position indices of tokens in the sequence for RoPE computation. - past_key_value (`Cache`, *optional*): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - class BLTSelfAttention(MllamaTextSelfAttention): """BLT variant of MllamaTextSelfAttention. Inherits all logic directly.""" @@ -825,14 +753,13 @@ def forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # Initialize collections to None - we will ONLY collect from decoder - all_hidden_states = None - all_attentions = None + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + # Explicit input validation (not XOR) + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") if input_ids is not None: batch_size, sequence_length = input_ids.shape @@ -928,25 +855,20 @@ def forward( ) # Only use decoder outputs (which match the expected num_hidden_layers) - if output_hidden_states and decoder_hidden_states_all is not None: - all_hidden_states = decoder_hidden_states_all - else: - all_hidden_states = None - - if output_attentions and decoder_attentions_all is not None: - all_attentions = decoder_attentions_all - else: - all_attentions = None + all_hidden_states = ( + decoder_hidden_states_all if output_hidden_states and decoder_hidden_states_all is not None else None + ) + all_attentions = decoder_attentions_all if output_attentions and decoder_attentions_all is not None else None if not return_dict: - output = (output,) + outputs = (output,) if past_key_values is not None: - output = output + (past_key_values,) + outputs = outputs + (past_key_values,) if all_hidden_states is not None: - output = output + (all_hidden_states,) + outputs = outputs + (all_hidden_states,) if all_attentions is not None: - output = output + (all_attentions,) - return output + outputs = outputs + (all_attentions,) + return outputs return BaseModelOutputWithPast( last_hidden_state=output, @@ -981,18 +903,13 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> class BLTPatcher(BLTPreTrainedModel): def __init__(self, config: BLTPatcherConfig): super().__init__(config) - self.rotary_emb = BLTRotaryEmbedding(config=self.config) - self.layers = nn.ModuleList() - for layer_idx in range(self.config.num_hidden_layers): self.layers.append(BLTTransformerLayer(self.config, layer_idx)) self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) - self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - self.lm_head = nn.Linear( self.config.hidden_size, self.config.vocab_size, @@ -1129,7 +1046,7 @@ def patch_lengths_from_entropies( return patch_lengths -class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): +class BLTForCausalLM(MllamaForCausalLM): _tied_weights_keys = ["lm_head.weight"] config_class = BLTConfig base_model_prefix = "model" @@ -1149,114 +1066,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.local_encoder.embed_tokens = value - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Union[CausalLMOutputWithPast, tuple]: - """ - Args: - input_ids (torch.LongTensor): Input token ids. - attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Standard transformers arguments. - labels (torch.LongTensor, optional): Labels for language modeling loss. - Returns: - Union[CausalLMOutputWithPast, tuple]: Standard transformers output. - """ - # Set defaults from config when parameters are None - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # Route only input_ids to BLTModel (as tokens) - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - if isinstance(outputs, dict): - sequence_output = outputs["last_hidden_state"] - past_key_values = outputs.get("past_key_values") - hidden_states = outputs.get("hidden_states") - attentions = outputs.get("attentions") - elif isinstance(outputs, tuple): - sequence_output = outputs[0] - # Handle tuple format: (output, past_key_values?, hidden_states?, attentions?) - idx = 1 - past_key_values = None - hidden_states = None - attentions = None - - if len(outputs) > idx and use_cache: - past_key_values = outputs[idx] - idx += 1 - - if len(outputs) > idx and output_hidden_states: - hidden_states = outputs[idx] - idx += 1 - - if len(outputs) > idx and output_attentions: - attentions = outputs[idx] - idx += 1 - else: - sequence_output = outputs - past_key_values = None - hidden_states = None - attentions = None - - logits = self.lm_head(sequence_output) - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss_fct = torch.nn.CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - if not return_dict: - output = (logits,) - if loss is not None: - output = (loss,) + output - return output - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=past_key_values, - hidden_states=hidden_states, - attentions=attentions, - ) - __all__ = [ "BLTPreTrainedModel", From 6f174745210cc2c1b461e95328fbd1d47a912c92 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 16 Jul 2025 15:45:35 +0000 Subject: [PATCH 081/139] ruff --- src/demo_hf.py | 47 ------------------- src/transformers/models/__init__.py | 2 +- .../models/auto/tokenization_auto.py | 2 +- 3 files changed, 2 insertions(+), 49 deletions(-) delete mode 100644 src/demo_hf.py diff --git a/src/demo_hf.py b/src/demo_hf.py deleted file mode 100644 index 97ad40a296bf..000000000000 --- a/src/demo_hf.py +++ /dev/null @@ -1,47 +0,0 @@ -import logging -import os - -import torch - -from transformers.models.blt.modeling_blt import BLTForCausalLM - -from transformers.models.blt.tokenization_blt import BLTTokenizer - -logger = logging.getLogger() - -os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" - -import gc -gc.collect() -torch.cuda.empty_cache() - - -def main(prompt: str = "my name is", model_name: str = "blt-1b"): - device = "cuda" - - model = BLTForCausalLM.from_pretrained( - "itazap/blt-1b" - ).to(device) - - tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) - - input_ids = torch.tensor([tokenizer.encode(prompt, add_eos=False)]).to(device) - - with torch.no_grad(): - output_ids = model.generate( - input_ids, - max_new_tokens=200, - do_sample=False, - temperature=1.0 - ) - - generated_ids = output_ids[0][len(input_ids[0]):] - output_text = tokenizer.decode(generated_ids.tolist()) - - print(f'Prompt: "{prompt}"') - print(f'Completion: "{output_text}"') - print('here') - - -if __name__ == "__main__": - main() diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5cde03fc4b5a..e6ef726f7565 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -389,4 +389,4 @@ import sys _file = globals()["__file__"] - sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) \ No newline at end of file + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 3ddbfacf0311..b79a5d5c8df3 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -1213,4 +1213,4 @@ def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class), exist_ok=exist_ok) -__all__ = ["TOKENIZER_MAPPING", "AutoTokenizer"] \ No newline at end of file +__all__ = ["TOKENIZER_MAPPING", "AutoTokenizer"] From c191396e7f995efc933cee7970f0ee165c765c1a Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 16 Jul 2025 16:03:33 +0000 Subject: [PATCH 082/139] modular fix --- src/transformers/models/blt/modeling_blt.py | 12 +++++++----- src/transformers/models/blt/modular_blt.py | 6 +++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 404738cefeb9..3144317f189a 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache -from ...generation.utils import GenerationMixin +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -131,19 +131,21 @@ def forward(self, x, position_ids): class BLTPreTrainedModel(PreTrainedModel): """BLT PreTrainedModel inheriting from Mllama but with BLT-specific init.""" - config_class = BLTConfig + config: BLTConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] - _supports_cache_class = False + _supports_static_cache = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn = True - _supports_quantized_cache = True _supports_flex_attn = True _supports_attention_backend = True + + config_class = BLTConfig _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = False + _supports_cache_class = False def _init_weights(self, module): std = self.config.initializer_range @@ -1505,7 +1507,7 @@ def patch_lengths_from_entropies( """ ) class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): - config_class = BLTConfig + config: BLTConfig _supports_static_cache = True # only the LLM without cross attn can do compile base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index cb4c6858cdf3..1d920f6a0e0a 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -1047,10 +1047,10 @@ def patch_lengths_from_entropies( class BLTForCausalLM(MllamaForCausalLM): - _tied_weights_keys = ["lm_head.weight"] - config_class = BLTConfig - base_model_prefix = "model" + config: BLTConfig supports_gradient_checkpointing = True + base_model_prefix = "model" + _tied_weights_keys = ["lm_head.weight"] _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] def __init__(self, config): From 50c0353be13746cf118764ca8262ce2427634a1a Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 16 Jul 2025 16:11:22 +0000 Subject: [PATCH 083/139] fix styling --- src/transformers/models/auto/configuration_auto.py | 4 ++-- src/transformers/models/auto/modeling_auto.py | 4 ++-- utils/check_repo.py | 1 + 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 3ddd8f367c84..47d4673a3b8e 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -65,6 +65,7 @@ ("blip-2", "Blip2Config"), ("blip_2_qformer", "Blip2QFormerConfig"), ("bloom", "BloomConfig"), + ("blt", "BLTConfig"), ("bridgetower", "BridgeTowerConfig"), ("bros", "BrosConfig"), ("camembert", "CamembertConfig"), @@ -226,7 +227,6 @@ ("lightglue", "LightGlueConfig"), ("lilt", "LiltConfig"), ("llama", "LlamaConfig"), - ("blt", "BLTConfig"), ("llama4", "Llama4Config"), ("llama4_text", "Llama4TextConfig"), ("llava", "LlavaConfig"), @@ -491,6 +491,7 @@ ("blip-2", "BLIP-2"), ("blip_2_qformer", "BLIP-2 QFormer"), ("bloom", "BLOOM"), + ("blt", "BLT"), ("bort", "BORT"), ("bridgetower", "BridgeTower"), ("bros", "BROS"), @@ -664,7 +665,6 @@ ("lightglue", "LightGlue"), ("lilt", "LiLT"), ("llama", "LLaMA"), - ("blt", "BLT"), ("llama2", "Llama2"), ("llama3", "Llama3"), ("llama4", "Llama4"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 961ea5ddd91c..d758acd687bd 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -73,6 +73,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("blip_2_qformer", "Blip2QFormerModel"), ("bloom", "BloomModel"), ("blt", "BLTModel"), + ("blt", "BLTModel"), ("bridgetower", "BridgeTowerModel"), ("bros", "BrosModel"), ("camembert", "CamembertModel"), @@ -227,7 +228,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("lightglue", "LightGlueForKeypointMatching"), ("lilt", "LiltModel"), ("llama", "LlamaModel"), - ("blt", "BLTModel"), ("llama4", "Llama4ForConditionalGeneration"), ("llama4_text", "Llama4TextModel"), ("llava", "LlavaModel"), @@ -635,6 +635,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("blenderbot", "BlenderbotForCausalLM"), ("blenderbot-small", "BlenderbotSmallForCausalLM"), ("bloom", "BloomForCausalLM"), + ("blt", "BLTForCausalLM"), ("camembert", "CamembertForCausalLM"), ("code_llama", "LlamaForCausalLM"), ("codegen", "CodeGenForCausalLM"), @@ -690,7 +691,6 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("jetmoe", "JetMoeForCausalLM"), ("lfm2", "Lfm2ForCausalLM"), ("llama", "LlamaForCausalLM"), - ("blt", "BLTForCausalLM"), ("llama4", "Llama4ForCausalLM"), ("llama4_text", "Llama4ForCausalLM"), ("longcat_flash", "LongcatFlashForCausalLM"), diff --git a/utils/check_repo.py b/utils/check_repo.py index e932e5bfc24c..77a2a0a2e258 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -180,6 +180,7 @@ "CsmDepthDecoderForCausalLM", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. "CsmDepthDecoderModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. "CsmBackboneModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. + "BLTPatcher", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. "Florence2VisionBackbone", # Building part of bigger (tested) model. Tested implicitly through Florence2ForConditionalGeneration. ] ) From 36a95538bee2f2c4f6242660392eaa15e2925ca1 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 17 Jul 2025 09:44:21 +0000 Subject: [PATCH 084/139] return 2 --- src/transformers/models/blt/modeling_blt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 3144317f189a..3d9c81602717 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -630,7 +630,7 @@ def forward( ) layer_idx = idx if self.config.cross_attn_all_layers else 0 - cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( + cross_attention_output, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, attention_mask=cross_mask, From 132830e0389814049d85aa9c47d338a132c155b3 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 17 Jul 2025 09:51:11 +0000 Subject: [PATCH 085/139] return 2 --- src/transformers/models/blt/modeling_blt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 3d9c81602717..e1ba62a2c1ff 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -750,7 +750,7 @@ def forward( if i == 0 or self.config.cross_attn_all_layers: # Use cross attention to extract info from patch_embeds into hidden_states - cross_attention_output, _, _ = self.cross_attn_layers[i]( + cross_attention_output, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, cross_attention_states=patch_embeds, attention_mask=cross_mask, From 9f3a3b42ba0dae50b46d84c8f28e1c1625b6faaf Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 17 Jul 2025 14:40:33 +0000 Subject: [PATCH 086/139] fix some tests --- src/transformers/models/blt/modeling_blt.py | 1 - src/transformers/models/blt/modular_blt.py | 5 ++--- utils/check_docstrings.py | 3 +++ 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index e1ba62a2c1ff..ec5d129a1aea 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -992,7 +992,6 @@ def byte_group_hash_function( hashes = rolling_polynomial_hash(windows, hash_func_nb) hash_values = hashes % max_hash - hash_values.requires_grad = False return hash_values diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 1d920f6a0e0a..a251d3c32d09 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -120,7 +120,6 @@ def byte_group_hash_function( hashes = rolling_polynomial_hash(windows, hash_func_nb) hash_values = hashes % max_hash - hash_values.requires_grad = False return hash_values @@ -459,7 +458,7 @@ def forward( ) layer_idx = idx if self.config.cross_attn_all_layers else 0 - cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( + cross_attention_output, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, attention_mask=cross_mask, @@ -579,7 +578,7 @@ def forward( if i == 0 or self.config.cross_attn_all_layers: # Use cross attention to extract info from patch_embeds into hidden_states - cross_attention_output, _, _ = self.cross_attn_layers[i]( + cross_attention_output, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, cross_attention_states=patch_embeds, attention_mask=cross_mask, diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 9eeda74afa48..81cfe7fd7e27 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -128,6 +128,9 @@ "BlipVisionConfig", "BloomConfig", "BloomTokenizerFast", + "BLTConfig", + "BLTPatcherConfig", + "BLTTokenizer", "BridgeTowerTextConfig", "BridgeTowerVisionConfig", "BrosModel", From 809530348ba735c40043e27c4050361aacb43b3d Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 18 Jul 2025 16:43:29 +0000 Subject: [PATCH 087/139] fix bltcrossattention after modular break --- src/demo_hf.py | 47 +++++++++++++++++++++ src/transformers/models/blt/modeling_blt.py | 15 ++++--- tests/models/blt/test_modeling_blt.py | 38 ++++++++--------- 3 files changed, 75 insertions(+), 25 deletions(-) create mode 100644 src/demo_hf.py diff --git a/src/demo_hf.py b/src/demo_hf.py new file mode 100644 index 000000000000..3074b0676c41 --- /dev/null +++ b/src/demo_hf.py @@ -0,0 +1,47 @@ +import logging +import os + +import torch + +from transformers.models.blt.modeling_blt import BLTForCausalLM + +from transformers.models.blt.tokenization_blt import BLTTokenizer + +from transformers import AutoTokenizer + +logger = logging.getLogger() + +os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" + +import gc +gc.collect() +torch.cuda.empty_cache() + + +def main(prompt: str = "my name is", model_name: str = ""): + + model = BLTForCausalLM.from_pretrained( + model_name, + device_map="auto", + # attn_implementation="eager" + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + generated_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False, use_cache=False) + + output_text = tokenizer.decode(generated_ids[0]) + + print(f'Model: "{model_name}"') + # print(f'Prompt: "{prompt}"') + print(f'Completion: "{output_text}"') + + +if __name__ == "__main__": + # SNAPSHOT_PATH = os.path.expanduser("~/.cache/huggingface/hub/models--itazap--blt-1b/snapshots/bb8c23be2c2f065f0ee315ec2066ac1d3c78722a") + # main(model_name=SNAPSHOT_PATH) + main(model_name="itazap/blt-1b-testing") + diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index ec5d129a1aea..72ecb1d22c17 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -812,8 +812,9 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.q_norm = BLTRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = BLTRMSNorm(self.head_dim, eps=config.rms_norm_eps) + # needs to stay hidden_size, NOT head_dim + self.q_norm = BLTRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.k_norm = BLTRMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.is_causal = False def forward( @@ -829,17 +830,18 @@ def forward( ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) + + query_states = self.q_norm(hidden_states) # BLT normalizes first + query_states = self.q_proj(query_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = self.q_norm(query_states) + if cross_attention_states is not None: + cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - key_states = self.k_norm(key_states) if past_key_value is not None: # if we have a new image + new tokens, we only computed key_states on that new image # we still update the cross key states, past_image, new_image. And use it! @@ -880,6 +882,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) + attn_output = attn_output + hidden_states if not output_attentions: attn_weights = None diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index aade85b7189f..d5eb7b4c94d0 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -76,7 +76,7 @@ def __init__( self.vocab_size = 32 self.rope_theta = 500000.0 self.rope_scaling = {"rope_type": "default"} - self.norm_eps = 1e-5 + self.rms_norm_eps = 1e-5 self.dropout = 0.0 self.encoder_hash_byte_group_size = [2, 3] self.encoder_hash_byte_group_vocab = 64 @@ -92,7 +92,7 @@ def __init__( "rope_theta": self.rope_theta, "rope_scaling": self.rope_scaling, "hidden_act": self.hidden_act, - "norm_eps": self.norm_eps, + "rms_norm_eps": self.rms_norm_eps, "dropout": self.dropout, } @@ -106,7 +106,7 @@ def __init__( "rope_theta": self.rope_theta, "rope_scaling": self.rope_scaling, "hidden_act": self.hidden_act, - "norm_eps": self.norm_eps, + "rms_norm_eps": self.rms_norm_eps, "dropout": self.dropout, } @@ -122,7 +122,7 @@ def __init__( "rope_theta": self.rope_theta, "rope_scaling": self.rope_scaling, "hidden_act": self.hidden_act, - "norm_eps": self.norm_eps, + "rms_norm_eps": self.rms_norm_eps, "dropout": self.dropout, } @@ -136,7 +136,7 @@ def __init__( "rope_theta": self.rope_theta, "rope_scaling": self.rope_scaling, "hidden_act": self.hidden_act, - "norm_eps": self.norm_eps, + "rms_norm_eps": self.rms_norm_eps, "dropout": self.dropout, } @@ -331,15 +331,15 @@ def test_blt(self): prompt = "my name is" model = BLTForCausalLM.from_pretrained( - "itazap/blt-1b", + "itazap/blt-1b-testing", device_map="auto", ) - tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b") + tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False) output_text = tokenizer.decode(generated_ids[0]) self.assertEqual(output_text, EXPECTED_TEXT) @@ -418,7 +418,7 @@ def test_model_logits(self): input_ids = [1, 42, 21, 12, 43, 23, 1, 4] - model = BLTForCausalLM.from_pretrained("itazap/blt-1b", device_map="auto") + model = BLTForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto") with torch.no_grad(): output = model(torch.tensor([input_ids]).to(torch_device))[0] @@ -435,13 +435,13 @@ def test_model_bf16(self): prompt = "my name is" - model = BLTForCausalLM.from_pretrained("itazap/blt-1b", device_map="auto", torch_dtype=torch.bfloat16) + model = BLTForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto", torch_dtype=torch.bfloat16) - tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b") + tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False) output_text = tokenizer.decode(generated_ids[0]) self.assertEqual(output_text, EXPECTED_TEXT) @@ -523,7 +523,7 @@ def test_model_logits_bf16(self): input_ids = [1, 42, 21, 12, 43, 23, 1, 4] - model = BLTForCausalLM.from_pretrained("itazap/blt-1b", device_map="auto", torch_dtype=torch.bfloat16) + model = BLTForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto", torch_dtype=torch.bfloat16) with torch.no_grad(): output = model(torch.tensor([input_ids]).to(torch_device))[0] @@ -541,13 +541,13 @@ def test_model_eager(self): prompt = "my name is" - model = BLTForCausalLM.from_pretrained("itazap/blt-1b", device_map="auto", attn_implementation="eager") + model = BLTForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto", attn_implementation="eager") - tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b") + tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False) output_text = tokenizer.decode(generated_ids[0]) self.assertEqual(output_text, EXPECTED_TEXT) @@ -562,15 +562,15 @@ def test_model_bf16_static_cache(self): prompt = "my name is" - model = BLTForCausalLM.from_pretrained("itazap/blt-1b", device_map="auto", torch_dtype=torch.bfloat16) + model = BLTForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto", torch_dtype=torch.bfloat16) model.generation_config.cache_implementation = "static" - tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b") + tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False) output_text = tokenizer.decode(generated_ids[0]) self.assertEqual(output_text, EXPECTED_TEXT) From 562c03ada894aa397fe3dd37e9b308bbf996c7af Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 21 Jul 2025 14:01:52 +0000 Subject: [PATCH 088/139] some fixes / feedback --- .../models/blt/configuration_blt.py | 6 - src/transformers/models/blt/modeling_blt.py | 84 +++++++------ src/transformers/models/blt/modular_blt.py | 111 +++++++++++++----- tests/models/blt/test_modeling_blt.py | 28 ++++- 4 files changed, 152 insertions(+), 77 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index d349c2bb4df1..102aa8bb1792 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -182,8 +182,6 @@ class BLTPatcherConfig(PretrainedConfig): Make feedforward dimension multiple of this for the entropy model. rope_theta (`float`, *optional*, defaults to 10000.0): RoPE theta parameter for the entropy model. - _attn_implementation (`str`, *optional*, defaults to "sdpa"): - Attention implementation for the entropy model. attn_bias_type (`str`, *optional*, defaults to "causal"): Attention bias type for the entropy model. """ @@ -237,10 +235,6 @@ class BLTConfig(PretrainedConfig): Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented. max_position_embeddings (`int`, *optional*, defaults to 1024): The maximum sequence length that this model can handle. - _attn_implementation (`str`, *optional*, defaults to "sdpa"): - The attention implementation to use. Can be "eager", "sdpa", etc. This setting is propagated to all - sub-components (encoder, decoder, global transformer, patcher). - # Patching configuration patch_in_forward (`bool`, *optional*, defaults to False): Whether to perform patching during forward pass. diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 72ecb1d22c17..f172c8856ab5 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -5,7 +5,7 @@ # modular_blt.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 -# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved. +# Copyright 2025 HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -419,8 +419,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): sin_freqs = sin[..., : head_dim // 2] # [B, S, D/2] # Expand cos/sin to match query/key tensor format [B, H, S, D/2] - cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] - sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + cos_freqs = cos_freqs.unsqueeze(1) # [B, 1, S, D/2] -> [B, H, S, D/2] + sin_freqs = sin_freqs.unsqueeze(1) # [B, 1, S, D/2] -> [B, H, S, D/2] # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... q_pairs = q.view(*q.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] @@ -549,7 +549,7 @@ def __init__(self, config: BLTLocalEncoderConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.cross_attn_layers = torch.nn.ModuleList() + self.cross_attn_layers = nn.ModuleList() layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): self.cross_attn_layers.append( @@ -604,7 +604,7 @@ def forward( layer.__call__, hidden_states, position_embeddings, - None, # attention_mask + attention_mask, None, # past_key_value False, # output_attentions False, # use_cache @@ -614,7 +614,7 @@ def forward( layer_outputs = layer( hidden_states, position_embeddings=position_embeddings, - attention_mask=None, + attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] @@ -697,7 +697,7 @@ def __init__(self, config: BLTLocalDecoderConfig): self.norm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.cross_attn_layers = torch.nn.ModuleList() + self.cross_attn_layers = nn.ModuleList() layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): self.cross_attn_layers.append( @@ -766,7 +766,7 @@ def forward( layer.__call__, hidden_states, position_embeddings, - None, # attention_mask + attention_mask, None, # past_key_value False, # output_attentions False, # use_cache @@ -776,7 +776,7 @@ def forward( layer_outputs = layer( hidden_states, position_embeddings=position_embeddings, - attention_mask=None, + attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] @@ -939,7 +939,7 @@ def forward( layer.__call__, hidden_states, position_embeddings, - None, # attention_mask + attention_mask, None, # past_key_value False, # output_attentions False, # use_cache @@ -949,7 +949,7 @@ def forward( layer_outputs = layer( hidden_states, position_embeddings=position_embeddings, - attention_mask=None, + attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] @@ -998,13 +998,6 @@ def byte_group_hash_function( return hash_values -def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): - """Initialize hash-based token embeddings for the BLT encoder.""" - num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) - embeddings = [nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) for _ in range(num_embeddings)] - return nn.ModuleList(embeddings) - - def compute_hash_embeddings( local_encoder_tokens: torch.Tensor, local_encoder, @@ -1166,11 +1159,13 @@ def __init__(self, config: BLTConfig): self.local_encoder = BLTLocalEncoder(config.encoder_config) self.global_transformer = BLTGlobalTransformer(config.global_config) self.local_decoder = BLTLocalDecoder(config.decoder_config) - self.encoder_hash_tok_embedding = init_hash_embeddings( - config, - local_encoder_dim=config.encoder_config.hidden_size, - encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, - ) + + num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size) + embeddings = [ + nn.Embedding(config.encoder_hash_byte_group_vocab, config.encoder_config.hidden_size) + for _ in range(num_embeddings) + ] + self.encoder_hash_tok_embedding = nn.ModuleList(embeddings) if self.config.patch_in_forward: self.patcher = BLTPatcher(config.patcher_config) self.patcher.eval() @@ -1179,7 +1174,6 @@ def __init__(self, config: BLTConfig): else: self.patcher = None - # Initialize weights and apply final processing self.post_init() def forward( @@ -1264,7 +1258,9 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = attention_mask + causal_mask = self._update_causal_mask( + attention_mask, encoder_embeds, cache_position, past_key_values, output_attentions + ) cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype @@ -1274,6 +1270,7 @@ def forward( input_ids=input_ids, input_embeds=encoder_embeds, patch_embeds=None, + attention_mask=causal_mask, cross_mask=cross_attn_mask_enc, full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, num_patches=patch_lengths.shape[1], @@ -1284,10 +1281,21 @@ def forward( ) global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - global_hidden_states, global_hidden_states_all, global_attentions_all = self.global_transformer( + global_cache_position = torch.arange( + 0, global_hidden_states.shape[1], device=global_hidden_states.device + ) + global_causal_mask = self._update_causal_mask( + None, # attention_mask + global_hidden_states, + global_cache_position, + None, # past_key_values + False # output_attentions + ) + global_hidden_states, _, _ = self.global_transformer( input_embeds=global_hidden_states, - output_attentions=False, # Don't collect global transformer attentions - output_hidden_states=False, # Don't collect global transformer hidden states + attention_mask=global_causal_mask, + output_attentions=False, + output_hidden_states=False, ) decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) @@ -1409,8 +1417,17 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) + cache_position = torch.arange(sequence_length, device=input_embeds.device) + causal_mask = self._update_causal_mask( + None, # attention_mask + input_embeds, + cache_position, + None, # past_key_values + False # output_attentions + ) + for i, layer in enumerate(self.layers): - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask) hidden_states = layer_outputs[0] logits = self.lm_head(self.norm(hidden_states)) @@ -1656,14 +1673,5 @@ def forward( "BLTPreTrainedModel", "BLTModel", "BLTPatcher", - "BLTLocalEncoder", - "BLTLocalDecoder", - "BLTGlobalTransformer", - "BLTTransformerLayer", "BLTForCausalLM", - "BLTMLP", - "BLTRMSNorm", - "BLTRotaryEmbedding", - "BLTSelfAttention", - "BLTCrossAttention", ] diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index a251d3c32d09..ca9b037922b9 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved. +# Copyright 2025 HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -61,8 +61,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): sin_freqs = sin[..., : head_dim // 2] # [B, S, D/2] # Expand cos/sin to match query/key tensor format [B, H, S, D/2] - cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] - sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + cos_freqs = cos_freqs.unsqueeze(1) # [B, 1, S, D/2] -> [B, H, S, D/2] + sin_freqs = sin_freqs.unsqueeze(1) # [B, 1, S, D/2] -> [B, H, S, D/2] # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... q_pairs = q.view(*q.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] @@ -123,13 +123,6 @@ def byte_group_hash_function( return hash_values -def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): - """Initialize hash-based token embeddings for the BLT encoder.""" - num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) - embeddings = [nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) for _ in range(num_embeddings)] - return nn.ModuleList(embeddings) - - def compute_hash_embeddings( local_encoder_tokens: torch.Tensor, local_encoder, @@ -377,7 +370,7 @@ def __init__(self, config: BLTLocalEncoderConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.cross_attn_layers = torch.nn.ModuleList() + self.cross_attn_layers = nn.ModuleList() layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): self.cross_attn_layers.append( @@ -525,7 +518,7 @@ def __init__(self, config: BLTLocalDecoderConfig): self.norm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.cross_attn_layers = torch.nn.ModuleList() + self.cross_attn_layers = nn.ModuleList() layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): self.cross_attn_layers.append( @@ -630,6 +623,78 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.q_norm = BLTRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = BLTRMSNorm(self.head_dim, eps=config.rms_norm_eps) + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(hidden_states) # BLT normalizes first + query_states = self.q_proj(query_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + + if cross_attention_states is not None: + cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if past_key_value is not None: + # if we have a new image + new tokens, we only computed key_states on that new image + # we still update the cross key states, past_image, new_image. And use it! + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + attn_output = attn_output + hidden_states + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + class BLTGlobalTransformer(nn.Module): def __init__(self, config: BLTGlobalTransformerConfig): @@ -708,11 +773,13 @@ def __init__(self, config: BLTConfig): self.local_encoder = BLTLocalEncoder(config.encoder_config) self.global_transformer = BLTGlobalTransformer(config.global_config) self.local_decoder = BLTLocalDecoder(config.decoder_config) - self.encoder_hash_tok_embedding = init_hash_embeddings( - config, - local_encoder_dim=config.encoder_config.hidden_size, - encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, - ) + + num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size) + embeddings = [ + nn.Embedding(config.encoder_hash_byte_group_vocab, config.encoder_config.hidden_size) + for _ in range(num_embeddings) + ] + self.encoder_hash_tok_embedding = nn.ModuleList(embeddings) if self.config.patch_in_forward: self.patcher = BLTPatcher(config.patcher_config) self.patcher.eval() @@ -721,7 +788,6 @@ def __init__(self, config: BLTConfig): else: self.patcher = None - # Initialize weights and apply final processing self.post_init() def forward( @@ -1070,14 +1136,5 @@ def set_input_embeddings(self, value): "BLTPreTrainedModel", "BLTModel", "BLTPatcher", - "BLTLocalEncoder", - "BLTLocalDecoder", - "BLTGlobalTransformer", - "BLTTransformerLayer", "BLTForCausalLM", - "BLTMLP", - "BLTRMSNorm", - "BLTRotaryEmbedding", - "BLTSelfAttention", - "BLTCrossAttention", ] diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index d5eb7b4c94d0..2eda2b08d21c 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -324,7 +324,7 @@ def tearDown(self): @slow @require_read_token - def test_blt(self): + def test_model(self): NUM_TOKENS_TO_GENERATE = 200 EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s" @@ -333,6 +333,7 @@ def test_blt(self): model = BLTForCausalLM.from_pretrained( "itazap/blt-1b-testing", device_map="auto", + attn_implementation="sdpa" ) tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") @@ -418,7 +419,7 @@ def test_model_logits(self): input_ids = [1, 42, 21, 12, 43, 23, 1, 4] - model = BLTForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto") + model = BLTForCausalLM.from_pretrained("itazap/blt-1b-testing", attn_implementation="sdpa", device_map="auto") with torch.no_grad(): output = model(torch.tensor([input_ids]).to(torch_device))[0] @@ -435,7 +436,11 @@ def test_model_bf16(self): prompt = "my name is" - model = BLTForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto", torch_dtype=torch.bfloat16) + model = BLTForCausalLM.from_pretrained( + "itazap/blt-1b-testing", + device_map="auto", + attn_implementation="sdpa", + torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") @@ -523,7 +528,11 @@ def test_model_logits_bf16(self): input_ids = [1, 42, 21, 12, 43, 23, 1, 4] - model = BLTForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto", torch_dtype=torch.bfloat16) + model = BLTForCausalLM.from_pretrained( + "itazap/blt-1b-testing", + device_map="auto", + attn_implementation="sdpa", + torch_dtype=torch.bfloat16) with torch.no_grad(): output = model(torch.tensor([input_ids]).to(torch_device))[0] @@ -541,7 +550,10 @@ def test_model_eager(self): prompt = "my name is" - model = BLTForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto", attn_implementation="eager") + model = BLTForCausalLM.from_pretrained( + "itazap/blt-1b-testing", + device_map="auto", + attn_implementation="eager") tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") @@ -562,7 +574,11 @@ def test_model_bf16_static_cache(self): prompt = "my name is" - model = BLTForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto", torch_dtype=torch.bfloat16) + model = BLTForCausalLM.from_pretrained( + "itazap/blt-1b-testing", + device_map="auto", + attn_implementation="sdpa", + torch_dtype=torch.bfloat16) model.generation_config.cache_implementation = "static" From fc1e7bfad08c1af5ca337c45d4ecd1d801ea0d6c Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 21 Jul 2025 14:51:48 +0000 Subject: [PATCH 089/139] try cache generate fix --- src/transformers/models/blt/modeling_blt.py | 55 +++++++++++++++------ 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index f172c8856ab5..c485a636f5d3 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -605,17 +605,20 @@ def forward( hidden_states, position_embeddings, attention_mask, - None, # past_key_value - False, # output_attentions - False, # use_cache - None, # cache_position + past_key_values, + output_attentions, + use_cache, + cache_position, ) else: layer_outputs = layer( hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, + past_key_value=past_key_values, output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -738,9 +741,10 @@ def forward( if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds - # Use sequence length from embeds (standard transformers pattern) - seq_len = embeds.shape[1] - position_ids = torch.arange(seq_len, device=embeds.device).unsqueeze(0).expand(batch_size, -1) + # Use passed position_ids if available, otherwise compute from sequence length + if position_ids is None: + seq_len = embeds.shape[1] + position_ids = torch.arange(seq_len, device=embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) @@ -767,17 +771,20 @@ def forward( hidden_states, position_embeddings, attention_mask, - None, # past_key_value - False, # output_attentions - False, # use_cache - None, # cache_position + past_key_values, + output_attentions, + use_cache, + cache_position, ) else: layer_outputs = layer( hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, + past_key_value=past_key_values, output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -940,17 +947,20 @@ def forward( hidden_states, position_embeddings, attention_mask, - None, # past_key_value - False, # output_attentions - False, # use_cache - None, # cache_position + past_key_values, + output_attentions, + use_cache, + cache_position, ) else: layer_outputs = layer( hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, + past_key_value=past_key_values, output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -1271,6 +1281,10 @@ def forward( input_embeds=encoder_embeds, patch_embeds=None, attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, cross_mask=cross_attn_mask_enc, full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, num_patches=patch_lengths.shape[1], @@ -1284,16 +1298,21 @@ def forward( global_cache_position = torch.arange( 0, global_hidden_states.shape[1], device=global_hidden_states.device ) + global_position_ids = global_cache_position.unsqueeze(0) global_causal_mask = self._update_causal_mask( None, # attention_mask global_hidden_states, global_cache_position, - None, # past_key_values + None, # past_key_values - global transformer doesn't use cache False # output_attentions ) global_hidden_states, _, _ = self.global_transformer( input_embeds=global_hidden_states, attention_mask=global_causal_mask, + position_ids=global_position_ids, + past_key_values=None, # Global transformer doesn't use cache + use_cache=False, + cache_position=global_cache_position, output_attentions=False, output_hidden_states=False, ) @@ -1312,6 +1331,10 @@ def forward( embeds=encoder_hidden_states, patch_embeds=global_hidden_states, attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, mask=None, cross_mask=cross_attn_mask_dec, full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, From 8add244735c7cfbc3a8c199229bf4ba6f05045f5 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 21 Jul 2025 15:06:46 +0000 Subject: [PATCH 090/139] try cache generate fix --- src/transformers/models/blt/modeling_blt.py | 30 ++++++++++++--------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index c485a636f5d3..1b132295fa99 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -588,9 +588,10 @@ def forward( hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) - position_ids = ( - torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - ) + if position_ids is None: + position_ids = ( + torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) @@ -741,10 +742,10 @@ def forward( if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds - # Use passed position_ids if available, otherwise compute from sequence length if position_ids is None: - seq_len = embeds.shape[1] - position_ids = torch.arange(seq_len, device=embeds.device).unsqueeze(0).expand(batch_size, -1) + position_ids = ( + torch.arange(embeds.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + ) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) @@ -934,7 +935,10 @@ def forward( hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) - position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + if position_ids is None: + position_ids = ( + torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, layer in enumerate(self.layers): @@ -1282,9 +1286,9 @@ def forward( patch_embeds=None, attention_mask=causal_mask, position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, + past_key_values=None, + use_cache=False, + cache_position=None, cross_mask=cross_attn_mask_enc, full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, num_patches=patch_lengths.shape[1], @@ -1303,16 +1307,16 @@ def forward( None, # attention_mask global_hidden_states, global_cache_position, - None, # past_key_values - global transformer doesn't use cache + None, # past_key_values False # output_attentions ) global_hidden_states, _, _ = self.global_transformer( input_embeds=global_hidden_states, attention_mask=global_causal_mask, position_ids=global_position_ids, - past_key_values=None, # Global transformer doesn't use cache + past_key_values=None, use_cache=False, - cache_position=global_cache_position, + cache_position=None, output_attentions=False, output_hidden_states=False, ) From f198de796672b3377874e4ac3f95c79a374bbfed Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 21 Jul 2025 17:33:38 +0000 Subject: [PATCH 091/139] fix generate tests --- src/transformers/models/blt/modeling_blt.py | 32 +++++++++------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 1b132295fa99..6c56fd0fb3c2 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -588,10 +588,9 @@ def forward( hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) - if position_ids is None: - position_ids = ( - torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - ) + position_ids = ( + torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) @@ -742,10 +741,10 @@ def forward( if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds + # Use passed position_ids if available, otherwise compute from sequence length if position_ids is None: - position_ids = ( - torch.arange(embeds.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) - ) + seq_len = embeds.shape[1] + position_ids = torch.arange(seq_len, device=embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) @@ -935,10 +934,7 @@ def forward( hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) - if position_ids is None: - position_ids = ( - torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - ) + position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, layer in enumerate(self.layers): @@ -1286,9 +1282,9 @@ def forward( patch_embeds=None, attention_mask=causal_mask, position_ids=position_ids, - past_key_values=None, - use_cache=False, - cache_position=None, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, cross_mask=cross_attn_mask_enc, full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, num_patches=patch_lengths.shape[1], @@ -1307,16 +1303,16 @@ def forward( None, # attention_mask global_hidden_states, global_cache_position, - None, # past_key_values + None, # past_key_values - global transformer doesn't use cache False # output_attentions ) global_hidden_states, _, _ = self.global_transformer( input_embeds=global_hidden_states, attention_mask=global_causal_mask, position_ids=global_position_ids, - past_key_values=None, + past_key_values=None, # Global transformer doesn't use cache use_cache=False, - cache_position=None, + cache_position=global_cache_position, output_attentions=False, output_hidden_states=False, ) @@ -1701,4 +1697,4 @@ def forward( "BLTModel", "BLTPatcher", "BLTForCausalLM", -] +] \ No newline at end of file From ab4d2cae47a3e99e0a0b2f6263ab65f940147e6a Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 22 Jul 2025 09:56:36 +0000 Subject: [PATCH 092/139] attn_impl workaround --- src/transformers/models/blt/modeling_blt.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 6c56fd0fb3c2..f9f99fbb7d10 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -166,6 +166,8 @@ def _init_weights(self, module): elif isinstance(module, nn.RMSNorm): module.weight.data.fill_(1.0) + + def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -1184,6 +1186,12 @@ def __init__(self, config: BLTConfig): else: self.patcher = None + # Sync attention implementation from main config to sub-configs + for subconfig_name in ["encoder_config", "decoder_config", "global_config", "patcher_config"]: + subconfig = getattr(self.config, subconfig_name) + if subconfig is not None: + subconfig._attn_implementation = self.config._attn_implementation + self.post_init() def forward( From a00ce1de26ab3470da8af77c2d534b2df3b18cff Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 22 Jul 2025 15:25:59 +0000 Subject: [PATCH 093/139] refactoring to use recent TransformersKwargs changes --- src/transformers/models/blt/modeling_blt.py | 825 +++++++------------- 1 file changed, 275 insertions(+), 550 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index f9f99fbb7d10..37d1168ff7fc 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -36,6 +36,8 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging +from ...utils.generic import check_model_inputs + from .configuration_blt import ( BLTConfig, BLTGlobalTransformerConfig, @@ -68,7 +70,7 @@ def __init__(self, config): # Ignore copy self.act_fn = ACT2FN[config.hidden_act] - def forward(self, x): + def forward(self, x, **kwargs: Unpack[TransformersKwargs]): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -82,7 +84,7 @@ def __init__(self, hidden_size, eps=1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def forward(self, hidden_states): + def forward(self, hidden_states, **kwargs: Unpack[TransformersKwargs]): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) @@ -113,7 +115,7 @@ def __init__(self, config: BLTConfig, device=None): @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): + def forward(self, x, position_ids, **kwargs: Unpack[TransformersKwargs]): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() @@ -127,251 +129,42 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -@auto_docstring -class BLTPreTrainedModel(PreTrainedModel): - """BLT PreTrainedModel inheriting from Mllama but with BLT-specific init.""" - - config: BLTConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] - - _supports_static_cache = False # static cache cannot have different shapes for each layer - _supports_sdpa = True - _supports_flash_attn = True - _supports_flex_attn = True - _supports_attention_backend = True - - config_class = BLTConfig - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = False - _supports_cache_class = False - - def _init_weights(self, module): - std = self.config.initializer_range - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - elif isinstance(module, BLTRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, nn.RMSNorm): - module.weight.data.fill_(1.0) - - - - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer class BLTTransformerLayer(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) self.mlp = BLTMLP(config) self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.layer_idx = layer_idx def forward( self, hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - position_ids=position_ids, + position_embeddings=position_embeddings, past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states - - # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - return outputs + return hidden_states def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -460,7 +253,6 @@ def __init__(self, config: BLTConfig, layer_idx: int): self.scaling = self.head_dim**-0.5 self.rope_theta = config.rope_theta self.layer_idx = layer_idx - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -470,16 +262,16 @@ def __init__(self, config: BLTConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, - output_attentions: bool = False, - use_cache: bool = False, - past_key_value=None, - cache_position=None, - **kwargs, + attention_mask: Optional[torch.Tensor], + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], ): + # Ensure hidden_states is always 3D (batch, seq_len, hidden_dim) + if hidden_states.dim() == 2: + hidden_states = hidden_states.unsqueeze(0) bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -487,26 +279,15 @@ def forward( query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - + attn_impl = getattr(self.config, "_attn_implementation", None) or "eager" attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - + if attn_impl != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] attn_output, attn_weights = attention_interface( self, query_states, @@ -517,13 +298,8 @@ def forward( scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - return attn_output, attn_weights @@ -533,24 +309,17 @@ def forward( class BLTLocalEncoder(nn.Module): def __init__(self, config: BLTLocalEncoderConfig): super().__init__() - self.config = config - self.gradient_checkpointing = False - self.layers = nn.ModuleList( [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.rotary_emb = BLTRotaryEmbedding(config=config) - self.patch_embedding_projection = nn.Linear( in_features=config.hidden_size, out_features=config.hidden_size * config.cross_attn_k, bias=False, ) - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.cross_attn_layers = nn.ModuleList() layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): @@ -566,88 +335,51 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, - cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ): - """ """ - # Initialize output collections - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - if input_embeds is None: input_embeds = self.embed_tokens(input_ids) - batch_size, _, _ = input_embeds.shape - hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) - - position_ids = ( - torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - ) + if position_ids is None: + position_ids = ( + torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) position_embeddings = self.rotary_emb(hidden_states, position_ids) - hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) - for idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - position_embeddings, - attention_mask, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=attention_mask, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_value=past_key_values, + cache_position=cache_position, + **kwargs, + ) if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size ) - layer_idx = idx if self.config.cross_attn_all_layers else 0 + # Remove cross_attention_states from kwargs if present to avoid multiple values error + kwargs.pop("cross_attention_states", None) cross_attention_output, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, attention_mask=cross_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, - output_attentions=False, - use_cache=False, - cache_position=None, + **kwargs, ) patch_embeds = patch_embeds + cross_attention_output - encoder_cross_states = patch_embeds - return hidden_states, encoder_cross_states, all_hidden_states, all_attentions + return hidden_states, encoder_cross_states def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): """ @@ -682,26 +414,18 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): class BLTLocalDecoder(nn.Module): def __init__(self, config: BLTLocalDecoderConfig): super().__init__() - - # Extract config values to instance attributes self.config = config - self.cross_attn_decoder = True # config.cross_attn_decoder #TODO: maybe remove - self.gradient_checkpointing = False - + self.cross_attn_decoder = True self.layers = nn.ModuleList( [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.rotary_emb = BLTRotaryEmbedding(config=config) - self.patch_embedding_projection = nn.Linear( in_features=config.hidden_size_global, out_features=config.hidden_size * config.cross_attn_k, bias=False, ) - self.norm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.cross_attn_layers = nn.ModuleList() layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): @@ -717,89 +441,48 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ): - # Initialize output collections - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - batch_size, _, _ = embeds.shape - hidden_states = embeds - patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size ) - if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds - - # Use passed position_ids if available, otherwise compute from sequence length if position_ids is None: - seq_len = embeds.shape[1] - position_ids = torch.arange(seq_len, device=embeds.device).unsqueeze(0).expand(batch_size, -1) + position_ids = ( + torch.arange(embeds.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + ) position_embeddings = self.rotary_emb(hidden_states, position_ids) - hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - if i == 0 or self.config.cross_attn_all_layers: - # Use cross attention to extract info from patch_embeds into hidden_states + # Remove cross_attention_states from kwargs if present to avoid multiple values error + kwargs.pop("cross_attention_states", None) cross_attention_output, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, cross_attention_states=patch_embeds, attention_mask=cross_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, - output_attentions=False, - use_cache=False, - cache_position=None, + **kwargs, ) hidden_states = hidden_states + cross_attention_output - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - position_embeddings, - attention_mask, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=attention_mask, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add final hidden state after all layers - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_value=past_key_values, + cache_position=cache_position, + **kwargs, + ) logits = self.norm(hidden_states) - # logits = self.lm_head(logits) - return logits, all_hidden_states, all_attentions + return logits class BLTCrossAttention(nn.Module): @@ -832,32 +515,25 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ): bsz, q_len, _ = hidden_states.size() - - query_states = self.q_norm(hidden_states) # BLT normalizes first + query_states = self.q_norm(hidden_states) query_states = self.q_proj(query_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - if cross_attention_states is not None: - cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first + cross_attention_states = self.k_norm(cross_attention_states) key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: - # if we have a new image + new tokens, we only computed key_states on that new image - # we still update the cross key states, past_image, new_image. And use it! key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - elif cache_position[0] != 0: + elif cache_position is not None and cache_position[0] != 0: key_states, value_states = ( past_key_value.key_cache[self.layer_idx], past_key_value.value_cache[self.layer_idx], @@ -866,18 +542,10 @@ def forward( raise ValueError( "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" ) - + attn_impl = getattr(self.config, "_attn_implementation", None) or "eager" attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - + if attn_impl != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] attn_output, attn_weights = attention_interface( self, query_states, @@ -888,28 +556,19 @@ def forward( scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) attn_output = attn_output + hidden_states - - if not output_attentions: - attn_weights = None - return attn_output, attn_weights class BLTGlobalTransformer(nn.Module): def __init__(self, config: BLTGlobalTransformerConfig): super().__init__() - self.config = config - self.gradient_checkpointing = False - self.layers = nn.ModuleList() for layer_idx in range(config.num_hidden_layers): self.layers.append(BLTTransformerLayer(config, layer_idx)) - self.rotary_emb = BLTRotaryEmbedding(config=config) def forward( @@ -918,58 +577,27 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, - cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ): - # Initialize output collections - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - batch_size, seq_len, _ = input_embeds.shape - hidden_states = input_embeds - hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) - - position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + if position_ids is None: + position_ids = ( + torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) position_embeddings = self.rotary_emb(hidden_states, position_ids) - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - position_embeddings, - attention_mask, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=attention_mask, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - return hidden_states, all_hidden_states, all_attentions + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + return hidden_states def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): @@ -1164,6 +792,176 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optiona return padded + +@auto_docstring +class BLTPreTrainedModel(PreTrainedModel): + """BLT PreTrainedModel inheriting from Mllama but with BLT-specific init.""" + + config: BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + + _supports_static_cache = False # static cache cannot have different shapes for each layer + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + + config_class = BLTConfig + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False + _supports_cache_class = False + + _can_record_outputs = { + "hidden_states": BLTTransformerLayer, + "attentions": BLTSelfAttention, + } + + def _init_weights(self, module): + std = self.config.initializer_range + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, BLTRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, nn.RMSNorm): + module.weight.data.fill_(1.0) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +# Top-level model classes class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) @@ -1171,7 +969,6 @@ def __init__(self, config: BLTConfig): self.local_encoder = BLTLocalEncoder(config.encoder_config) self.global_transformer = BLTGlobalTransformer(config.global_config) self.local_decoder = BLTLocalDecoder(config.decoder_config) - num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size) embeddings = [ nn.Embedding(config.encoder_hash_byte_group_vocab, config.encoder_config.hidden_size) @@ -1185,15 +982,9 @@ def __init__(self, config: BLTConfig): param.requires_grad = False else: self.patcher = None - - # Sync attention implementation from main config to sub-configs - for subconfig_name in ["encoder_config", "decoder_config", "global_config", "patcher_config"]: - subconfig = getattr(self.config, subconfig_name) - if subconfig is not None: - subconfig._attn_implementation = self.config._attn_implementation - self.post_init() + @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1207,34 +998,21 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Union[BaseModelOutputWithPast, tuple]: - """ - Args: - input_ids (torch.LongTensor, optional): Input token ids. - patch_lengths (Optional[torch.Tensor]): Patch lengths for patching. - attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Ignored, for compatibility. - Returns: - Union[BaseModelOutputWithPast, tuple]: Model outputs. - """ - # Set defaults from config when parameters are None + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Explicit input validation (not XOR) if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is None and inputs_embeds is None: raise ValueError("You have to specify either input_ids or inputs_embeds") - if input_ids is not None: batch_size, sequence_length = input_ids.shape else: batch_size, sequence_length, _ = inputs_embeds.shape - # Handle patching if patch_lengths is None: if self.config.patching_mode == "entropy" and self.patcher is not None: if input_ids is None: @@ -1266,65 +1044,51 @@ def forward( self.config.encoder_hash_byte_group_size, self.config.encoder_hash_byte_group_vocab, ) - if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + encoder_embeds.shape[1], device=encoder_embeds.device ) - if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( attention_mask, encoder_embeds, cache_position, past_key_values, output_attentions ) - cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) - encoder_hidden_states, encoder_cross_states, encoder_hidden_states_all, encoder_attentions_all = ( - self.local_encoder( - input_ids=input_ids, - input_embeds=encoder_embeds, - patch_embeds=None, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - cross_mask=cross_attn_mask_enc, - full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, - num_patches=patch_lengths.shape[1], - patch_ids=patch_ids, - output_attentions=False, # Don't collect encoder attentions - output_hidden_states=False, # Don't collect encoder hidden states - ) + # Remove full_text_row_masked_out_mask from kwargs if present to avoid multiple values error + kwargs.pop("full_text_row_masked_out_mask", None) + encoder_hidden_states, encoder_cross_states = self.local_encoder( + input_ids=input_ids, + input_embeds=encoder_embeds, + patch_embeds=None, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=None, + cache_position=None, + cross_mask=cross_attn_mask_enc, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + **kwargs, ) - global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) global_cache_position = torch.arange( 0, global_hidden_states.shape[1], device=global_hidden_states.device ) global_position_ids = global_cache_position.unsqueeze(0) global_causal_mask = self._update_causal_mask( - None, # attention_mask - global_hidden_states, - global_cache_position, - None, # past_key_values - global transformer doesn't use cache - False # output_attentions + None, global_hidden_states, global_cache_position, None, False ) - global_hidden_states, _, _ = self.global_transformer( + global_hidden_states = self.global_transformer( input_embeds=global_hidden_states, attention_mask=global_causal_mask, position_ids=global_position_ids, - past_key_values=None, # Global transformer doesn't use cache - use_cache=False, - cache_position=global_cache_position, - output_attentions=False, - output_hidden_states=False, + past_key_values=None, + cache_position=None, + **kwargs, ) - decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( decoder_patch_ids, @@ -1334,55 +1098,31 @@ def forward( self.config.cross_attn_k, encoder_embeds.dtype, ) - output, decoder_hidden_states_all, decoder_attentions_all = self.local_decoder( + output = self.local_decoder( input_ids=input_ids, embeds=encoder_hidden_states, patch_embeds=global_hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, - use_cache=use_cache, cache_position=cache_position, mask=None, cross_mask=cross_attn_mask_dec, full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, - output_attentions=output_attentions, # Only collect decoder attentions - output_hidden_states=output_hidden_states, # Only collect decoder hidden states - ) - - # Only use decoder outputs (which match the expected num_hidden_layers) - all_hidden_states = ( - decoder_hidden_states_all if output_hidden_states and decoder_hidden_states_all is not None else None + **kwargs, ) - all_attentions = decoder_attentions_all if output_attentions and decoder_attentions_all is not None else None - - if not return_dict: - outputs = (output,) - if past_key_values is not None: - outputs = outputs + (past_key_values,) - if all_hidden_states is not None: - outputs = outputs + (all_hidden_states,) - if all_attentions is not None: - outputs = outputs + (all_attentions,) - return outputs - return BaseModelOutputWithPast( last_hidden_state=output, past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_attentions, ) def get_input_embeddings(self): - """Returns the model's input embeddings.""" return self.local_encoder.embed_tokens def set_input_embeddings(self, value): - """Sets the model's input embeddings.""" self.local_encoder.embed_tokens = value def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: - """Convert patch lengths to patch IDs for each token position.""" batch_size = patch_lengths.shape[0] patch_starts = torch.cat( [ @@ -1391,7 +1131,6 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> ], dim=-1, ) - token_positions = torch.arange(seq_len, device=patch_lengths.device) return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 @@ -1403,7 +1142,6 @@ def __init__(self, config: BLTPatcherConfig): self.layers = nn.ModuleList() for layer_idx in range(self.config.num_hidden_layers): self.layers.append(BLTTransformerLayer(self.config, layer_idx)) - self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.lm_head = nn.Linear( @@ -1420,14 +1158,13 @@ def forward( max_patch_length: Optional[int] = None, patching_batch_size: int = 1, device: Optional[str] = None, + **kwargs: Unpack[TransformersKwargs], ): - # Handle chunked processing for entropy calculation entropies = [] predictions = [] max_length = self.config.max_position_embeddings batch_numel = max_length * patching_batch_size splits = torch.split(token_values.flatten(), batch_numel) - for split in splits: pad_size = (max_length - (split.numel() % max_length)) % max_length pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False) @@ -1435,17 +1172,11 @@ def forward( split = split.reshape(-1, max_length) if device is not None: split = split.to(device) - - # Process chunk: embeddings -> layers -> output batch_size, sequence_length = split.shape input_embeds = self.embed_tokens(split) - hidden_states = input_embeds - batch_size, _, _ = input_embeds.shape - position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) cache_position = torch.arange(sequence_length, device=input_embeds.device) @@ -1466,14 +1197,9 @@ def forward( predictions.append(logits) prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() entropies.append(prediction_entropies) - concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) - - # Always compute patch lengths from concatenated entropies batch_size, sequence_length = token_values.shape - - # Find patch start IDs based on entropy if patch_size is not None: patch_lengths = self.patch_lengths_from_entropies( entropies=concat_entropies, @@ -1482,7 +1208,6 @@ def forward( threshold=threshold, ) else: - # Default to byte-level patching patch_lengths = torch.ones( (batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device ) From 1df0b6a22e2ffff0e329d0f2ef3773bfc37cfb60 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 22 Jul 2025 16:01:07 +0000 Subject: [PATCH 094/139] fix hidden_states shape test --- src/transformers/models/blt/modeling_blt.py | 60 +++++++-------------- 1 file changed, 18 insertions(+), 42 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 37d1168ff7fc..30c8cb074e2d 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -1,9 +1,3 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/blt/modular_blt.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_blt.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 HuggingFace Inc. team. All rights reserved. # @@ -35,7 +29,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils.generic import check_model_inputs from .configuration_blt import ( @@ -303,6 +297,15 @@ def forward( return attn_output, attn_weights +# Decoder-specific layer class for automatic hidden states collection +class BLTDecoderLayer(BLTTransformerLayer): + """ + BLT decoder layer - identical to BLTTransformerLayer but with a different class name + for selective automatic collection of hidden states from decoder layers only. + """ + pass + + # BLT-SPECIFIC COMPONENTS (no Mllama equivalent) @@ -417,7 +420,7 @@ def __init__(self, config: BLTLocalDecoderConfig): self.config = config self.cross_attn_decoder = True self.layers = nn.ModuleList( - [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [BLTDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.rotary_emb = BLTRotaryEmbedding(config=config) self.patch_embedding_projection = nn.Linear( @@ -814,7 +817,7 @@ class BLTPreTrainedModel(PreTrainedModel): _supports_cache_class = False _can_record_outputs = { - "hidden_states": BLTTransformerLayer, + "hidden_states": BLTDecoderLayer, "attentions": BLTSelfAttention, } @@ -843,7 +846,6 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): @@ -861,7 +863,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -895,7 +897,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -994,17 +995,9 @@ def forward( past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is None and inputs_embeds is None: @@ -1052,7 +1045,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, encoder_embeds, cache_position, past_key_values, output_attentions + attention_mask, encoder_embeds, cache_position, past_key_values ) cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype @@ -1079,7 +1072,7 @@ def forward( ) global_position_ids = global_cache_position.unsqueeze(0) global_causal_mask = self._update_causal_mask( - None, global_hidden_states, global_cache_position, None, False + None, global_hidden_states, global_cache_position, None ) global_hidden_states = self.global_transformer( input_embeds=global_hidden_states, @@ -1185,7 +1178,6 @@ def forward( input_embeds, cache_position, None, # past_key_values - False # output_attentions ) for i, layer in enumerate(self.layers): @@ -1316,6 +1308,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @can_return_tuple @auto_docstring def forward( self, @@ -1329,13 +1322,10 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], - ) -> Union[tuple, CausalLMOutputWithPast]: + ) -> CausalLMOutputWithPast: r""" cross_attention_states (`torch.FloatTensor`, *optional*): Output of the vision model, used for cross-attention. This tensor contains the processed image features that @@ -1380,15 +1370,8 @@ def forward( I love the idea of snowflakes gently falling, each one ``` """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, cross_attention_states=cross_attention_states, attention_mask=attention_mask, position_ids=position_ids, @@ -1397,9 +1380,6 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) @@ -1412,10 +1392,6 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, From 3f7d5cdde21af523eb21eff05349febbfef6400b Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 23 Jul 2025 17:51:02 +0000 Subject: [PATCH 095/139] refactor to new outputs --- .../models/blt/configuration_blt.py | 6 ---- src/transformers/models/blt/modeling_blt.py | 25 ++++++++++++--- tests/models/blt/test_modeling_blt.py | 32 ++++++++++++++++++- 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 102aa8bb1792..3e3075e5ab74 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -383,12 +383,6 @@ def __init__( super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) - # Add decoder config attributes to main config for compatibility with tests - # These mirror the decoder config attributes since the main model interface uses the decoder - # self.hidden_size = self.decoder_config.hidden_size - # self.num_hidden_layers = self.decoder_config.num_hidden_layers - # self.num_attention_heads = self.decoder_config.num_attention_heads - __all__ = [ "BLTConfig", diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 30c8cb074e2d..fc91fa451cd6 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -30,7 +30,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, OutputRecorder from .configuration_blt import ( BLTConfig, @@ -145,7 +145,7 @@ def forward( ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, @@ -312,6 +312,7 @@ class BLTDecoderLayer(BLTTransformerLayer): class BLTLocalEncoder(nn.Module): def __init__(self, config: BLTLocalEncoderConfig): super().__init__() + self.gradient_checkpointing = False self.config = config self.layers = nn.ModuleList( [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] @@ -417,6 +418,7 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): class BLTLocalDecoder(nn.Module): def __init__(self, config: BLTLocalDecoderConfig): super().__init__() + self.gradient_checkpointing = False self.config = config self.cross_attn_decoder = True self.layers = nn.ModuleList( @@ -436,6 +438,7 @@ def __init__(self, config: BLTLocalDecoderConfig): BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) + @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -818,7 +821,9 @@ class BLTPreTrainedModel(PreTrainedModel): _can_record_outputs = { "hidden_states": BLTDecoderLayer, - "attentions": BLTSelfAttention, + "attentions": OutputRecorder(BLTSelfAttention, index=1, layer_name="local_decoder"), + "encoder_attentions": OutputRecorder(BLTSelfAttention, index=1, layer_name="local_encoder"), + "global_attentions": OutputRecorder(BLTSelfAttention, index=1, layer_name="global_transformer"), } def _init_weights(self, module): @@ -966,6 +971,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) + self.gradient_checkpointing = False self.config = config self.local_encoder = BLTLocalEncoder(config.encoder_config) self.global_transformer = BLTGlobalTransformer(config.global_config) @@ -1371,7 +1377,8 @@ def forward( ``` """ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, cross_attention_states=cross_attention_states, attention_mask=attention_mask, position_ids=position_ids, @@ -1392,13 +1399,21 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) - return CausalLMOutputWithPast( + output = CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + # Add BLT-specific attention outputs + if hasattr(outputs, 'encoder_attentions'): + output.encoder_attentions = outputs.encoder_attentions + if hasattr(outputs, 'global_attentions'): + output.global_attentions = outputs.global_attentions + + return output __all__ = [ diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 2eda2b08d21c..d07ac78a6d39 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -140,6 +140,8 @@ def __init__( "dropout": self.dropout, } + self.num_hidden_layers = self.encoder_config["num_hidden_layers"] + def get_config(self): config = BLTConfig( vocab_size=self.vocab_size, @@ -163,7 +165,7 @@ def get_config(self): ) config.num_attention_heads = config.decoder_config.num_attention_heads - config.num_hidden_layers = config.decoder_config.num_hidden_layers + config.num_hidden_layers = config.encoder_config.num_hidden_layers config.hidden_size = config.decoder_config.hidden_size return config @@ -312,6 +314,34 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_flex_attention_with_grads(): return + def test_attention_outputs(self): + if not self.has_attentions: + self.skipTest(reason="Model does not output attentions") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + config._attn_implementation = "eager" + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + # For BLT, check separate attention outputs from each component + attentions = outputs.attentions + encoder_attentions = outputs.encoder_attentions + global_attentions = outputs.global_attentions + + # Each component should have attention outputs equal to their layer count + self.assertEqual(len(attentions), config.decoder_config.num_hidden_layers) + self.assertEqual(len(encoder_attentions), config.encoder_config.num_hidden_layers) + self.assertEqual(len(global_attentions), config.global_config.num_hidden_layers) + @require_torch_accelerator class BLTIntegrationTest(unittest.TestCase): From 22a511a704ec16e374ebc443ab02ce1e50845676 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 24 Jul 2025 10:05:12 +0000 Subject: [PATCH 096/139] simplify outputs a bit --- src/transformers/models/blt/modeling_blt.py | 44 ++++++++++----------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index fc91fa451cd6..035640229974 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -265,7 +265,8 @@ def forward( # Ensure hidden_states is always 3D (batch, seq_len, hidden_dim) if hidden_states.dim() == 2: hidden_states = hidden_states.unsqueeze(0) - bsz, q_len, _ = hidden_states.size() + bsz = hidden_states.size(0) + q_len = hidden_states.size(1) query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -297,15 +298,6 @@ def forward( return attn_output, attn_weights -# Decoder-specific layer class for automatic hidden states collection -class BLTDecoderLayer(BLTTransformerLayer): - """ - BLT decoder layer - identical to BLTTransformerLayer but with a different class name - for selective automatic collection of hidden states from decoder layers only. - """ - pass - - # BLT-SPECIFIC COMPONENTS (no Mllama equivalent) @@ -348,7 +340,7 @@ def forward( ): if input_embeds is None: input_embeds = self.embed_tokens(input_ids) - batch_size, _, _ = input_embeds.shape + batch_size = input_embeds.shape[0] hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) if position_ids is None: position_ids = ( @@ -396,7 +388,8 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): (i.e. if the sum(patch_lengths[i]) < seq_len for any i) will be sent to a dummy patch, which is trimmed before returning. """ - batch_size, _, embedding_dim = hidden_states.shape + batch_size = hidden_states.shape[0] + embedding_dim = hidden_states.shape[-1] patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) @@ -422,7 +415,7 @@ def __init__(self, config: BLTLocalDecoderConfig): self.config = config self.cross_attn_decoder = True self.layers = nn.ModuleList( - [BLTDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.rotary_emb = BLTRotaryEmbedding(config=config) self.patch_embedding_projection = nn.Linear( @@ -453,7 +446,7 @@ def forward( full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): - batch_size, _, _ = embeds.shape + batch_size = embeds.shape[0] hidden_states = embeds patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( @@ -525,7 +518,8 @@ def forward( full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): - bsz, q_len, _ = hidden_states.size() + bsz = hidden_states.size(0) + q_len = hidden_states.size(1) query_states = self.q_norm(hidden_states) query_states = self.q_proj(query_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -586,7 +580,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ): - batch_size, seq_len, _ = input_embeds.shape + batch_size = input_embeds.shape[0] + seq_len = input_embeds.shape[1] hidden_states = input_embeds hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) if position_ids is None: @@ -631,7 +626,8 @@ def byte_group_hash_function( ): """Hash token groups and map to range [0, max_hash].""" with torch.no_grad(): - batch_size, seq_len = token_ids.shape + batch_size = token_ids.shape[0] + seq_len = token_ids.shape[1] # Add padding for sliding window padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) padded_tokens = torch.cat([padding, token_ids], dim=1) @@ -693,7 +689,8 @@ def _prepare_patch_cross_attention_mask( - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows """ - batch_size, seq_len = patch_ids.shape + batch_size = patch_ids.shape[0] + seq_len = patch_ids.shape[1] device = patch_ids.device # Determine query and key lengths based on configuration @@ -820,7 +817,7 @@ class BLTPreTrainedModel(PreTrainedModel): _supports_cache_class = False _can_record_outputs = { - "hidden_states": BLTDecoderLayer, + "hidden_states": OutputRecorder(BLTTransformerLayer, index=1, layer_name="local_decoder"), "attentions": OutputRecorder(BLTSelfAttention, index=1, layer_name="local_decoder"), "encoder_attentions": OutputRecorder(BLTSelfAttention, index=1, layer_name="local_encoder"), "global_attentions": OutputRecorder(BLTSelfAttention, index=1, layer_name="global_transformer"), @@ -1171,10 +1168,11 @@ def forward( split = split.reshape(-1, max_length) if device is not None: split = split.to(device) - batch_size, sequence_length = split.shape + batch_size = split.shape[0] + sequence_length = split.shape[1] input_embeds = self.embed_tokens(split) hidden_states = input_embeds - batch_size, _, _ = input_embeds.shape + batch_size = input_embeds.shape[0] position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -1197,7 +1195,7 @@ def forward( entropies.append(prediction_entropies) concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) - batch_size, sequence_length = token_values.shape + batch_size = token_values.shape[0] if patch_size is not None: patch_lengths = self.patch_lengths_from_entropies( entropies=concat_entropies, @@ -1407,7 +1405,7 @@ def forward( attentions=outputs.attentions, ) - # Add BLT-specific attention outputs + # Add BLT-specific attention outputs - TODO: this is needed only for the test_attention_outputs test if hasattr(outputs, 'encoder_attentions'): output.encoder_attentions = outputs.encoder_attentions if hasattr(outputs, 'global_attentions'): From 0239f77346e29f7fc713972ffd36e495cb659b78 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 28 Jul 2025 13:33:37 +0000 Subject: [PATCH 097/139] rm unneeded decoderlayer overwriting --- src/transformers/models/blt/modeling_blt.py | 33 ++++++++------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 035640229974..2c3ecb78fec8 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -265,8 +265,7 @@ def forward( # Ensure hidden_states is always 3D (batch, seq_len, hidden_dim) if hidden_states.dim() == 2: hidden_states = hidden_states.unsqueeze(0) - bsz = hidden_states.size(0) - q_len = hidden_states.size(1) + bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -340,7 +339,7 @@ def forward( ): if input_embeds is None: input_embeds = self.embed_tokens(input_ids) - batch_size = input_embeds.shape[0] + batch_size, _, _ = input_embeds.shape hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) if position_ids is None: position_ids = ( @@ -388,8 +387,7 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): (i.e. if the sum(patch_lengths[i]) < seq_len for any i) will be sent to a dummy patch, which is trimmed before returning. """ - batch_size = hidden_states.shape[0] - embedding_dim = hidden_states.shape[-1] + batch_size, _, embedding_dim = hidden_states.shape patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) @@ -446,7 +444,7 @@ def forward( full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): - batch_size = embeds.shape[0] + batch_size, _, _ = embeds.shape hidden_states = embeds patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( @@ -518,8 +516,7 @@ def forward( full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): - bsz = hidden_states.size(0) - q_len = hidden_states.size(1) + bsz, q_len, _ = hidden_states.size() query_states = self.q_norm(hidden_states) query_states = self.q_proj(query_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -580,8 +577,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ): - batch_size = input_embeds.shape[0] - seq_len = input_embeds.shape[1] + batch_size, seq_len, _ = input_embeds.shape hidden_states = input_embeds hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) if position_ids is None: @@ -626,8 +622,7 @@ def byte_group_hash_function( ): """Hash token groups and map to range [0, max_hash].""" with torch.no_grad(): - batch_size = token_ids.shape[0] - seq_len = token_ids.shape[1] + batch_size, seq_len = token_ids.shape # Add padding for sliding window padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) padded_tokens = torch.cat([padding, token_ids], dim=1) @@ -689,8 +684,7 @@ def _prepare_patch_cross_attention_mask( - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows """ - batch_size = patch_ids.shape[0] - seq_len = patch_ids.shape[1] + batch_size, seq_len = patch_ids.shape device = patch_ids.device # Determine query and key lengths based on configuration @@ -817,7 +811,7 @@ class BLTPreTrainedModel(PreTrainedModel): _supports_cache_class = False _can_record_outputs = { - "hidden_states": OutputRecorder(BLTTransformerLayer, index=1, layer_name="local_decoder"), + "hidden_states": OutputRecorder(BLTTransformerLayer, index=0, layer_name="local_decoder"), "attentions": OutputRecorder(BLTSelfAttention, index=1, layer_name="local_decoder"), "encoder_attentions": OutputRecorder(BLTSelfAttention, index=1, layer_name="local_encoder"), "global_attentions": OutputRecorder(BLTSelfAttention, index=1, layer_name="global_transformer"), @@ -1168,11 +1162,10 @@ def forward( split = split.reshape(-1, max_length) if device is not None: split = split.to(device) - batch_size = split.shape[0] - sequence_length = split.shape[1] + batch_size, sequence_length = split.shape input_embeds = self.embed_tokens(split) hidden_states = input_embeds - batch_size = input_embeds.shape[0] + batch_size, _, _ = input_embeds.shape position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -1195,7 +1188,7 @@ def forward( entropies.append(prediction_entropies) concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) - batch_size = token_values.shape[0] + batch_size, sequence_length = token_values.shape if patch_size is not None: patch_lengths = self.patch_lengths_from_entropies( entropies=concat_entropies, @@ -1405,7 +1398,7 @@ def forward( attentions=outputs.attentions, ) - # Add BLT-specific attention outputs - TODO: this is needed only for the test_attention_outputs test + # Add BLT-specific attention outputs if hasattr(outputs, 'encoder_attentions'): output.encoder_attentions = outputs.encoder_attentions if hasattr(outputs, 'global_attentions'): From 926fb0928266793107aec5407f1a15377ab6c2c4 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 28 Jul 2025 14:15:39 +0000 Subject: [PATCH 098/139] rename blt --- .../models/auto/configuration_auto.py | 4 +- src/transformers/models/auto/modeling_auto.py | 6 +- .../models/auto/tokenization_auto.py | 2 +- .../models/blt/configuration_blt.py | 102 ++++++------ .../models/blt/convert_blt_weights_to_hf.py | 6 +- src/transformers/models/blt/modeling_blt.py | 155 +++++++++--------- .../models/blt/tokenization_blt.py | 26 +-- tests/models/blt/test_modeling_blt.py | 76 ++++----- 8 files changed, 189 insertions(+), 188 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 47d4673a3b8e..ec6ce58f7994 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -65,7 +65,7 @@ ("blip-2", "Blip2Config"), ("blip_2_qformer", "Blip2QFormerConfig"), ("bloom", "BloomConfig"), - ("blt", "BLTConfig"), + ("blt", "BltConfig"), ("bridgetower", "BridgeTowerConfig"), ("bros", "BrosConfig"), ("camembert", "CamembertConfig"), @@ -491,7 +491,7 @@ ("blip-2", "BLIP-2"), ("blip_2_qformer", "BLIP-2 QFormer"), ("bloom", "BLOOM"), - ("blt", "BLT"), + ("blt", "Blt"), ("bort", "BORT"), ("bridgetower", "BridgeTower"), ("bros", "BROS"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d758acd687bd..372e6b723c8a 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -72,8 +72,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("blip-2", "Blip2Model"), ("blip_2_qformer", "Blip2QFormerModel"), ("bloom", "BloomModel"), - ("blt", "BLTModel"), - ("blt", "BLTModel"), + ("blt", "BltModel"), + ("blt", "BltModel"), ("bridgetower", "BridgeTowerModel"), ("bros", "BrosModel"), ("camembert", "CamembertModel"), @@ -635,7 +635,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("blenderbot", "BlenderbotForCausalLM"), ("blenderbot-small", "BlenderbotSmallForCausalLM"), ("bloom", "BloomForCausalLM"), - ("blt", "BLTForCausalLM"), + ("blt", "BltForCausalLM"), ("camembert", "CamembertForCausalLM"), ("code_llama", "LlamaForCausalLM"), ("codegen", "CodeGenForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index b79a5d5c8df3..e128b927ab8e 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -105,7 +105,7 @@ ("blip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("blip-2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)), - ("blt", ("BLTTokenizer", None)), + ("blt", ("BltTokenizer", None)), ("bridgetower", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), ("bros", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("byt5", ("ByT5Tokenizer", None)), diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 3e3075e5ab74..1a3557b7f2bd 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""BLT model configuration""" +"""Blt model configuration""" from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -21,9 +21,9 @@ logger = logging.get_logger(__name__) -class BLTLocalEncoderConfig(PretrainedConfig): +class BltLocalEncoderConfig(PretrainedConfig): """ - Configuration class for the BLT Local Encoder component. + Configuration class for the Blt Local Encoder component. """ model_type = "blt_local_encoder" @@ -69,9 +69,9 @@ def __init__( super().__init__(**kwargs) -class BLTLocalDecoderConfig(PretrainedConfig): +class BltLocalDecoderConfig(PretrainedConfig): """ - Configuration class for the BLT Local Decoder component. + Configuration class for the Blt Local Decoder component. """ model_type = "blt_local_decoder" @@ -115,9 +115,9 @@ def __init__( super().__init__(**kwargs) -class BLTGlobalTransformerConfig(PretrainedConfig): +class BltGlobalTransformerConfig(PretrainedConfig): """ - Configuration class for the BLT Global Transformer component. + Configuration class for the Blt Global Transformer component. """ model_type = "blt_global_transformer" @@ -153,9 +153,9 @@ def __init__( super().__init__(**kwargs) -class BLTPatcherConfig(PretrainedConfig): +class BltPatcherConfig(PretrainedConfig): r""" - Configuration class for the BLT Patcher/Entropy model component. + Configuration class for the Blt Patcher/Entropy model component. Args: vocab_size (`int`, *optional*, defaults to 256): @@ -215,24 +215,24 @@ def __init__( self.dropout = dropout self.rope_theta = rope_theta self.attn_bias_type = attn_bias_type - self.hidden_act = "silu" # BLT uses silu activation + self.hidden_act = "silu" # Blt uses silu activation self.intermediate_size = intermediate_size or int(8 * self.hidden_size / 3) self.rope_scaling = rope_scaling super().__init__(**kwargs) -class BLTConfig(PretrainedConfig): +class BltConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`BLTModel`]. It is used to instantiate a - BLT model according to the specified arguments, defining the model architecture. + This is the configuration class to store the configuration of a [`BltModel`]. It is used to instantiate a + Blt model according to the specified arguments, defining the model architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 256): - Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented. + Vocabulary size of the Blt model. Defines the number of different tokens (bytes) that can be represented. max_position_embeddings (`int`, *optional*, defaults to 1024): The maximum sequence length that this model can handle. # Patching configuration @@ -264,23 +264,23 @@ class BLTConfig(PretrainedConfig): Number of hash functions for byte groups. # Component configurations - patcher_config (`Union[BLTPatcherConfig, dict]`, *optional*): - Configuration for the BLT patcher/entropy model component. - encoder_config (`Union[BLTLocalEncoderConfig, dict]`, *optional*): - Configuration for the BLT local encoder component. - decoder_config (`Union[BLTLocalDecoderConfig, dict]`, *optional*): - Configuration for the BLT local decoder component. - global_config (`Union[BLTGlobalTransformerConfig, dict]`, *optional*): - Configuration for the BLT global transformer component. + patcher_config (`Union[BltPatcherConfig, dict]`, *optional*): + Configuration for the Blt patcher/entropy model component. + encoder_config (`Union[BltLocalEncoderConfig, dict]`, *optional*): + Configuration for the Blt local encoder component. + decoder_config (`Union[BltLocalDecoderConfig, dict]`, *optional*): + Configuration for the Blt local decoder component. + global_config (`Union[BltGlobalTransformerConfig, dict]`, *optional*): + Configuration for the Blt global transformer component. ```python - >>> from transformers import BLTModel, BLTConfig + >>> from transformers import BltModel, BltConfig - >>> # Initializing a BLT configuration - >>> configuration = BLTConfig() + >>> # Initializing a Blt configuration + >>> configuration = BltConfig() >>> # Initializing a model from the configuration - >>> model = BLTModel(configuration) + >>> model = BltModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -289,10 +289,10 @@ class BLTConfig(PretrainedConfig): model_type = "blt" keys_to_ignore_at_inference = ["past_key_values"] sub_configs = { - "patcher_config": BLTPatcherConfig, - "encoder_config": BLTLocalEncoderConfig, - "decoder_config": BLTLocalDecoderConfig, - "global_config": BLTGlobalTransformerConfig, + "patcher_config": BltPatcherConfig, + "encoder_config": BltLocalEncoderConfig, + "decoder_config": BltLocalDecoderConfig, + "global_config": BltGlobalTransformerConfig, } def __init__( @@ -350,44 +350,44 @@ def __init__( # Initialize component configurations if patcher_config is None: - self.patcher_config = BLTPatcherConfig() - logger.info("patcher_config is None, using default BLT patcher config") + self.patcher_config = BltPatcherConfig() + logger.info("patcher_config is None, using default Blt patcher config") elif isinstance(patcher_config, dict): - self.patcher_config = BLTPatcherConfig(**patcher_config) - elif isinstance(patcher_config, BLTPatcherConfig): + self.patcher_config = BltPatcherConfig(**patcher_config) + elif isinstance(patcher_config, BltPatcherConfig): self.patcher_config = patcher_config if encoder_config is None: - self.encoder_config = BLTLocalEncoderConfig() - logger.info("encoder_config is None, using default BLT encoder config") + self.encoder_config = BltLocalEncoderConfig() + logger.info("encoder_config is None, using default Blt encoder config") elif isinstance(encoder_config, dict): - self.encoder_config = BLTLocalEncoderConfig(**encoder_config) - elif isinstance(encoder_config, BLTLocalEncoderConfig): + self.encoder_config = BltLocalEncoderConfig(**encoder_config) + elif isinstance(encoder_config, BltLocalEncoderConfig): self.encoder_config = encoder_config if decoder_config is None: - self.decoder_config = BLTLocalDecoderConfig() - logger.info("decoder_config is None, using default BLT decoder config") + self.decoder_config = BltLocalDecoderConfig() + logger.info("decoder_config is None, using default Blt decoder config") elif isinstance(decoder_config, dict): - self.decoder_config = BLTLocalDecoderConfig(**decoder_config) - elif isinstance(decoder_config, BLTLocalDecoderConfig): + self.decoder_config = BltLocalDecoderConfig(**decoder_config) + elif isinstance(decoder_config, BltLocalDecoderConfig): self.decoder_config = decoder_config if global_config is None: - self.global_config = BLTGlobalTransformerConfig() - logger.info("global_config is None, using default BLT global config") + self.global_config = BltGlobalTransformerConfig() + logger.info("global_config is None, using default Blt global config") elif isinstance(global_config, dict): - self.global_config = BLTGlobalTransformerConfig(**global_config) - elif isinstance(global_config, BLTGlobalTransformerConfig): + self.global_config = BltGlobalTransformerConfig(**global_config) + elif isinstance(global_config, BltGlobalTransformerConfig): self.global_config = global_config super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) __all__ = [ - "BLTConfig", - "BLTPatcherConfig", - "BLTLocalEncoderConfig", - "BLTLocalDecoderConfig", - "BLTGlobalTransformerConfig", + "BltConfig", + "BltPatcherConfig", + "BltLocalEncoderConfig", + "BltLocalDecoderConfig", + "BltGlobalTransformerConfig", ] diff --git a/src/transformers/models/blt/convert_blt_weights_to_hf.py b/src/transformers/models/blt/convert_blt_weights_to_hf.py index 713f0e8ef112..a436ca3560cf 100644 --- a/src/transformers/models/blt/convert_blt_weights_to_hf.py +++ b/src/transformers/models/blt/convert_blt_weights_to_hf.py @@ -261,7 +261,7 @@ def create_tokenizer_config(output_dir: str, config: dict[str, Any]): def push_to_hub( local_dir: str, repo_id: str, - commit_message: str = "Upload converted BLT model", + commit_message: str = "Upload converted Blt model", private: bool = False, token: Optional[str] = None, ) -> None: @@ -321,7 +321,7 @@ def convert_hf_blt_to_unified( push_to_hub( local_dir=output_dir, repo_id=push_to_hub_repo, - commit_message="Upload BLT model converted", + commit_message="Upload Blt model converted", private=hub_private, token=hub_token, ) @@ -329,7 +329,7 @@ def convert_hf_blt_to_unified( def main(): parser = argparse.ArgumentParser( - description="Convert BLT models from HuggingFace Hub format to unified format", + description="Convert Blt models from HuggingFace Hub format to unified format", formatter_class=argparse.RawDescriptionHelpFormatter, ) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 2c3ecb78fec8..48af5f56d14b 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -33,11 +33,11 @@ from ...utils.generic import check_model_inputs, OutputRecorder from .configuration_blt import ( - BLTConfig, - BLTGlobalTransformerConfig, - BLTLocalDecoderConfig, - BLTLocalEncoderConfig, - BLTPatcherConfig, + BltConfig, + BltGlobalTransformerConfig, + BltLocalDecoderConfig, + BltLocalEncoderConfig, + BltPatcherConfig, ) @@ -52,7 +52,7 @@ logger = logging.get_logger(__name__) -class BLTMLP(nn.Module): +class BltMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config @@ -69,10 +69,10 @@ def forward(self, x, **kwargs: Unpack[TransformersKwargs]): return down_proj -class BLTRMSNorm(nn.Module): +class BltRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ - BLTRMSNorm is equivalent to T5LayerNorm + BltRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -89,8 +89,8 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class BLTRotaryEmbedding(nn.Module): - def __init__(self, config: BLTConfig, device=None): +class BltRotaryEmbedding(nn.Module): + def __init__(self, config: BltConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" self.rope_type = ( @@ -124,14 +124,14 @@ def forward(self, x, position_ids, **kwargs: Unpack[TransformersKwargs]): # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer -class BLTTransformerLayer(nn.Module): +class BltTransformerLayer(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) - self.mlp = BLTMLP(config) - self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = BltSelfAttention(config=config, layer_idx=layer_idx) + self.mlp = BltMLP(config) + self.input_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.layer_idx = layer_idx def forward( @@ -202,7 +202,7 @@ def eager_attention_forward( def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # TODO: not exactly equivalent to other transformers implementations,, need feedback # Extract first head_dim//2 elements which correspond to the unique frequencies - # This matches the original BLT approach which uses head_dim//2 frequency pairs + # This matches the original Blt approach which uses head_dim//2 frequency pairs head_dim = q.shape[-1] cos_freqs = cos[..., : head_dim // 2] # [B, S, D/2] sin_freqs = sin[..., : head_dim // 2] # [B, S, D/2] @@ -232,10 +232,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_rot.type_as(q), k_rot.type_as(k) -class BLTSelfAttention(nn.Module): - """BLT variant of MllamaTextSelfAttention. Inherits all logic directly.""" +class BltSelfAttention(nn.Module): + """Blt variant of MllamaTextSelfAttention. Inherits all logic directly.""" - def __init__(self, config: BLTConfig, layer_idx: int): + def __init__(self, config: BltConfig, layer_idx: int): super().__init__() self.config = config self.num_heads = config.num_attention_heads @@ -297,18 +297,18 @@ def forward( return attn_output, attn_weights -# BLT-SPECIFIC COMPONENTS (no Mllama equivalent) +# Blt-SPECIFIC COMPONENTS (no Mllama equivalent) -class BLTLocalEncoder(nn.Module): - def __init__(self, config: BLTLocalEncoderConfig): +class BltLocalEncoder(nn.Module): + def __init__(self, config: BltLocalEncoderConfig): super().__init__() self.gradient_checkpointing = False self.config = config self.layers = nn.ModuleList( - [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.rotary_emb = BLTRotaryEmbedding(config=config) + self.rotary_emb = BltRotaryEmbedding(config=config) self.patch_embedding_projection = nn.Linear( in_features=config.hidden_size, out_features=config.hidden_size * config.cross_attn_k, @@ -319,7 +319,7 @@ def __init__(self, config: BLTLocalEncoderConfig): layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) + BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) def forward( @@ -339,7 +339,7 @@ def forward( ): if input_embeds is None: input_embeds = self.embed_tokens(input_ids) - batch_size, _, _ = input_embeds.shape + batch_size = input_embeds.shape[0] hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) if position_ids is None: position_ids = ( @@ -387,7 +387,8 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): (i.e. if the sum(patch_lengths[i]) < seq_len for any i) will be sent to a dummy patch, which is trimmed before returning. """ - batch_size, _, embedding_dim = hidden_states.shape + batch_size = hidden_states.shape[0] + embedding_dim = hidden_states.shape[-1] patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) @@ -406,27 +407,27 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): return reduced_embeddings -class BLTLocalDecoder(nn.Module): - def __init__(self, config: BLTLocalDecoderConfig): +class BltLocalDecoder(nn.Module): + def __init__(self, config: BltLocalDecoderConfig): super().__init__() self.gradient_checkpointing = False self.config = config self.cross_attn_decoder = True self.layers = nn.ModuleList( - [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.rotary_emb = BLTRotaryEmbedding(config=config) + self.rotary_emb = BltRotaryEmbedding(config=config) self.patch_embedding_projection = nn.Linear( in_features=config.hidden_size_global, out_features=config.hidden_size * config.cross_attn_k, bias=False, ) - self.norm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn_layers = nn.ModuleList() layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) + BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) @check_model_inputs @@ -444,7 +445,7 @@ def forward( full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): - batch_size, _, _ = embeds.shape + batch_size = embeds.shape[0] hidden_states = embeds patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( @@ -482,10 +483,10 @@ def forward( return logits -class BLTCrossAttention(nn.Module): - """Cross-attention module for BLT, following transformers style""" +class BltCrossAttention(nn.Module): + """Cross-attention module for Blt, following transformers style""" - def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None): + def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] = None): super().__init__() self.config = config self.num_heads = self.config.num_attention_heads @@ -502,8 +503,8 @@ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) # needs to stay hidden_size, NOT head_dim - self.q_norm = BLTRMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.k_norm = BLTRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.is_causal = False def forward( @@ -559,14 +560,14 @@ def forward( return attn_output, attn_weights -class BLTGlobalTransformer(nn.Module): - def __init__(self, config: BLTGlobalTransformerConfig): +class BltGlobalTransformer(nn.Module): + def __init__(self, config: BltGlobalTransformerConfig): super().__init__() self.config = config self.layers = nn.ModuleList() for layer_idx in range(config.num_hidden_layers): - self.layers.append(BLTTransformerLayer(config, layer_idx)) - self.rotary_emb = BLTRotaryEmbedding(config=config) + self.layers.append(BltTransformerLayer(config, layer_idx)) + self.rotary_emb = BltRotaryEmbedding(config=config) def forward( self, @@ -791,13 +792,13 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optiona @auto_docstring -class BLTPreTrainedModel(PreTrainedModel): - """BLT PreTrainedModel inheriting from Mllama but with BLT-specific init.""" +class BltPreTrainedModel(PreTrainedModel): + """Blt PreTrainedModel inheriting from Mllama but with Blt-specific init.""" - config: BLTConfig + config: BltConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] _supports_static_cache = False # static cache cannot have different shapes for each layer _supports_sdpa = True @@ -805,16 +806,16 @@ class BLTPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True - config_class = BLTConfig + config_class = BltConfig _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = False _supports_cache_class = False _can_record_outputs = { - "hidden_states": OutputRecorder(BLTTransformerLayer, index=0, layer_name="local_decoder"), - "attentions": OutputRecorder(BLTSelfAttention, index=1, layer_name="local_decoder"), - "encoder_attentions": OutputRecorder(BLTSelfAttention, index=1, layer_name="local_encoder"), - "global_attentions": OutputRecorder(BLTSelfAttention, index=1, layer_name="global_transformer"), + "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), + "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), + "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), + "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), } def _init_weights(self, module): @@ -831,7 +832,7 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) module.bias.data.zero_() - elif isinstance(module, BLTRMSNorm): + elif isinstance(module, BltRMSNorm): module.weight.data.fill_(1.0) elif isinstance(module, nn.RMSNorm): module.weight.data.fill_(1.0) @@ -959,14 +960,14 @@ def _prepare_4d_causal_attention_mask_with_cache_position( # Top-level model classes -class BLTModel(BLTPreTrainedModel): - def __init__(self, config: BLTConfig): +class BltModel(BltPreTrainedModel): + def __init__(self, config: BltConfig): super().__init__(config) self.gradient_checkpointing = False self.config = config - self.local_encoder = BLTLocalEncoder(config.encoder_config) - self.global_transformer = BLTGlobalTransformer(config.global_config) - self.local_decoder = BLTLocalDecoder(config.decoder_config) + self.local_encoder = BltLocalEncoder(config.encoder_config) + self.global_transformer = BltGlobalTransformer(config.global_config) + self.local_decoder = BltLocalDecoder(config.decoder_config) num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size) embeddings = [ nn.Embedding(config.encoder_hash_byte_group_vocab, config.encoder_config.hidden_size) @@ -974,7 +975,7 @@ def __init__(self, config: BLTConfig): ] self.encoder_hash_tok_embedding = nn.ModuleList(embeddings) if self.config.patch_in_forward: - self.patcher = BLTPatcher(config.patcher_config) + self.patcher = BltPatcher(config.patcher_config) self.patcher.eval() for param in self.patcher.parameters(): param.requires_grad = False @@ -1125,15 +1126,15 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 -class BLTPatcher(BLTPreTrainedModel): - def __init__(self, config: BLTPatcherConfig): +class BltPatcher(BltPreTrainedModel): + def __init__(self, config: BltPatcherConfig): super().__init__(config) - self.rotary_emb = BLTRotaryEmbedding(config=self.config) + self.rotary_emb = BltRotaryEmbedding(config=self.config) self.layers = nn.ModuleList() for layer_idx in range(self.config.num_hidden_layers): - self.layers.append(BLTTransformerLayer(self.config, layer_idx)) + self.layers.append(BltTransformerLayer(self.config, layer_idx)) self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) - self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.lm_head = nn.Linear( self.config.hidden_size, self.config.vocab_size, @@ -1165,7 +1166,7 @@ def forward( batch_size, sequence_length = split.shape input_embeds = self.embed_tokens(split) hidden_states = input_embeds - batch_size, _, _ = input_embeds.shape + batch_size = input_embeds.shape[0] position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -1267,22 +1268,22 @@ def patch_lengths_from_entropies( @auto_docstring( custom_intro=""" - The BLT Text Model with a language modeling head on top. + The Blt Text Model with a language modeling head on top. """ ) -class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): - config: BLTConfig +class BltForCausalLM(BltPreTrainedModel, GenerationMixin): + config: BltConfig _supports_static_cache = True # only the LLM without cross attn can do compile base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] supports_gradient_checkpointing = True - _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] def __init__(self, config): super().__init__(config.get_text_config()) self.text_config = config.get_text_config() self.vocab_size = config.vocab_size - self.model = BLTModel(config) + self.model = BltModel(config) self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) self.post_init() @@ -1351,9 +1352,9 @@ def forward( Example: ```python - >>> from transformers import AutoTokenizer, BLTForCausalLM + >>> from transformers import AutoTokenizer, BltForCausalLM - >>> model = BLTForCausalLM.from_pretrained("Llama-3.2-11B-Vision") + >>> model = BltForCausalLM.from_pretrained("Llama-3.2-11B-Vision") >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") >>> prompt = "If I had to write a haiku, it would be:" @@ -1398,7 +1399,7 @@ def forward( attentions=outputs.attentions, ) - # Add BLT-specific attention outputs + # Add Blt-specific attention outputs if hasattr(outputs, 'encoder_attentions'): output.encoder_attentions = outputs.encoder_attentions if hasattr(outputs, 'global_attentions'): @@ -1408,8 +1409,8 @@ def forward( __all__ = [ - "BLTPreTrainedModel", - "BLTModel", - "BLTPatcher", - "BLTForCausalLM", + "BltPreTrainedModel", + "BltModel", + "BltPatcher", + "BltForCausalLM", ] \ No newline at end of file diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py index 1c6c39544a59..e4c0afdbc821 100644 --- a/src/transformers/models/blt/tokenization_blt.py +++ b/src/transformers/models/blt/tokenization_blt.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tokenization classes for BLT.""" +"""Tokenization classes for Blt.""" from typing import TYPE_CHECKING, Optional @@ -25,7 +25,7 @@ logger = logging.get_logger(__name__) -# BLT tokenizer constants +# Blt tokenizer constants SEP = " " BOS_ID: int = 1 EOS_ID: int = 2 @@ -35,12 +35,12 @@ OFFSET: int = 4 BYTE_UNITS: int = 256 -VOCAB_FILES_NAMES = {} # BLT doesn't require external vocab files +VOCAB_FILES_NAMES = {} # Blt doesn't require external vocab files -class BLTTokenizer(PreTrainedTokenizer): +class BltTokenizer(PreTrainedTokenizer): """ - Construct a BLT tokenizer. Based on byte-level tokenization where each byte is treated as a token. + Construct a Blt tokenizer. Based on byte-level tokenization where each byte is treated as a token. This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset. It supports special tokens for beginning of sequence (BOS), end of sequence (EOS), @@ -54,9 +54,9 @@ class BLTTokenizer(PreTrainedTokenizer): pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): The padding token. unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The unknown token. Not used in BLT but kept for compatibility. + The unknown token. Not used in Blt but kept for compatibility. boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The beginning of example token, specific to BLT. + The beginning of example token, specific to Blt. add_bos_token (`bool`, *optional*, defaults to `True`): Whether or not to add a `bos_token` at the start of sequences. add_eos_token (`bool`, *optional*, defaults to `True`): @@ -83,13 +83,13 @@ def __init__( spaces_between_special_tokens=False, **kwargs, ): - # Store BLT-specific parameters first + # Store Blt-specific parameters first self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token self.vocab_size_unit_1 = BYTE_UNITS self.offsetting_special_char = OFFSET - # BLT token IDs (exactly like original) + # Blt token IDs (exactly like original) self.boe_id = BOE_ID self.bos_id = BOS_ID self.eos_id = EOS_ID @@ -174,7 +174,7 @@ def convert_tokens_to_string(self, tokens: list[str]) -> str: return bytes(byte_values).decode("utf-8", errors="ignore") def _tokenize(self, text: str, **kwargs) -> list[str]: - """Converts a string to a list of tokens. For BLT, we work directly with byte values.""" + """Converts a string to a list of tokens. For Blt, we work directly with byte values.""" return [str(byte_val) for byte_val in text.encode("utf-8", errors="ignore")] def build_inputs_with_special_tokens( @@ -182,7 +182,7 @@ def build_inputs_with_special_tokens( ) -> list[int]: """ Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating and - adding special tokens. A BLT sequence has the following format: + adding special tokens. A Blt sequence has the following format: - single sequence: ` X ` - pair of sequences: ` A B ` @@ -238,8 +238,8 @@ def get_vocab_size(self) -> int: return self.vocab_size def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: - # BLT doesn't require external vocabulary files since it uses byte-level tokenization + # Blt doesn't require external vocabulary files since it uses byte-level tokenization return () -__all__ = ["BLTTokenizer"] +__all__ = ["BltTokenizer"] diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index d07ac78a6d39..0ece0ec2e288 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Testing suite for the PyTorch BLT model.""" +"""Testing suite for the PyTorch Blt model.""" import unittest @@ -41,15 +41,15 @@ if is_torch_available(): import torch - from transformers import BLTConfig, BLTForCausalLM, BLTModel - from transformers.models.blt.modeling_blt import BLTRotaryEmbedding + from transformers import BltConfig, BltForCausalLM, BltModel +from transformers.models.blt.modeling_blt import BltRotaryEmbedding -class BLTModelTester(CausalLMModelTester): +class BltModelTester(CausalLMModelTester): if is_torch_available(): - config_class = BLTConfig - base_model_class = BLTModel - causal_lm_class = BLTForCausalLM + config_class = BltConfig + base_model_class = BltModel + causal_lm_class = BltForCausalLM def __init__( self, @@ -143,7 +143,7 @@ def __init__( self.num_hidden_layers = self.encoder_config["num_hidden_layers"] def get_config(self): - config = BLTConfig( + config = BltConfig( vocab_size=self.vocab_size, max_position_embeddings=self.max_position_embeddings, patch_in_forward=False, # Disable patching for tests @@ -172,19 +172,19 @@ def get_config(self): @require_torch -class BLTModelTest(CausalLMModelTest, unittest.TestCase): +class BltModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( ( - BLTModel, - BLTForCausalLM, + BltModel, + BltForCausalLM, ) if is_torch_available() else () ) pipeline_model_mapping = ( { - "feature-extraction": BLTModel, - "text-generation": BLTForCausalLM, + "feature-extraction": BltModel, + "text-generation": BltForCausalLM, } if is_torch_available() else {} @@ -192,29 +192,29 @@ class BLTModelTest(CausalLMModelTest, unittest.TestCase): test_headmasking = False test_pruning = False fx_compatible = False - model_tester_class = BLTModelTester - rotary_embedding_layer = BLTRotaryEmbedding # Enables RoPE tests if set + model_tester_class = BltModelTester + rotary_embedding_layer = BltRotaryEmbedding # Enables RoPE tests if set # Need to use `0.8` instead of `0.9` for `test_cpu_offload` # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] # used in `test_torch_compile_for_training` - _torch_compile_train_cls = BLTForCausalLM if is_torch_available() else None + _torch_compile_train_cls = BltForCausalLM if is_torch_available() else None @pytest.mark.generate @parameterized.expand([("greedy", 1), ("beam search", 2)]) def test_generate_from_inputs_embeds(self, _, num_beams): - """Skip this test for BLT as it has complex embedding computation that requires real token IDs for hash-based embeddings.""" + """Skip this test for Blt as it has complex embedding computation that requires real token IDs for hash-based embeddings.""" self.skipTest( - "BLT requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs" + "Blt requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs" ) @pytest.mark.generate def test_inputs_embeds_matches_input_ids(self): - """Skip this test for BLT as it has complex embedding computation that requires real token IDs for hash-based embeddings.""" + """Skip this test for Blt as it has complex embedding computation that requires real token IDs for hash-based embeddings.""" self.skipTest( - "BLT requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs" + "Blt requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs" ) @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) @@ -248,16 +248,16 @@ def test_eager_matches_sdpa_inference( ) def test_torchscript_simple(self): - """Skip torchscript test for BLT as it has complex patching logic that's not compatible.""" - self.skipTest("BLT has complex patching logic that's not compatible with torchscript") + """Skip torchscript test for Blt as it has complex patching logic that's not compatible.""" + self.skipTest("Blt has complex patching logic that's not compatible with torchscript") def test_torchscript_output_hidden_state(self): - """Skip torchscript test for BLT as it has complex patching logic that's not compatible.""" - self.skipTest("BLT has complex patching logic that's not compatible with torchscript") + """Skip torchscript test for Blt as it has complex patching logic that's not compatible.""" + self.skipTest("Blt has complex patching logic that's not compatible with torchscript") @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) def test_model_rope_scaling_from_config(self, scaling_type): - """Override rope scaling from config test to handle BLT's sub-config structure.""" + """Override rope scaling from config test to handle Blt's sub-config structure.""" if self.rotary_embedding_layer is None: self.skipTest("Rotary embedding layer not set") config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -273,7 +273,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): set_seed(42) # Fixed seed at init time so the two models get the same random weights config.rope_scaling = {"rope_type": scaling_type, "factor": 10.0} - # Propagate rope_scaling to sub-configs for BLT + # Propagate rope_scaling to sub-configs for Blt config.encoder_config.rope_scaling = config.rope_scaling config.decoder_config.rope_scaling = config.rope_scaling config.global_config.rope_scaling = config.rope_scaling @@ -332,7 +332,7 @@ def test_attention_outputs(self): model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - # For BLT, check separate attention outputs from each component + # For Blt, check separate attention outputs from each component attentions = outputs.attentions encoder_attentions = outputs.encoder_attentions global_attentions = outputs.global_attentions @@ -344,7 +344,7 @@ def test_attention_outputs(self): @require_torch_accelerator -class BLTIntegrationTest(unittest.TestCase): +class BltIntegrationTest(unittest.TestCase): def tearDown(self): # TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves # some memory allocated in the cache, which means some object is not being released properly. This causes some @@ -360,7 +360,7 @@ def test_model(self): prompt = "my name is" - model = BLTForCausalLM.from_pretrained( + model = BltForCausalLM.from_pretrained( "itazap/blt-1b-testing", device_map="auto", attn_implementation="sdpa" @@ -449,7 +449,7 @@ def test_model_logits(self): input_ids = [1, 42, 21, 12, 43, 23, 1, 4] - model = BLTForCausalLM.from_pretrained("itazap/blt-1b-testing", attn_implementation="sdpa", device_map="auto") + model = BltForCausalLM.from_pretrained("itazap/blt-1b-testing", attn_implementation="sdpa", device_map="auto") with torch.no_grad(): output = model(torch.tensor([input_ids]).to(torch_device))[0] @@ -460,13 +460,13 @@ def test_model_logits(self): @require_read_token @require_torch_bf16 def test_model_bf16(self): - """Test BLT model with bfloat16 precision.""" + """Test Blt model with bfloat16 precision.""" NUM_TOKENS_TO_GENERATE = 200 EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m" prompt = "my name is" - model = BLTForCausalLM.from_pretrained( + model = BltForCausalLM.from_pretrained( "itazap/blt-1b-testing", device_map="auto", attn_implementation="sdpa", @@ -485,7 +485,7 @@ def test_model_bf16(self): @require_read_token @require_torch_bf16 def test_model_logits_bf16(self): - """Test BLT model logits with bfloat16 precision.""" + """Test Blt model logits with bfloat16 precision.""" EXPECTED_OUTPUT = torch.tensor( [ [ @@ -558,7 +558,7 @@ def test_model_logits_bf16(self): input_ids = [1, 42, 21, 12, 43, 23, 1, 4] - model = BLTForCausalLM.from_pretrained( + model = BltForCausalLM.from_pretrained( "itazap/blt-1b-testing", device_map="auto", attn_implementation="sdpa", @@ -574,13 +574,13 @@ def test_model_logits_bf16(self): @slow @require_read_token def test_model_eager(self): - """Test BLT model with bfloat16 precision using eager attention implementation.""" + """Test Blt model with bfloat16 precision using eager attention implementation.""" NUM_TOKENS_TO_GENERATE = 200 EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s" prompt = "my name is" - model = BLTForCausalLM.from_pretrained( + model = BltForCausalLM.from_pretrained( "itazap/blt-1b-testing", device_map="auto", attn_implementation="eager") @@ -598,13 +598,13 @@ def test_model_eager(self): @require_read_token @require_torch_bf16 def test_model_bf16_static_cache(self): - """Test BLT model with bfloat16 precision and static cache.""" + """Test Blt model with bfloat16 precision and static cache.""" NUM_TOKENS_TO_GENERATE = 200 EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m" prompt = "my name is" - model = BLTForCausalLM.from_pretrained( + model = BltForCausalLM.from_pretrained( "itazap/blt-1b-testing", device_map="auto", attn_implementation="sdpa", From 232d245fcc558f3b2331192b612823e6bdc8d8ef Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 28 Jul 2025 14:24:57 +0000 Subject: [PATCH 099/139] forgot tokenizer test renamed --- tests/models/blt/test_tokenization_blt.py | 50 +++++++++++------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/tests/models/blt/test_tokenization_blt.py b/tests/models/blt/test_tokenization_blt.py index 503bf504c3ad..d2077104ac2c 100644 --- a/tests/models/blt/test_tokenization_blt.py +++ b/tests/models/blt/test_tokenization_blt.py @@ -14,7 +14,7 @@ import unittest -from transformers import BLTTokenizer +from transformers import BltTokenizer from transformers.testing_utils import require_tokenizers from transformers.tokenization_utils import AddedToken @@ -22,9 +22,9 @@ @require_tokenizers -class BLTTokenizationTest(TokenizerTesterMixin, unittest.TestCase): +class BltTokenizationTest(TokenizerTesterMixin, unittest.TestCase): from_pretrained_id = [] - tokenizer_class = BLTTokenizer + tokenizer_class = BltTokenizer rust_tokenizer_class = None test_rust_tokenizer = False @@ -35,8 +35,8 @@ class BLTTokenizationTest(TokenizerTesterMixin, unittest.TestCase): @classmethod def setUpClass(cls): super().setUpClass() - # Create a simple BLT tokenizer for testing - tokenizer = BLTTokenizer() + # Create a simple Blt tokenizer for testing + tokenizer = BltTokenizer() tokenizer.save_pretrained(cls.tmpdirname) def get_tokenizers(self, **kwargs): @@ -44,8 +44,8 @@ def get_tokenizers(self, **kwargs): return super().get_tokenizers(**kwargs) def test_blt_tokenizer_basics(self): - """Test basic BLT tokenizer functionality""" - tokenizer = BLTTokenizer() + """Test basic Blt tokenizer functionality""" + tokenizer = BltTokenizer() # Test vocab size (256 bytes + 4 offset + special tokens) self.assertEqual(tokenizer.vocab_size, 261) @@ -63,7 +63,7 @@ def test_blt_tokenizer_basics(self): self.assertEqual(str(tokenizer.pad_token), "") def test_simple_encode_decode(self): - tokenizer = BLTTokenizer(add_bos_token=False, add_eos_token=False) + tokenizer = BltTokenizer(add_bos_token=False, add_eos_token=False) text = "Hello" encoded = tokenizer.encode(text, add_special_tokens=False) @@ -77,7 +77,7 @@ def test_simple_encode_decode(self): self.assertEqual(decoded, text) def test_special_tokens_encoding(self): - tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) + tokenizer = BltTokenizer(add_bos_token=True, add_eos_token=True) text = "Hi" encoded = tokenizer.encode(text, add_special_tokens=True) @@ -88,7 +88,7 @@ def test_special_tokens_encoding(self): self.assertEqual(encoded, expected) def test_tokenize_method(self): - tokenizer = BLTTokenizer() + tokenizer = BltTokenizer() text = "ABC" tokens = tokenizer._tokenize(text) @@ -99,7 +99,7 @@ def test_tokenize_method(self): def test_token_conversion(self): """Test token to ID and ID to token conversion""" - tokenizer = BLTTokenizer() + tokenizer = BltTokenizer() # Test byte token conversion token = "65" # Byte value for 'A' @@ -116,7 +116,7 @@ def test_token_conversion(self): self.assertEqual(bos_token, str(tokenizer.bos_token)) def test_convert_tokens_to_string(self): - tokenizer = BLTTokenizer() + tokenizer = BltTokenizer() tokens = ["72", "101", "108", "108", "111"] # "Hello" in bytes result = tokenizer.convert_tokens_to_string(tokens) @@ -128,7 +128,7 @@ def test_convert_tokens_to_string(self): self.assertEqual(result, "Hi") def test_unicode_handling(self): - tokenizer = BLTTokenizer(add_bos_token=False, add_eos_token=False) + tokenizer = BltTokenizer(add_bos_token=False, add_eos_token=False) # Test Unicode character (é) text = "café" @@ -143,7 +143,7 @@ def test_unicode_handling(self): self.assertEqual(decoded, text) def test_empty_and_whitespace(self): - tokenizer = BLTTokenizer(add_bos_token=False, add_eos_token=False) + tokenizer = BltTokenizer(add_bos_token=False, add_eos_token=False) # Test empty string encoded = tokenizer.encode("", add_special_tokens=False) @@ -158,7 +158,7 @@ def test_empty_and_whitespace(self): self.assertEqual(decoded, " ") def test_get_vocab(self): - tokenizer = BLTTokenizer() + tokenizer = BltTokenizer() vocab = tokenizer.get_vocab() # Should contain special tokens @@ -177,7 +177,7 @@ def test_get_vocab(self): self.assertEqual(vocab["255"], 259) # 255 + 4 offset def test_build_inputs_with_special_tokens(self): - tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) + tokenizer = BltTokenizer(add_bos_token=True, add_eos_token=True) # Single sequence token_ids = [76, 109] # "Hi" encoded (H=72+4=76, i=105+4=109) @@ -193,7 +193,7 @@ def test_build_inputs_with_special_tokens(self): self.assertEqual(result, expected) def test_special_tokens_mask(self): - tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True) + tokenizer = BltTokenizer(add_bos_token=True, add_eos_token=True) token_ids = [76, 109] # "Hi" encoded (H=72+4=76, i=105+4=109) mask = tokenizer.get_special_tokens_mask(token_ids) @@ -201,24 +201,24 @@ def test_special_tokens_mask(self): self.assertEqual(mask, expected) def test_add_special_tokens_flags(self): - tokenizer1 = BLTTokenizer(add_bos_token=True, add_eos_token=True) + tokenizer1 = BltTokenizer(add_bos_token=True, add_eos_token=True) encoded1 = tokenizer1.encode("Hi", add_special_tokens=True) self.assertEqual(encoded1[0], 1) # BOS self.assertEqual(encoded1[-1], 2) # EOS - tokenizer2 = BLTTokenizer(add_bos_token=False, add_eos_token=False) + tokenizer2 = BltTokenizer(add_bos_token=False, add_eos_token=False) encoded2 = tokenizer2.encode("Hi", add_special_tokens=True) self.assertNotEqual(encoded2[0], 1) # No BOS self.assertNotEqual(encoded2[-1], 2) # No EOS # Test with only BOS - tokenizer3 = BLTTokenizer(add_bos_token=True, add_eos_token=False) + tokenizer3 = BltTokenizer(add_bos_token=True, add_eos_token=False) encoded3 = tokenizer3.encode("Hi", add_special_tokens=True) self.assertEqual(encoded3[0], 1) # BOS self.assertNotEqual(encoded3[-1], 2) # No EOS def test_added_tokens(self): - tokenizer = BLTTokenizer() + tokenizer = BltTokenizer() custom_token = AddedToken("", normalized=False, special=True) tokenizer.add_tokens([custom_token]) @@ -231,19 +231,19 @@ def test_added_tokens(self): back_token = tokenizer._convert_id_to_token(token_id) self.assertEqual(back_token, "") - @unittest.skip("BLT is byte-level, special tokens are encoded as bytes") + @unittest.skip("Blt is byte-level, special tokens are encoded as bytes") def test_add_special_tokens(self): pass - @unittest.skip("BLT byte-level tokenization doesn't handle pretokenized inputs the same way") + @unittest.skip("Blt byte-level tokenization doesn't handle pretokenized inputs the same way") def test_pretokenized_inputs(self): pass - @unittest.skip("BLT encodes added tokens as bytes, not single tokens") + @unittest.skip("Blt encodes added tokens as bytes, not single tokens") def test_add_tokens_tokenizer(self): pass - @unittest.skip("BLT tokenizer serialization needs additional work for added tokens") + @unittest.skip("Blt tokenizer serialization needs additional work for added tokens") def test_save_and_load_tokenizer(self): pass From 703fab75f47ec3f0de7e5aefcae2e11f88d4dead Mon Sep 17 00:00:00 2001 From: itazap Date: Tue, 29 Jul 2025 11:01:39 +0200 Subject: [PATCH 100/139] Reorder --- src/transformers/models/blt/modeling_blt.py | 32 --------------------- 1 file changed, 32 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 48af5f56d14b..c4e083311811 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -805,12 +805,6 @@ class BltPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True - - config_class = BltConfig - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = False - _supports_cache_class = False - _can_record_outputs = { "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), @@ -818,24 +812,6 @@ class BltPreTrainedModel(PreTrainedModel): "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), } - def _init_weights(self, module): - std = self.config.initializer_range - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - elif isinstance(module, BltRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, nn.RMSNorm): - module.weight.data.fill_(1.0) def _update_causal_mask( self, @@ -1406,11 +1382,3 @@ def forward( output.global_attentions = outputs.global_attentions return output - - -__all__ = [ - "BltPreTrainedModel", - "BltModel", - "BltPatcher", - "BltForCausalLM", -] \ No newline at end of file From ec9b4c08f1406f6e7f506f2dde4a2de62abf552c Mon Sep 17 00:00:00 2001 From: itazap Date: Tue, 29 Jul 2025 11:08:02 +0200 Subject: [PATCH 101/139] Reorder --- src/transformers/models/blt/modeling_blt.py | 299 ++++++++++---------- 1 file changed, 154 insertions(+), 145 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index c4e083311811..0407313b2ec6 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -598,6 +598,152 @@ def forward( return hidden_states + + +@auto_docstring +class BltPreTrainedModel(PreTrainedModel): + """Blt PreTrainedModel inheriting from Mllama but with Blt-specific init.""" + + config: BltConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] + + _supports_static_cache = False # static cache cannot have different shapes for each layer + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), + "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), + "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), + "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), + } + + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): primes = [ 1000000007, @@ -791,150 +937,6 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optiona -@auto_docstring -class BltPreTrainedModel(PreTrainedModel): - """Blt PreTrainedModel inheriting from Mllama but with Blt-specific init.""" - - config: BltConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] - - _supports_static_cache = False # static cache cannot have different shapes for each layer - _supports_sdpa = True - _supports_flash_attn = True - _supports_flex_attn = True - _supports_attention_backend = True - _can_record_outputs = { - "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), - "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), - "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), - "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), - } - - - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - # Top-level model classes class BltModel(BltPreTrainedModel): def __init__(self, config: BltConfig): @@ -1375,10 +1377,17 @@ def forward( attentions=outputs.attentions, ) - # Add Blt-specific attention outputs + # Add Blt-specific attention outputs if hasattr(outputs, 'encoder_attentions'): output.encoder_attentions = outputs.encoder_attentions if hasattr(outputs, 'global_attentions'): output.global_attentions = outputs.global_attentions return output + +__all__ = [ + "BltPreTrainedModel", + "BltModel", + "BltPatcher", + "BltForCausalLM", +] \ No newline at end of file From 3117a0388e4f0e1baca6e54ea0a3fbab2c66d750 Mon Sep 17 00:00:00 2001 From: itazap Date: Tue, 29 Jul 2025 16:21:49 +0200 Subject: [PATCH 102/139] working on modular --- src/transformers/models/blt/modeling_blt.py | 29 ++++++++++++--------- src/transformers/models/blt/modular_blt.py | 6 +++++ 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 0407313b2ec6..7d15ccc314b4 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -278,10 +278,11 @@ def forward( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attn_impl = getattr(self.config, "_attn_implementation", None) or "eager" + attention_interface: Callable = eager_attention_forward - if attn_impl != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, @@ -540,10 +541,10 @@ def forward( raise ValueError( "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" ) - attn_impl = getattr(self.config, "_attn_implementation", None) or "eager" attention_interface: Callable = eager_attention_forward - if attn_impl != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, @@ -943,6 +944,12 @@ def __init__(self, config: BltConfig): super().__init__(config) self.gradient_checkpointing = False self.config = config + # Set _attn_implementation on all sub-configs as they are not PreTrainedModels + config.patcher_config._attn_implementation = config._attn_implementation + config.encoder_config._attn_implementation = config._attn_implementation + config.decoder_config._attn_implementation = config._attn_implementation + config.global_config._attn_implementation = config._attn_implementation + self.local_encoder = BltLocalEncoder(config.encoder_config) self.global_transformer = BltGlobalTransformer(config.global_config) self.local_decoder = BltLocalDecoder(config.decoder_config) @@ -1250,22 +1257,18 @@ def patch_lengths_from_entropies( """ ) class BltForCausalLM(BltPreTrainedModel, GenerationMixin): - config: BltConfig - _supports_static_cache = True # only the LLM without cross attn can do compile - base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] - supports_gradient_checkpointing = True _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] - def __init__(self, config): - super().__init__(config.get_text_config()) - self.text_config = config.get_text_config() + def __init__(self, config: BltConfig): + super().__init__(config) self.vocab_size = config.vocab_size self.model = BltModel(config) self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) self.post_init() + def get_input_embeddings(self): return self.model.local_encoder.embed_tokens diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index ca9b037922b9..c2eb03f29d1b 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -770,6 +770,12 @@ class BLTModel(BLTPreTrainedModel): def __init__(self, config: BLTConfig): super().__init__(config) self.config = config + + config.patcher_config._attn_implementation = config._attn_implementation + config.encoder_config._attn_implementation = config._attn_implementation + config.decoder_config._attn_implementation = config._attn_implementation + config.global_config._attn_implementation = config._attn_implementation + self.local_encoder = BLTLocalEncoder(config.encoder_config) self.global_transformer = BLTGlobalTransformer(config.global_config) self.local_decoder = BLTLocalDecoder(config.decoder_config) From eb4cd4140f81c9336c33b55917a27f636914af6b Mon Sep 17 00:00:00 2001 From: itazap Date: Tue, 29 Jul 2025 17:32:27 +0200 Subject: [PATCH 103/139] updates from modular --- src/transformers/models/blt/modeling_blt.py | 164 ++++++++++---------- tests/models/blt/test_modeling_blt.py | 60 +++---- 2 files changed, 118 insertions(+), 106 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 7d15ccc314b4..4e77d0721f2b 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/blt/modular_blt.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_blt.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 HuggingFace Inc. team. All rights reserved. # @@ -25,13 +31,13 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging -from ...utils.generic import check_model_inputs, OutputRecorder - +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available +from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_blt import ( BltConfig, BltGlobalTransformerConfig, @@ -44,14 +50,9 @@ if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask - -if is_torch_flex_attn_available(): from ...integrations.flex_attention import make_flex_block_causal_mask -logger = logging.get_logger(__name__) - - class BltMLP(nn.Module): def __init__(self, config): super().__init__() @@ -64,7 +65,7 @@ def __init__(self, config): # Ignore copy self.act_fn = ACT2FN[config.hidden_act] - def forward(self, x, **kwargs: Unpack[TransformersKwargs]): + def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -78,7 +79,7 @@ def __init__(self, hidden_size, eps=1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def forward(self, hidden_states, **kwargs: Unpack[TransformersKwargs]): + def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) @@ -109,7 +110,7 @@ def __init__(self, config: BltConfig, device=None): @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids, **kwargs: Unpack[TransformersKwargs]): + def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() @@ -124,40 +125,75 @@ def forward(self, x, position_ids, **kwargs: Unpack[TransformersKwargs]): # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer -class BltTransformerLayer(nn.Module): +class BltTransformerLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size + self.self_attn = BltSelfAttention(config=config, layer_idx=layer_idx) self.mlp = BltMLP(config) self.input_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx def forward( self, hidden_states: torch.Tensor, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + cross_attention_states: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> torch.Tensor: + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) - hidden_states, attn_weights = self.self_attn( + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - position_embeddings=position_embeddings, + position_ids=position_ids, past_key_value=past_key_value, + use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states + + # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states + return hidden_states @@ -247,6 +283,7 @@ def __init__(self, config: BltConfig, layer_idx: int): self.scaling = self.head_dim**-0.5 self.rope_theta = config.rope_theta self.layer_idx = layer_idx + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -256,16 +293,15 @@ def __init__(self, config: BltConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor], - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[TransformersKwargs], + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, ): - # Ensure hidden_states is always 3D (batch, seq_len, hidden_dim) - if hidden_states.dim() == 2: - hidden_states = hidden_states.unsqueeze(0) bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -273,13 +309,17 @@ def forward( query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -293,8 +333,10 @@ def forward( scaling=self.scaling, **kwargs, ) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) + return attn_output, attn_weights @@ -455,9 +497,7 @@ def forward( if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds if position_ids is None: - position_ids = ( - torch.arange(embeds.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) - ) + position_ids = torch.arange(embeds.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for i, layer in enumerate(self.layers): @@ -503,7 +543,6 @@ def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - # needs to stay hidden_size, NOT head_dim self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.is_causal = False @@ -517,7 +556,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], - ): + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() query_states = self.q_norm(hidden_states) query_states = self.q_proj(query_states) @@ -544,7 +584,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, attn_weights = attention_interface( self, query_states, @@ -599,8 +638,6 @@ def forward( return hidden_states - - @auto_docstring class BltPreTrainedModel(PreTrainedModel): """Blt PreTrainedModel inheriting from Mllama but with Blt-specific init.""" @@ -622,7 +659,6 @@ class BltPreTrainedModel(PreTrainedModel): "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), } - def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -937,19 +973,16 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optiona return padded - -# Top-level model classes class BltModel(BltPreTrainedModel): def __init__(self, config: BltConfig): super().__init__(config) self.gradient_checkpointing = False - self.config = config - # Set _attn_implementation on all sub-configs as they are not PreTrainedModels config.patcher_config._attn_implementation = config._attn_implementation config.encoder_config._attn_implementation = config._attn_implementation config.decoder_config._attn_implementation = config._attn_implementation config.global_config._attn_implementation = config._attn_implementation - + + self.config = config self.local_encoder = BltLocalEncoder(config.encoder_config) self.global_transformer = BltGlobalTransformer(config.global_config) self.local_decoder = BltLocalDecoder(config.decoder_config) @@ -1027,9 +1060,7 @@ def forward( ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, encoder_embeds, cache_position, past_key_values - ) + causal_mask = self._update_causal_mask(attention_mask, encoder_embeds, cache_position, past_key_values) cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) @@ -1050,13 +1081,9 @@ def forward( **kwargs, ) global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - global_cache_position = torch.arange( - 0, global_hidden_states.shape[1], device=global_hidden_states.device - ) + global_cache_position = torch.arange(0, global_hidden_states.shape[1], device=global_hidden_states.device) global_position_ids = global_cache_position.unsqueeze(0) - global_causal_mask = self._update_causal_mask( - None, global_hidden_states, global_cache_position, None - ) + global_causal_mask = self._update_causal_mask(None, global_hidden_states, global_cache_position, None) global_hidden_states = self.global_transformer( input_embeds=global_hidden_states, attention_mask=global_causal_mask, @@ -1164,7 +1191,9 @@ def forward( ) for i, layer in enumerate(self.layers): - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask) + layer_outputs = layer( + hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask + ) hidden_states = layer_outputs[0] logits = self.lm_head(self.norm(hidden_states)) @@ -1251,11 +1280,6 @@ def patch_lengths_from_entropies( return patch_lengths -@auto_docstring( - custom_intro=""" - The Blt Text Model with a language modeling head on top. - """ -) class BltForCausalLM(BltPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] @@ -1268,7 +1292,6 @@ def __init__(self, config: BltConfig): self.post_init() - def get_input_embeddings(self): return self.model.local_encoder.embed_tokens @@ -1281,12 +1304,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - @can_return_tuple @auto_docstring def forward( @@ -1333,9 +1350,9 @@ def forward( Example: ```python - >>> from transformers import AutoTokenizer, BltForCausalLM + >>> from transformers import AutoTokenizer, XBltForCausalLM - >>> model = BltForCausalLM.from_pretrained("Llama-3.2-11B-Vision") + >>> model = XBltForCausalLM.from_pretrained("Llama-3.2-11B-Vision") >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") >>> prompt = "If I had to write a haiku, it would be:" @@ -1372,25 +1389,12 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) - output = CausalLMOutputWithPast( + return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - - # Add Blt-specific attention outputs - if hasattr(outputs, 'encoder_attentions'): - output.encoder_attentions = outputs.encoder_attentions - if hasattr(outputs, 'global_attentions'): - output.global_attentions = outputs.global_attentions - - return output - -__all__ = [ - "BltPreTrainedModel", - "BltModel", - "BltPatcher", - "BltForCausalLM", -] \ No newline at end of file + +__all__ = ["BltPreTrainedModel", "BltModel", "BltPatcher", "BltForCausalLM"] diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 0ece0ec2e288..b7a3df150dba 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -314,33 +314,41 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_flex_attention_with_grads(): return - def test_attention_outputs(self): - if not self.has_attentions: - self.skipTest(reason="Model does not output attentions") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.return_dict = True - config._attn_implementation = "eager" - - for model_class in self.all_model_classes: - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = False - config.return_dict = True - model = model_class._from_config(config, attn_implementation="eager") - config = model.config - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - # For Blt, check separate attention outputs from each component - attentions = outputs.attentions - encoder_attentions = outputs.encoder_attentions - global_attentions = outputs.global_attentions + @unittest.skip(reason="Padding with patcher is complex") + def test_eager_padding_matches_padding_free_with_position_ids(): + return + + @unittest.skip(reason="Padding with patcher is complex") + def test_sdpa_padding_matches_padding_free_with_position_ids(): + return + + # def test_attention_outputs(self): + # if not self.has_attentions: + # self.skipTest(reason="Model does not output attentions") + + # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # config.return_dict = True + # config._attn_implementation = "eager" + + # for model_class in self.all_model_classes: + # inputs_dict["output_attentions"] = True + # inputs_dict["output_hidden_states"] = False + # config.return_dict = True + # model = model_class._from_config(config, attn_implementation="eager") + # config = model.config + # model.to(torch_device) + # model.eval() + # with torch.no_grad(): + # outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + # # For Blt, check separate attention outputs from each component + # attentions = outputs.attentions + # encoder_attentions = outputs.encoder_attentions + # global_attentions = outputs.global_attentions - # Each component should have attention outputs equal to their layer count - self.assertEqual(len(attentions), config.decoder_config.num_hidden_layers) - self.assertEqual(len(encoder_attentions), config.encoder_config.num_hidden_layers) - self.assertEqual(len(global_attentions), config.global_config.num_hidden_layers) + # # Each component should have attention outputs equal to their layer count + # self.assertEqual(len(attentions), config.decoder_config.num_hidden_layers) + # self.assertEqual(len(encoder_attentions), config.encoder_config.num_hidden_layers) + # self.assertEqual(len(global_attentions), config.global_config.num_hidden_layers) @require_torch_accelerator From c9e30fd986447170542892f2ad5fd819a0f562fe Mon Sep 17 00:00:00 2001 From: itazap Date: Tue, 29 Jul 2025 18:26:02 +0200 Subject: [PATCH 104/139] new modular --- src/transformers/models/blt/modeling_blt.py | 72 +- src/transformers/models/blt/modular_blt.py | 749 ++++++++++---------- 2 files changed, 381 insertions(+), 440 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 4e77d0721f2b..b2e88f00a4ad 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -300,6 +300,8 @@ def forward( cache_position=None, **kwargs, ): + if hidden_states.dim() == 2: + hidden_states = hidden_states.unsqueeze(0) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -1280,29 +1282,32 @@ def patch_lengths_from_entropies( return patch_lengths +@auto_docstring( + custom_intro=""" + The Blt Text Model with a language modeling head on top. + """ +) class BltForCausalLM(BltPreTrainedModel, GenerationMixin): + config: BltConfig + _can_compile_fullgraph = False + base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] def __init__(self, config: BltConfig): - super().__init__(config) + super().__init__(config.get_text_config()) + self.text_config = config.get_text_config() self.vocab_size = config.vocab_size self.model = BltModel(config) self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.local_encoder.embed_tokens - - def set_input_embeddings(self, value): - self.model.local_encoder.embed_tokens = value + def set_decoder(self, decoder): + self.model = decoder - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings + def get_decoder(self): + return self.model @can_return_tuple @auto_docstring @@ -1321,53 +1326,15 @@ def forward( cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> Union[tuple, CausalLMOutputWithPast]: r""" - cross_attention_states (`torch.FloatTensor`, *optional*): - Output of the vision model, used for cross-attention. This tensor contains the processed image features that - the language model will attend to. - cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*): - Cross-attention mask to control the interaction between text tokens and image tiles. - This 4D tensor defines which image tiles each text token should attend to. - - For each text token (in seq_length): - - 1 indicates the token **should attend** to the corresponding image tile - - 0 indicates the token **should not attend** to the corresponding image tile - full_text_row_masked_out_mask (`tuple[torch.Tensor, torch.Tensor]`, *optional*): - A tuple containing two tensors that mask out rows in the cross-attention mechanism: - - The first tensor has shape `(batch_size, 1, seq_length, 1)` and contains values of 0 or 1. - A value of 0 indicates that the corresponding text token's entire row in the cross-attention - matrix should be masked out (all image tokens ignored). - - The second tensor has the same shape and is used internally to apply the masking during - the forward pass of cross-attention layers. - This mask is derived from the cross_attention_mask and is used to handle cases where a text token - should not attend to any image token. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, XBltForCausalLM - - >>> model = XBltForCausalLM.from_pretrained("Llama-3.2-11B-Vision") - >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") - - >>> prompt = "If I had to write a haiku, it would be:" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) - >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - >>> print(result) - If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. - I love the idea of snowflakes gently falling, each one - ``` """ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: BaseModelOutputWithPast = self.model( + outputs = self.model( input_ids=input_ids, cross_attention_states=cross_attention_states, attention_mask=attention_mask, @@ -1381,7 +1348,7 @@ def forward( **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]).float() @@ -1397,4 +1364,5 @@ def forward( attentions=outputs.attentions, ) + __all__ = ["BltPreTrainedModel", "BltModel", "BltPatcher", "BltForCausalLM"] diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index c2eb03f29d1b..5248a2a1dbbd 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""BLT modular model, inheriting from Mllama where appropriate.""" +"""Blt modular model, inheriting from Mllama where appropriate.""" from typing import Optional, Union @@ -22,19 +22,26 @@ import torch.nn.functional as F from ...cache_utils import Cache -from ...modeling_outputs import BaseModelOutputWithPast +from ...generation import GenerationMixin +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel + +from ...processing_utils import Unpack from ...utils import is_torch_flex_attn_available, logging +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.generic import check_model_inputs, OutputRecorder from .configuration_blt import ( - BLTConfig, - BLTGlobalTransformerConfig, - BLTLocalDecoderConfig, - BLTLocalEncoderConfig, - BLTPatcherConfig, + BltConfig, + BltGlobalTransformerConfig, + BltLocalDecoderConfig, + BltLocalEncoderConfig, + BltPatcherConfig, ) if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask + from ...integrations.flex_attention import make_flex_block_causal_mask from ..mllama.modeling_mllama import ( @@ -46,6 +53,8 @@ MllamaTextMLP, MllamaTextRMSNorm, MllamaTextSelfAttention, + eager_attention_forward, + repeat_kv ) @@ -55,14 +64,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # TODO: not exactly equivalent to other transformers implementations,, need feedback # Extract first head_dim//2 elements which correspond to the unique frequencies - # This matches the original BLT approach which uses head_dim//2 frequency pairs + # This matches the original Blt approach which uses head_dim//2 frequency pairs head_dim = q.shape[-1] cos_freqs = cos[..., : head_dim // 2] # [B, S, D/2] sin_freqs = sin[..., : head_dim // 2] # [B, S, D/2] # Expand cos/sin to match query/key tensor format [B, H, S, D/2] - cos_freqs = cos_freqs.unsqueeze(1) # [B, 1, S, D/2] -> [B, H, S, D/2] - sin_freqs = sin_freqs.unsqueeze(1) # [B, 1, S, D/2] -> [B, H, S, D/2] + cos_freqs = cos_freqs.unsqueeze(1) # [B, 1, S, D/2] -> [B, H, S, D/2] + sin_freqs = sin_freqs.unsqueeze(1) # [B, 1, S, D/2] -> [B, H, S, D/2] # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... q_pairs = q.view(*q.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] @@ -277,16 +286,16 @@ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optiona return padded -class BLTMLP(MllamaTextMLP): +class BltMLP(MllamaTextMLP): pass -class BLTRMSNorm(MllamaTextRMSNorm): +class BltRMSNorm(MllamaTextRMSNorm): pass -class BLTRotaryEmbedding(MllamaRotaryEmbedding): - def __init__(self, config: BLTConfig, device=None): +class BltRotaryEmbedding(MllamaRotaryEmbedding): + def __init__(self, config: BltConfig, device=None): super().__init__(config=config, device=device) # BC: "rope_type" was originally "type" self.rope_type = ( @@ -296,85 +305,68 @@ def __init__(self, config: BLTConfig, device=None): ) -class BLTPreTrainedModel(MllamaPreTrainedModel): - """BLT PreTrainedModel inheriting from Mllama but with BLT-specific init.""" - - config_class = BLTConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = False - _supports_sdpa = True - _supports_cache_class = False - - def _init_weights(self, module): - std = self.config.initializer_range - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - elif isinstance(module, BLTRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, nn.RMSNorm): - module.weight.data.fill_(1.0) - - -class BLTTransformerLayer(MllamaSelfAttentionDecoderLayer): +class BltTransformerLayer(MllamaSelfAttentionDecoderLayer): def __init__(self, config, layer_idx: int): super().__init__() - self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) - self.mlp = BLTMLP(config) - self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = BltSelfAttention(config=config, layer_idx=layer_idx) + self.mlp = BltMLP(config) + self.input_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) -class BLTSelfAttention(MllamaTextSelfAttention): - """BLT variant of MllamaTextSelfAttention. Inherits all logic directly.""" +class BltSelfAttention(MllamaTextSelfAttention): + """Blt variant of MllamaTextSelfAttention. Inherits all logic directly.""" - def __init__(self, config: BLTConfig, layer_idx: int): + def __init__(self, config: BltConfig, layer_idx: int): super().__init__(config, layer_idx) self.is_causal = True + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, + ): + if hidden_states.dim() == 2: + hidden_states = hidden_states.unsqueeze(0) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + use_cache=use_cache, + past_key_value=past_key_value, + cache_position=cache_position, + **kwargs, + ) -# BLT-SPECIFIC COMPONENTS (no Mllama equivalent) +# Blt-SPECIFIC COMPONENTS (no Mllama equivalent) -class BLTLocalEncoder(nn.Module): - def __init__(self, config: BLTLocalEncoderConfig): +class BltLocalEncoder(nn.Module): + def __init__(self, config: BltLocalEncoderConfig): super().__init__() - - self.config = config self.gradient_checkpointing = False - + self.config = config self.layers = nn.ModuleList( - [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - - self.rotary_emb = BLTRotaryEmbedding(config=config) - + self.rotary_emb = BltRotaryEmbedding(config=config) self.patch_embedding_projection = nn.Linear( in_features=config.hidden_size, out_features=config.hidden_size * config.cross_attn_k, bias=False, ) - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.cross_attn_layers = nn.ModuleList() layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) + BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) def forward( @@ -385,85 +377,51 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, - cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ): - """ """ - # Initialize output collections - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - if input_embeds is None: input_embeds = self.embed_tokens(input_ids) - - batch_size, _, _ = input_embeds.shape - + batch_size = input_embeds.shape[0] hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) - - position_ids = ( - torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - ) + if position_ids is None: + position_ids = ( + torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) position_embeddings = self.rotary_emb(hidden_states, position_ids) - hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) - for idx, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - position_embeddings, - None, # attention_mask - None, # past_key_value - False, # output_attentions - False, # use_cache - None, # cache_position - ) - else: - layer_outputs = layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=None, - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_value=past_key_values, + cache_position=cache_position, + **kwargs, + ) if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size ) - layer_idx = idx if self.config.cross_attn_all_layers else 0 + # Remove cross_attention_states from kwargs if present to avoid multiple values error + kwargs.pop("cross_attention_states", None) cross_attention_output, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, attention_mask=cross_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, - output_attentions=False, - use_cache=False, - cache_position=None, + **kwargs, ) patch_embeds = patch_embeds + cross_attention_output - encoder_cross_states = patch_embeds - return hidden_states, encoder_cross_states, all_hidden_states, all_attentions + return hidden_states, encoder_cross_states def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): """ @@ -476,7 +434,8 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): (i.e. if the sum(patch_lengths[i]) < seq_len for any i) will be sent to a dummy patch, which is trimmed before returning. """ - batch_size, _, embedding_dim = hidden_states.shape + batch_size = hidden_states.shape[0] + embedding_dim = hidden_states.shape[-1] patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) @@ -495,36 +454,30 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): return reduced_embeddings -class BLTLocalDecoder(nn.Module): - def __init__(self, config: BLTLocalDecoderConfig): +class BltLocalDecoder(nn.Module): + def __init__(self, config: BltLocalDecoderConfig): super().__init__() - - # Extract config values to instance attributes - self.config = config - self.cross_attn_decoder = True # config.cross_attn_decoder #TODO: maybe remove self.gradient_checkpointing = False - + self.config = config + self.cross_attn_decoder = True self.layers = nn.ModuleList( - [BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - - self.rotary_emb = BLTRotaryEmbedding(config=config) - + self.rotary_emb = BltRotaryEmbedding(config=config) self.patch_embedding_projection = nn.Linear( in_features=config.hidden_size_global, out_features=config.hidden_size * config.cross_attn_k, bias=False, ) - - self.norm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.norm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn_layers = nn.ModuleList() layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 for layer_idx in range(layers_to_add): self.cross_attn_layers.append( - BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) + BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) + @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -533,95 +486,57 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ): - # Initialize output collections - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - batch_size, _, _ = embeds.shape - + batch_size = embeds.shape[0] hidden_states = embeds - patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size ) - if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds - - # Use sequence length from embeds (standard transformers pattern) - seq_len = embeds.shape[1] - position_ids = torch.arange(seq_len, device=embeds.device).unsqueeze(0).expand(batch_size, -1) + if position_ids is None: + position_ids = ( + torch.arange(embeds.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + ) position_embeddings = self.rotary_emb(hidden_states, position_ids) - hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - if i == 0 or self.config.cross_attn_all_layers: - # Use cross attention to extract info from patch_embeds into hidden_states + # Remove cross_attention_states from kwargs if present to avoid multiple values error + kwargs.pop("cross_attention_states", None) cross_attention_output, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, cross_attention_states=patch_embeds, attention_mask=cross_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, - output_attentions=False, - use_cache=False, - cache_position=None, + **kwargs, ) hidden_states = hidden_states + cross_attention_output - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - position_embeddings, - None, # attention_mask - None, # past_key_value - False, # output_attentions - False, # use_cache - None, # cache_position - ) - else: - layer_outputs = layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=None, - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - # Add final hidden state after all layers - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_value=past_key_values, + cache_position=cache_position, + **kwargs, + ) logits = self.norm(hidden_states) - # logits = self.lm_head(logits) - return logits, all_hidden_states, all_attentions - + return logits -class BLTCrossAttention(MllamaTextCrossAttention): - """Cross-attention module for BLT, following transformers style""" +class BltCrossAttention(MllamaTextCrossAttention): + """Cross-attention module for Blt, following transformers style""" - def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None): + def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] = None): super().__init__() self.is_causal = False - self.q_norm = BLTRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = BLTRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -629,32 +544,25 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ): bsz, q_len, _ = hidden_states.size() - - query_states = self.q_norm(hidden_states) # BLT normalizes first + query_states = self.q_norm(hidden_states) query_states = self.q_proj(query_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - if cross_attention_states is not None: - cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first + cross_attention_states = self.k_norm(cross_attention_states) key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: - # if we have a new image + new tokens, we only computed key_states on that new image - # we still update the cross key states, past_image, new_image. And use it! key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - elif cache_position[0] != 0: + elif cache_position is not None and cache_position[0] != 0: key_states, value_states = ( past_key_value.key_cache[self.layer_idx], past_key_value.value_cache[self.layer_idx], @@ -663,18 +571,9 @@ def forward( raise ValueError( "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" ) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, @@ -685,29 +584,20 @@ def forward( scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) attn_output = attn_output + hidden_states - - if not output_attentions: - attn_weights = None - return attn_output, attn_weights -class BLTGlobalTransformer(nn.Module): - def __init__(self, config: BLTGlobalTransformerConfig): +class BltGlobalTransformer(nn.Module): + def __init__(self, config: BltGlobalTransformerConfig): super().__init__() - self.config = config - self.gradient_checkpointing = False - self.layers = nn.ModuleList() for layer_idx in range(config.num_hidden_layers): - self.layers.append(BLTTransformerLayer(config, layer_idx)) - - self.rotary_emb = BLTRotaryEmbedding(config=config) + self.layers.append(BltTransformerLayer(config, layer_idx)) + self.rotary_emb = BltRotaryEmbedding(config=config) def forward( self, @@ -715,71 +605,186 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, - cache: Optional[list[tuple[torch.Tensor, torch.Tensor, int]]] = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ): - # Initialize output collections - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - batch_size, seq_len, _ = input_embeds.shape - hidden_states = input_embeds - hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) - - position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + if position_ids is None: + position_ids = ( + torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) position_embeddings = self.rotary_emb(hidden_states, position_ids) - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - position_embeddings, - None, # attention_mask - None, # past_key_value - False, # output_attentions - False, # use_cache - None, # cache_position + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + return hidden_states + + +@auto_docstring +class BltPreTrainedModel(PreTrainedModel): + """Blt PreTrainedModel inheriting from Mllama but with Blt-specific init.""" + + config: BltConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] + + _supports_static_cache = False # static cache cannot have different shapes for each layer + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), + "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), + "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), + "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), + } + + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device ) - else: - layer_outputs = layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=None, - output_attentions=output_attentions, + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - return hidden_states, all_hidden_states, all_attentions + return causal_mask -class BLTModel(BLTPreTrainedModel): - def __init__(self, config: BLTConfig): +class BltModel(BltPreTrainedModel): + def __init__(self, config: BltConfig): super().__init__(config) - self.config = config - + self.gradient_checkpointing = False config.patcher_config._attn_implementation = config._attn_implementation config.encoder_config._attn_implementation = config._attn_implementation config.decoder_config._attn_implementation = config._attn_implementation config.global_config._attn_implementation = config._attn_implementation - self.local_encoder = BLTLocalEncoder(config.encoder_config) - self.global_transformer = BLTGlobalTransformer(config.global_config) - self.local_decoder = BLTLocalDecoder(config.decoder_config) - + self.config = config + self.local_encoder = BltLocalEncoder(config.encoder_config) + self.global_transformer = BltGlobalTransformer(config.global_config) + self.local_decoder = BltLocalDecoder(config.decoder_config) num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size) embeddings = [ nn.Embedding(config.encoder_hash_byte_group_vocab, config.encoder_config.hidden_size) @@ -787,15 +792,15 @@ def __init__(self, config: BLTConfig): ] self.encoder_hash_tok_embedding = nn.ModuleList(embeddings) if self.config.patch_in_forward: - self.patcher = BLTPatcher(config.patcher_config) + self.patcher = BltPatcher(config.patcher_config) self.patcher.eval() for param in self.patcher.parameters(): param.requires_grad = False else: self.patcher = None - self.post_init() + @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -805,38 +810,17 @@ def forward( past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Union[BaseModelOutputWithPast, tuple]: - """ - Args: - input_ids (torch.LongTensor, optional): Input token ids. - patch_lengths (Optional[torch.Tensor]): Patch lengths for patching. - attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Ignored, for compatibility. - Returns: - Union[BaseModelOutputWithPast, tuple]: Model outputs. - """ - # Set defaults from config when parameters are None - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Explicit input validation (not XOR) + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is None and inputs_embeds is None: raise ValueError("You have to specify either input_ids or inputs_embeds") - if input_ids is not None: batch_size, sequence_length = input_ids.shape else: batch_size, sequence_length, _ = inputs_embeds.shape - # Handle patching if patch_lengths is None: if self.config.patching_mode == "entropy" and self.patcher is not None: if input_ids is None: @@ -868,42 +852,51 @@ def forward( self.config.encoder_hash_byte_group_size, self.config.encoder_hash_byte_group_vocab, ) - if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + encoder_embeds.shape[1], device=encoder_embeds.device ) - if position_ids is None: position_ids = cache_position.unsqueeze(0) - - causal_mask = attention_mask - + causal_mask = self._update_causal_mask( + attention_mask, encoder_embeds, cache_position, past_key_values + ) cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) - encoder_hidden_states, encoder_cross_states, encoder_hidden_states_all, encoder_attentions_all = ( - self.local_encoder( - input_ids=input_ids, - input_embeds=encoder_embeds, - patch_embeds=None, - cross_mask=cross_attn_mask_enc, - full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, - num_patches=patch_lengths.shape[1], - patch_ids=patch_ids, - output_attentions=False, # Don't collect encoder attentions - output_hidden_states=False, # Don't collect encoder hidden states - ) + # Remove full_text_row_masked_out_mask from kwargs if present to avoid multiple values error + kwargs.pop("full_text_row_masked_out_mask", None) + encoder_hidden_states, encoder_cross_states = self.local_encoder( + input_ids=input_ids, + input_embeds=encoder_embeds, + patch_embeds=None, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=None, + cache_position=None, + cross_mask=cross_attn_mask_enc, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + **kwargs, ) - global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - global_hidden_states, global_hidden_states_all, global_attentions_all = self.global_transformer( + global_cache_position = torch.arange( + 0, global_hidden_states.shape[1], device=global_hidden_states.device + ) + global_position_ids = global_cache_position.unsqueeze(0) + global_causal_mask = self._update_causal_mask( + None, global_hidden_states, global_cache_position, None + ) + global_hidden_states = self.global_transformer( input_embeds=global_hidden_states, - output_attentions=False, # Don't collect global transformer attentions - output_hidden_states=False, # Don't collect global transformer hidden states + attention_mask=global_causal_mask, + position_ids=global_position_ids, + past_key_values=None, + cache_position=None, + **kwargs, ) - decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( decoder_patch_ids, @@ -913,51 +906,31 @@ def forward( self.config.cross_attn_k, encoder_embeds.dtype, ) - output, decoder_hidden_states_all, decoder_attentions_all = self.local_decoder( + output = self.local_decoder( input_ids=input_ids, embeds=encoder_hidden_states, patch_embeds=global_hidden_states, attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, mask=None, cross_mask=cross_attn_mask_dec, full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, - output_attentions=output_attentions, # Only collect decoder attentions - output_hidden_states=output_hidden_states, # Only collect decoder hidden states - ) - - # Only use decoder outputs (which match the expected num_hidden_layers) - all_hidden_states = ( - decoder_hidden_states_all if output_hidden_states and decoder_hidden_states_all is not None else None + **kwargs, ) - all_attentions = decoder_attentions_all if output_attentions and decoder_attentions_all is not None else None - - if not return_dict: - outputs = (output,) - if past_key_values is not None: - outputs = outputs + (past_key_values,) - if all_hidden_states is not None: - outputs = outputs + (all_hidden_states,) - if all_attentions is not None: - outputs = outputs + (all_attentions,) - return outputs - return BaseModelOutputWithPast( last_hidden_state=output, past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_attentions, ) def get_input_embeddings(self): - """Returns the model's input embeddings.""" return self.local_encoder.embed_tokens def set_input_embeddings(self, value): - """Sets the model's input embeddings.""" self.local_encoder.embed_tokens = value def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: - """Convert patch lengths to patch IDs for each token position.""" batch_size = patch_lengths.shape[0] patch_starts = torch.cat( [ @@ -966,21 +939,19 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> ], dim=-1, ) - token_positions = torch.arange(seq_len, device=patch_lengths.device) return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 -class BLTPatcher(BLTPreTrainedModel): - def __init__(self, config: BLTPatcherConfig): +class BltPatcher(BltPreTrainedModel): + def __init__(self, config: BltPatcherConfig): super().__init__(config) - self.rotary_emb = BLTRotaryEmbedding(config=self.config) + self.rotary_emb = BltRotaryEmbedding(config=self.config) self.layers = nn.ModuleList() for layer_idx in range(self.config.num_hidden_layers): - self.layers.append(BLTTransformerLayer(self.config, layer_idx)) - + self.layers.append(BltTransformerLayer(self.config, layer_idx)) self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) - self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.lm_head = nn.Linear( self.config.hidden_size, self.config.vocab_size, @@ -995,14 +966,13 @@ def forward( max_patch_length: Optional[int] = None, patching_batch_size: int = 1, device: Optional[str] = None, + **kwargs: Unpack[TransformersKwargs], ): - # Handle chunked processing for entropy calculation entropies = [] predictions = [] max_length = self.config.max_position_embeddings batch_numel = max_length * patching_batch_size splits = torch.split(token_values.flatten(), batch_numel) - for split in splits: pad_size = (max_length - (split.numel() % max_length)) % max_length pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False) @@ -1010,21 +980,23 @@ def forward( split = split.reshape(-1, max_length) if device is not None: split = split.to(device) - - # Process chunk: embeddings -> layers -> output batch_size, sequence_length = split.shape input_embeds = self.embed_tokens(split) - hidden_states = input_embeds - - batch_size, _, _ = input_embeds.shape - + batch_size = input_embeds.shape[0] position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) + cache_position = torch.arange(sequence_length, device=input_embeds.device) + causal_mask = self._update_causal_mask( + None, # attention_mask + input_embeds, + cache_position, + None, # past_key_values + ) + for i, layer in enumerate(self.layers): - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask) hidden_states = layer_outputs[0] logits = self.lm_head(self.norm(hidden_states)) @@ -1032,14 +1004,9 @@ def forward( predictions.append(logits) prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() entropies.append(prediction_entropies) - concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) - - # Always compute patch lengths from concatenated entropies batch_size, sequence_length = token_values.shape - - # Find patch start IDs based on entropy if patch_size is not None: patch_lengths = self.patch_lengths_from_entropies( entropies=concat_entropies, @@ -1048,7 +1015,6 @@ def forward( threshold=threshold, ) else: - # Default to byte-level patching patch_lengths = torch.ones( (batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device ) @@ -1117,30 +1083,37 @@ def patch_lengths_from_entropies( return patch_lengths -class BLTForCausalLM(MllamaForCausalLM): - config: BLTConfig - supports_gradient_checkpointing = True +class BltForCausalLM(MllamaForCausalLM): + config: BltConfig base_model_prefix = "model" + _can_compile_fullgraph = False _tied_weights_keys = ["lm_head.weight"] - _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] - def __init__(self, config): + def __init__(self, config: BltConfig): super().__init__(config) - self.model = BLTModel(config) self.vocab_size = config.vocab_size + self.model = BltModel(config) self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) + self.post_init() - def get_input_embeddings(self): - return self.model.local_encoder.embed_tokens - def set_input_embeddings(self, value): - self.model.local_encoder.embed_tokens = value + @can_return_tuple + @auto_docstring + def forward(self, **super_kwargs): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + super().forward(**super_kwargs) __all__ = [ - "BLTPreTrainedModel", - "BLTModel", - "BLTPatcher", - "BLTForCausalLM", + "BltPreTrainedModel", + "BltModel", + "BltPatcher", + "BltForCausalLM", ] From 3b2e3e85e838bb2854023f1ffe406b7529fbe312 Mon Sep 17 00:00:00 2001 From: itazap Date: Wed, 30 Jul 2025 13:12:55 +0200 Subject: [PATCH 105/139] ruff and such --- src/transformers/models/blt/modular_blt.py | 17 +++---- tests/models/blt/test_modeling_blt.py | 53 ++++++++++------------ 2 files changed, 31 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 5248a2a1dbbd..ee46eb301a87 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -14,7 +14,7 @@ # limitations under the License. """Blt modular model, inheriting from Mllama where appropriate.""" -from typing import Optional, Union +from typing import Callable, Optional, Union import torch import torch.distributions @@ -22,14 +22,12 @@ import torch.nn.functional as F from ...cache_utils import Cache -from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel - from ...processing_utils import Unpack -from ...utils import is_torch_flex_attn_available, logging from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging -from ...utils.generic import check_model_inputs, OutputRecorder +from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_blt import ( BltConfig, BltGlobalTransformerConfig, @@ -41,12 +39,12 @@ if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask + from ...integrations.flex_attention import make_flex_block_causal_mask from ..mllama.modeling_mllama import ( MllamaForCausalLM, - MllamaPreTrainedModel, MllamaRotaryEmbedding, MllamaSelfAttentionDecoderLayer, MllamaTextCrossAttention, @@ -54,7 +52,6 @@ MllamaTextRMSNorm, MllamaTextSelfAttention, eager_attention_forward, - repeat_kv ) @@ -626,7 +623,7 @@ def forward( **kwargs, ) return hidden_states - + @auto_docstring class BltPreTrainedModel(PreTrainedModel): @@ -1099,7 +1096,7 @@ def __init__(self, config: BltConfig): self.post_init() - @can_return_tuple + @can_return_tuple @auto_docstring def forward(self, **super_kwargs): r""" diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index b7a3df150dba..dcbf89d6168b 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -317,7 +317,7 @@ def test_flex_attention_with_grads(): @unittest.skip(reason="Padding with patcher is complex") def test_eager_padding_matches_padding_free_with_position_ids(): return - + @unittest.skip(reason="Padding with patcher is complex") def test_sdpa_padding_matches_padding_free_with_position_ids(): return @@ -342,12 +342,12 @@ def test_sdpa_padding_matches_padding_free_with_position_ids(): # outputs = model(**self._prepare_for_class(inputs_dict, model_class)) # # For Blt, check separate attention outputs from each component # attentions = outputs.attentions - # encoder_attentions = outputs.encoder_attentions + # encoder_attentions = outputs.encoder_attentions # global_attentions = outputs.global_attentions - + # # Each component should have attention outputs equal to their layer count # self.assertEqual(len(attentions), config.decoder_config.num_hidden_layers) - # self.assertEqual(len(encoder_attentions), config.encoder_config.num_hidden_layers) + # self.assertEqual(len(encoder_attentions), config.encoder_config.num_hidden_layers) # self.assertEqual(len(global_attentions), config.global_config.num_hidden_layers) @@ -368,17 +368,15 @@ def test_model(self): prompt = "my name is" - model = BltForCausalLM.from_pretrained( - "itazap/blt-1b-testing", - device_map="auto", - attn_implementation="sdpa" - ) + model = BltForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto", attn_implementation="sdpa") tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False + ) output_text = tokenizer.decode(generated_ids[0]) self.assertEqual(output_text, EXPECTED_TEXT) @@ -475,16 +473,16 @@ def test_model_bf16(self): prompt = "my name is" model = BltForCausalLM.from_pretrained( - "itazap/blt-1b-testing", - device_map="auto", - attn_implementation="sdpa", - torch_dtype=torch.bfloat16) + "itazap/blt-1b-testing", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16 + ) tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False + ) output_text = tokenizer.decode(generated_ids[0]) self.assertEqual(output_text, EXPECTED_TEXT) @@ -567,10 +565,8 @@ def test_model_logits_bf16(self): input_ids = [1, 42, 21, 12, 43, 23, 1, 4] model = BltForCausalLM.from_pretrained( - "itazap/blt-1b-testing", - device_map="auto", - attn_implementation="sdpa", - torch_dtype=torch.bfloat16) + "itazap/blt-1b-testing", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16 + ) with torch.no_grad(): output = model(torch.tensor([input_ids]).to(torch_device))[0] @@ -588,16 +584,15 @@ def test_model_eager(self): prompt = "my name is" - model = BltForCausalLM.from_pretrained( - "itazap/blt-1b-testing", - device_map="auto", - attn_implementation="eager") + model = BltForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto", attn_implementation="eager") tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False + ) output_text = tokenizer.decode(generated_ids[0]) self.assertEqual(output_text, EXPECTED_TEXT) @@ -613,10 +608,8 @@ def test_model_bf16_static_cache(self): prompt = "my name is" model = BltForCausalLM.from_pretrained( - "itazap/blt-1b-testing", - device_map="auto", - attn_implementation="sdpa", - torch_dtype=torch.bfloat16) + "itazap/blt-1b-testing", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16 + ) model.generation_config.cache_implementation = "static" @@ -624,7 +617,9 @@ def test_model_bf16_static_cache(self): inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False + ) output_text = tokenizer.decode(generated_ids[0]) self.assertEqual(output_text, EXPECTED_TEXT) From 2ded41e055ea06908b146e9d13cc9eab83a2ca93 Mon Sep 17 00:00:00 2001 From: itazap Date: Wed, 30 Jul 2025 13:36:58 +0200 Subject: [PATCH 106/139] update pretrainedmodel modular --- src/transformers/models/blt/modeling_blt.py | 55 +++++++- src/transformers/models/blt/modular_blt.py | 136 +------------------- 2 files changed, 53 insertions(+), 138 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index b2e88f00a4ad..788840482847 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -50,6 +50,8 @@ if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask + +if is_torch_flex_attn_available(): from ...integrations.flex_attention import make_flex_block_causal_mask @@ -342,9 +344,6 @@ def forward( return attn_output, attn_weights -# Blt-SPECIFIC COMPONENTS (no Mllama equivalent) - - class BltLocalEncoder(nn.Module): def __init__(self, config: BltLocalEncoderConfig): super().__init__() @@ -648,8 +647,7 @@ class BltPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] - - _supports_static_cache = False # static cache cannot have different shapes for each layer + _can_compile_fullgraph = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True @@ -661,12 +659,18 @@ class BltPreTrainedModel(PreTrainedModel): "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), } + _supports_static_cache = False # static cache cannot have different shapes for each layer + + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): @@ -684,7 +688,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -718,6 +722,7 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -1328,10 +1333,48 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" + cross_attention_states (`torch.FloatTensor`, *optional*): + Output of the vision model, used for cross-attention. This tensor contains the processed image features that + the language model will attend to. + cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*): + Cross-attention mask to control the interaction between text tokens and image tiles. + This 4D tensor defines which image tiles each text token should attend to. + + For each text token (in seq_length): + - 1 indicates the token **should attend** to the corresponding image tile + - 0 indicates the token **should not attend** to the corresponding image tile + full_text_row_masked_out_mask (`tuple[torch.Tensor, torch.Tensor]`, *optional*): + A tuple containing two tensors that mask out rows in the cross-attention mechanism: + - The first tensor has shape `(batch_size, 1, seq_length, 1)` and contains values of 0 or 1. + A value of 0 indicates that the corresponding text token's entire row in the cross-attention + matrix should be masked out (all image tokens ignored). + - The second tensor has the same shape and is used internally to apply the masking during + the forward pass of cross-attention layers. + This mask is derived from the cross_attention_mask and is used to handle cases where a text token + should not attend to any image token. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, BltForCausalLM + + >>> model = BltForCausalLM.from_pretrained("Llama-3.2-11B-Vision") + >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") + + >>> prompt = "If I had to write a haiku, it would be:" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) + >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(result) + If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. + I love the idea of snowflakes gently falling, each one + ``` """ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index ee46eb301a87..f45a4aa4ae5b 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -22,7 +22,6 @@ import torch.nn.functional as F from ...cache_utils import Cache -from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -40,11 +39,11 @@ if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask - from ...integrations.flex_attention import make_flex_block_causal_mask from ..mllama.modeling_mllama import ( MllamaForCausalLM, + MllamaPreTrainedModel, MllamaRotaryEmbedding, MllamaSelfAttentionDecoderLayer, MllamaTextCrossAttention, @@ -342,8 +341,6 @@ def forward( ) -# Blt-SPECIFIC COMPONENTS (no Mllama equivalent) - class BltLocalEncoder(nn.Module): def __init__(self, config: BltLocalEncoderConfig): super().__init__() @@ -626,7 +623,7 @@ def forward( @auto_docstring -class BltPreTrainedModel(PreTrainedModel): +class BltPreTrainedModel(MllamaPreTrainedModel): """Blt PreTrainedModel inheriting from Mllama but with Blt-specific init.""" config: BltConfig @@ -646,127 +643,8 @@ class BltPreTrainedModel(PreTrainedModel): "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), } - - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) class BltModel(BltPreTrainedModel): @@ -1099,12 +977,6 @@ def __init__(self, config: BltConfig): @can_return_tuple @auto_docstring def forward(self, **super_kwargs): - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - """ super().forward(**super_kwargs) From cd7d1a8d10f1527de3e81b80d48c392c7c1c9f25 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 31 Jul 2025 16:29:49 +0000 Subject: [PATCH 107/139] using cohere2 apply_rotary_pos_emb --- src/demo_hf.py | 20 ++--- src/transformers/models/blt/modeling_blt.py | 61 +++++++-------- src/transformers/models/blt/modular_blt.py | 83 ++++++++------------- 3 files changed, 70 insertions(+), 94 deletions(-) diff --git a/src/demo_hf.py b/src/demo_hf.py index 3074b0676c41..83f6c3e2d1a8 100644 --- a/src/demo_hf.py +++ b/src/demo_hf.py @@ -1,47 +1,39 @@ +import gc import logging import os import torch +from transformers import AutoTokenizer from transformers.models.blt.modeling_blt import BLTForCausalLM -from transformers.models.blt.tokenization_blt import BLTTokenizer - -from transformers import AutoTokenizer logger = logging.getLogger() os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" -import gc gc.collect() torch.cuda.empty_cache() def main(prompt: str = "my name is", model_name: str = ""): - model = BLTForCausalLM.from_pretrained( - model_name, - device_map="auto", - # attn_implementation="eager" + model_name, + device_map="auto", ) tokenizer = AutoTokenizer.from_pretrained(model_name) - inputs = tokenizer(prompt, return_tensors="pt").to(model.device) generated_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False, use_cache=False) output_text = tokenizer.decode(generated_ids[0]) - + print(f'Model: "{model_name}"') - # print(f'Prompt: "{prompt}"') + # print(f'Prompt: "{prompt}"') print(f'Completion: "{output_text}"') if __name__ == "__main__": - # SNAPSHOT_PATH = os.path.expanduser("~/.cache/huggingface/hub/models--itazap--blt-1b/snapshots/bb8c23be2c2f065f0ee315ec2066ac1d3c78722a") - # main(model_name=SNAPSHOT_PATH) main(model_name="itazap/blt-1b-testing") - diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 788840482847..a6a67723a57d 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -113,13 +113,14 @@ def __init__(self, config: BltConfig, device=None): @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): + # Copied from Cohere2RotaryEmbedding.forward inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) + emb = torch.repeat_interleave(freqs, 2, dim=-1) # Use Cohere2 pattern for compatibility cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -238,41 +239,41 @@ def eager_attention_forward( def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - # TODO: not exactly equivalent to other transformers implementations,, need feedback - # Extract first head_dim//2 elements which correspond to the unique frequencies - # This matches the original Blt approach which uses head_dim//2 frequency pairs - head_dim = q.shape[-1] - cos_freqs = cos[..., : head_dim // 2] # [B, S, D/2] - sin_freqs = sin[..., : head_dim // 2] # [B, S, D/2] + """Applies Rotary Position Embedding to the query and key tensors. - # Expand cos/sin to match query/key tensor format [B, H, S, D/2] - cos_freqs = cos_freqs.unsqueeze(1) # [B, 1, S, D/2] -> [B, H, S, D/2] - sin_freqs = sin_freqs.unsqueeze(1) # [B, 1, S, D/2] -> [B, H, S, D/2] - - # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... - q_pairs = q.view(*q.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] - k_pairs = k.view(*k.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] - - # Extract real and i parts - q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] - k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] - - # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] - q_real_rot = cos_freqs * q_real - sin_freqs * q_imag - q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag - k_real_rot = cos_freqs * k_real - sin_freqs * k_imag - k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed - # Recombine pairs and reshape back to original format - q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] - k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] - return q_rot.type_as(q), k_rot.type_as(k) +def rotate_half(x): + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x class BltSelfAttention(nn.Module): - """Blt variant of MllamaTextSelfAttention. Inherits all logic directly.""" - def __init__(self, config: BltConfig, layer_idx: int): super().__init__() self.config = config diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index f45a4aa4ae5b..f075d59cb978 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -23,6 +23,7 @@ from ...cache_utils import Cache from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging @@ -40,7 +41,6 @@ from torch.nn.attention.flex_attention import BlockMask - from ..mllama.modeling_mllama import ( MllamaForCausalLM, MllamaPreTrainedModel, @@ -57,39 +57,6 @@ logger = logging.get_logger(__name__) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - # TODO: not exactly equivalent to other transformers implementations,, need feedback - # Extract first head_dim//2 elements which correspond to the unique frequencies - # This matches the original Blt approach which uses head_dim//2 frequency pairs - head_dim = q.shape[-1] - cos_freqs = cos[..., : head_dim // 2] # [B, S, D/2] - sin_freqs = sin[..., : head_dim // 2] # [B, S, D/2] - - # Expand cos/sin to match query/key tensor format [B, H, S, D/2] - cos_freqs = cos_freqs.unsqueeze(1) # [B, 1, S, D/2] -> [B, H, S, D/2] - sin_freqs = sin_freqs.unsqueeze(1) # [B, 1, S, D/2] -> [B, H, S, D/2] - - # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... - q_pairs = q.view(*q.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] - k_pairs = k.view(*k.shape[:-1], head_dim // 2, 2) # [B, H, S, D/2, 2] - - # Extract real and i parts - q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] - k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] - - # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] - q_real_rot = cos_freqs * q_real - sin_freqs * q_imag - q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag - k_real_rot = cos_freqs * k_real - sin_freqs * k_imag - k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag - - # Recombine pairs and reshape back to original format - q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] - k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] - - return q_rot.type_as(q), k_rot.type_as(k) - - def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): primes = [ 1000000007, @@ -300,6 +267,22 @@ def __init__(self, config: BltConfig, device=None): else "default" ) + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # Copied from Cohere2RotaryEmbedding.forward + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.repeat_interleave(freqs, 2, dim=-1) # Use Cohere2 pattern for compatibility + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + class BltTransformerLayer(MllamaSelfAttentionDecoderLayer): def __init__(self, config, layer_idx: int): @@ -311,9 +294,15 @@ def __init__(self, config, layer_idx: int): self.post_attention_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) -class BltSelfAttention(MllamaTextSelfAttention): - """Blt variant of MllamaTextSelfAttention. Inherits all logic directly.""" +def rotate_half(x): + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x + +class BltSelfAttention(MllamaTextSelfAttention): def __init__(self, config: BltConfig, layer_idx: int): super().__init__(config, layer_idx) self.is_causal = True @@ -495,9 +484,7 @@ def forward( if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds if position_ids is None: - position_ids = ( - torch.arange(embeds.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) - ) + position_ids = torch.arange(embeds.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) for i, layer in enumerate(self.layers): @@ -523,6 +510,7 @@ def forward( logits = self.norm(hidden_states) return logits + class BltCrossAttention(MllamaTextCrossAttention): """Cross-attention module for Blt, following transformers style""" @@ -734,9 +722,7 @@ def forward( ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, encoder_embeds, cache_position, past_key_values - ) + causal_mask = self._update_causal_mask(attention_mask, encoder_embeds, cache_position, past_key_values) cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) @@ -757,13 +743,9 @@ def forward( **kwargs, ) global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - global_cache_position = torch.arange( - 0, global_hidden_states.shape[1], device=global_hidden_states.device - ) + global_cache_position = torch.arange(0, global_hidden_states.shape[1], device=global_hidden_states.device) global_position_ids = global_cache_position.unsqueeze(0) - global_causal_mask = self._update_causal_mask( - None, global_hidden_states, global_cache_position, None - ) + global_causal_mask = self._update_causal_mask(None, global_hidden_states, global_cache_position, None) global_hidden_states = self.global_transformer( input_embeds=global_hidden_states, attention_mask=global_causal_mask, @@ -871,7 +853,9 @@ def forward( ) for i, layer in enumerate(self.layers): - layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask) + layer_outputs = layer( + hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask + ) hidden_states = layer_outputs[0] logits = self.lm_head(self.norm(hidden_states)) @@ -973,7 +957,6 @@ def __init__(self, config: BltConfig): self.post_init() - @can_return_tuple @auto_docstring def forward(self, **super_kwargs): From 01835380c259e1884559095c045e6b88b615517e Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 1 Aug 2025 09:42:00 +0000 Subject: [PATCH 108/139] small changes --- src/transformers/models/blt/modeling_blt.py | 56 ++++++++++-------- src/transformers/models/blt/modular_blt.py | 63 +++++++++++---------- 2 files changed, 67 insertions(+), 52 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index a6a67723a57d..3e71bcd16da1 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -266,7 +266,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def rotate_half(x): - # Split and rotate. Note that this function is different from e.g. Llama. + # From Cohere2. Split and rotate. Note that this function is different from e.g. Llama. x1 = x[..., ::2] x2 = x[..., 1::2] rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) @@ -370,7 +370,7 @@ def __init__(self, config: BltLocalEncoderConfig): def forward( self, input_ids: Optional[torch.LongTensor] = None, - input_embeds: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, patch_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -382,16 +382,20 @@ def forward( patch_ids: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ): - if input_embeds is None: - input_embeds = self.embed_tokens(input_ids) - batch_size = input_embeds.shape[0] - hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size = inputs_embeds.shape[0] + hidden_states = F.dropout(inputs_embeds, p=self.config.dropout, training=self.training) + if position_ids is None: position_ids = ( - torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1) ) + position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + for idx, layer in enumerate(self.layers): hidden_states = layer( hidden_states, @@ -440,6 +444,7 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): reduced_embeddings = torch.zeros( (batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device ) + reduced_embeddings = reduced_embeddings.scatter_reduce( src=hidden_states, dim=1, @@ -479,7 +484,7 @@ def __init__(self, config: BltLocalDecoderConfig): def forward( self, input_ids: Optional[torch.LongTensor] = None, - embeds: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, patch_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -490,18 +495,24 @@ def forward( full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): - batch_size = embeds.shape[0] - hidden_states = embeds + batch_size = inputs_embeds.shape[0] + hidden_states = inputs_embeds patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size ) + if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds + if position_ids is None: - position_ids = torch.arange(embeds.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + position_ids = ( + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) + position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + for i, layer in enumerate(self.layers): if i == 0 or self.config.cross_attn_all_layers: # Remove cross_attention_states from kwargs if present to avoid multiple values error @@ -564,6 +575,7 @@ def forward( query_states = self.q_norm(hidden_states) query_states = self.q_proj(query_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + if cross_attention_states is not None: cross_attention_states = self.k_norm(cross_attention_states) key_states = self.k_proj(cross_attention_states) @@ -574,18 +586,20 @@ def forward( key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - elif cache_position is not None and cache_position[0] != 0: + elif cache_position[0] != 0: key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], + past_key_value.layers[self.layer_idx].keys, + past_key_value.layers[self.layer_idx].values, ) else: raise ValueError( "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" ) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, @@ -642,8 +656,6 @@ def forward( @auto_docstring class BltPreTrainedModel(PreTrainedModel): - """Blt PreTrainedModel inheriting from Mllama but with Blt-specific init.""" - config: BltConfig base_model_prefix = "model" supports_gradient_checkpointing = True @@ -1022,10 +1034,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is None and inputs_embeds is None: - raise ValueError("You have to specify either input_ids or inputs_embeds") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if input_ids is not None: batch_size, sequence_length = input_ids.shape else: @@ -1076,7 +1086,7 @@ def forward( kwargs.pop("full_text_row_masked_out_mask", None) encoder_hidden_states, encoder_cross_states = self.local_encoder( input_ids=input_ids, - input_embeds=encoder_embeds, + inputs_embeds=encoder_embeds, patch_embeds=None, attention_mask=causal_mask, position_ids=position_ids, @@ -1111,7 +1121,7 @@ def forward( ) output = self.local_decoder( input_ids=input_ids, - embeds=encoder_hidden_states, + inputs_embeds=encoder_hidden_states, patch_embeds=global_hidden_states, attention_mask=causal_mask, position_ids=position_ids, @@ -1124,7 +1134,7 @@ def forward( ) return BaseModelOutputWithPast( last_hidden_state=output, - past_key_values=past_key_values if use_cache else None, + past_key_values=past_key_values, ) def get_input_embeddings(self): diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index f075d59cb978..8b23221454e6 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -26,7 +26,7 @@ from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_blt import ( BltConfig, @@ -295,7 +295,7 @@ def __init__(self, config, layer_idx: int): def rotate_half(x): - # Split and rotate. Note that this function is different from e.g. Llama. + # From Cohere2. Split and rotate. Note that this function is different from e.g. Llama. x1 = x[..., ::2] x2 = x[..., 1::2] rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) @@ -355,7 +355,7 @@ def __init__(self, config: BltLocalEncoderConfig): def forward( self, input_ids: Optional[torch.LongTensor] = None, - input_embeds: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, patch_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -367,16 +367,20 @@ def forward( patch_ids: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ): - if input_embeds is None: - input_embeds = self.embed_tokens(input_ids) - batch_size = input_embeds.shape[0] - hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size = inputs_embeds.shape[0] + hidden_states = F.dropout(inputs_embeds, p=self.config.dropout, training=self.training) + if position_ids is None: position_ids = ( - torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1) ) + position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + for idx, layer in enumerate(self.layers): hidden_states = layer( hidden_states, @@ -425,6 +429,7 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): reduced_embeddings = torch.zeros( (batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device ) + reduced_embeddings = reduced_embeddings.scatter_reduce( src=hidden_states, dim=1, @@ -464,7 +469,7 @@ def __init__(self, config: BltLocalDecoderConfig): def forward( self, input_ids: Optional[torch.LongTensor] = None, - embeds: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, patch_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -475,18 +480,24 @@ def forward( full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): - batch_size = embeds.shape[0] - hidden_states = embeds + batch_size = inputs_embeds.shape[0] + hidden_states = inputs_embeds patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size ) + if patch_embeds is not None and not self.cross_attn_decoder: hidden_states = hidden_states + patch_embeds + if position_ids is None: - position_ids = torch.arange(embeds.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + position_ids = ( + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1) + ) + position_embeddings = self.rotary_emb(hidden_states, position_ids) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + for i, layer in enumerate(self.layers): if i == 0 or self.config.cross_attn_all_layers: # Remove cross_attention_states from kwargs if present to avoid multiple values error @@ -534,6 +545,7 @@ def forward( query_states = self.q_norm(hidden_states) query_states = self.q_proj(query_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + if cross_attention_states is not None: cross_attention_states = self.k_norm(cross_attention_states) key_states = self.k_proj(cross_attention_states) @@ -544,18 +556,20 @@ def forward( key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - elif cache_position is not None and cache_position[0] != 0: + elif cache_position[0] != 0: key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], + past_key_value.layers[self.layer_idx].keys, + past_key_value.layers[self.layer_idx].values, ) else: raise ValueError( "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" ) attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, @@ -612,8 +626,6 @@ def forward( @auto_docstring class BltPreTrainedModel(MllamaPreTrainedModel): - """Blt PreTrainedModel inheriting from Mllama but with Blt-specific init.""" - config: BltConfig base_model_prefix = "model" supports_gradient_checkpointing = True @@ -676,10 +688,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is None and inputs_embeds is None: - raise ValueError("You have to specify either input_ids or inputs_embeds") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if input_ids is not None: batch_size, sequence_length = input_ids.shape else: @@ -730,7 +740,7 @@ def forward( kwargs.pop("full_text_row_masked_out_mask", None) encoder_hidden_states, encoder_cross_states = self.local_encoder( input_ids=input_ids, - input_embeds=encoder_embeds, + inputs_embeds=encoder_embeds, patch_embeds=None, attention_mask=causal_mask, position_ids=position_ids, @@ -765,7 +775,7 @@ def forward( ) output = self.local_decoder( input_ids=input_ids, - embeds=encoder_hidden_states, + inputs_embeds=encoder_hidden_states, patch_embeds=global_hidden_states, attention_mask=causal_mask, position_ids=position_ids, @@ -778,7 +788,7 @@ def forward( ) return BaseModelOutputWithPast( last_hidden_state=output, - past_key_values=past_key_values if use_cache else None, + past_key_values=past_key_values, ) def get_input_embeddings(self): @@ -957,11 +967,6 @@ def __init__(self, config: BltConfig): self.post_init() - @can_return_tuple - @auto_docstring - def forward(self, **super_kwargs): - super().forward(**super_kwargs) - __all__ = [ "BltPreTrainedModel", From cb91d0e9eb5a0892418aeb71fb442632422dd35f Mon Sep 17 00:00:00 2001 From: itazap Date: Thu, 7 Aug 2025 09:29:48 +0200 Subject: [PATCH 109/139] apply feedback r2 --- src/demo_hf.py | 39 --- src/transformers/models/auto/modeling_auto.py | 3 +- .../models/blt/configuration_blt.py | 2 +- src/transformers/models/blt/modeling_blt.py | 243 +++++++++--------- src/transformers/models/blt/modular_blt.py | 204 ++++++++------- .../models/blt/tokenization_blt.py | 5 +- 6 files changed, 236 insertions(+), 260 deletions(-) delete mode 100644 src/demo_hf.py diff --git a/src/demo_hf.py b/src/demo_hf.py deleted file mode 100644 index 83f6c3e2d1a8..000000000000 --- a/src/demo_hf.py +++ /dev/null @@ -1,39 +0,0 @@ -import gc -import logging -import os - -import torch - -from transformers import AutoTokenizer -from transformers.models.blt.modeling_blt import BLTForCausalLM - - -logger = logging.getLogger() - -os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" - -gc.collect() -torch.cuda.empty_cache() - - -def main(prompt: str = "my name is", model_name: str = ""): - model = BLTForCausalLM.from_pretrained( - model_name, - device_map="auto", - ) - - tokenizer = AutoTokenizer.from_pretrained(model_name) - - inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - - generated_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False, use_cache=False) - - output_text = tokenizer.decode(generated_ids[0]) - - print(f'Model: "{model_name}"') - # print(f'Prompt: "{prompt}"') - print(f'Completion: "{output_text}"') - - -if __name__ == "__main__": - main(model_name="itazap/blt-1b-testing") diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 372e6b723c8a..3d0ee2e9fcbd 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -72,8 +72,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("blip-2", "Blip2Model"), ("blip_2_qformer", "Blip2QFormerModel"), ("bloom", "BloomModel"), - ("blt", "BltModel"), - ("blt", "BltModel"), + ("blt", "BltModel"), ("bridgetower", "BridgeTowerModel"), ("bros", "BrosModel"), ("camembert", "CamembertModel"), diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 1a3557b7f2bd..f51d7cd18685 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 Facebook Research and The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 3e71bcd16da1..daa0debe4b8b 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -50,8 +51,6 @@ if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask - -if is_torch_flex_attn_available(): from ...integrations.flex_attention import make_flex_block_causal_mask @@ -238,6 +237,13 @@ def eager_attention_forward( return attn_output, attn_weights +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -265,14 +271,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def rotate_half(x): - # From Cohere2. Split and rotate. Note that this function is different from e.g. Llama. - x1 = x[..., ::2] - x2 = x[..., 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - return rot_x - - class BltSelfAttention(nn.Module): def __init__(self, config: BltConfig, layer_idx: int): super().__init__() @@ -345,6 +343,85 @@ def forward( return attn_output, attn_weights +class BltCrossAttention(nn.Module): + """Cross-attention module for Blt, following transformers style""" + + def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] = None): + super().__init__() + self.config = config + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // self.num_heads + self.layer_idx = layer_idx + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = self.head_dim**-0.5 + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + query_states = self.q_norm(hidden_states) + query_states = self.q_proj(query_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + if cross_attention_states is not None: + cross_attention_states = self.k_norm(cross_attention_states) + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position[0] != 0: + key_states, value_states = ( + past_key_value.layers[self.layer_idx].keys, + past_key_value.layers[self.layer_idx].values, + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + attn_output = attn_output + hidden_states + return attn_output, attn_weights + + class BltLocalEncoder(nn.Module): def __init__(self, config: BltLocalEncoderConfig): super().__init__() @@ -376,7 +453,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - cross_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, @@ -417,7 +494,7 @@ def forward( cross_attention_output, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, - attention_mask=cross_mask, + attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, **kwargs, ) @@ -490,8 +567,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, - cross_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): @@ -520,7 +596,7 @@ def forward( cross_attention_output, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, cross_attention_states=patch_embeds, - attention_mask=cross_mask, + attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, **kwargs, ) @@ -537,85 +613,6 @@ def forward( return logits -class BltCrossAttention(nn.Module): - """Cross-attention module for Blt, following transformers style""" - - def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] = None): - super().__init__() - self.config = config - self.num_heads = self.config.num_attention_heads - self.num_key_value_heads = self.config.num_key_value_heads - self.dropout = config.dropout - self.hidden_size = config.hidden_size - self.head_dim = config.hidden_size // self.num_heads - self.layer_idx = layer_idx - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.scaling = self.head_dim**-0.5 - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.is_causal = False - - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - cache_position: Optional[torch.LongTensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - bsz, q_len, _ = hidden_states.size() - query_states = self.q_norm(hidden_states) - query_states = self.q_proj(query_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - if cross_attention_states is not None: - cross_attention_states = self.k_norm(cross_attention_states) - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - elif cache_position[0] != 0: - key_states, value_states = ( - past_key_value.layers[self.layer_idx].keys, - past_key_value.layers[self.layer_idx].values, - ) - else: - raise ValueError( - "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" - ) - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, - scaling=self.scaling, - **kwargs, - ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - attn_output = attn_output + hidden_states - return attn_output, attn_weights - - class BltGlobalTransformer(nn.Module): def __init__(self, config: BltGlobalTransformerConfig): super().__init__() @@ -659,7 +656,11 @@ class BltPreTrainedModel(PreTrainedModel): config: BltConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] + _no_split_modules = [ + "BltVisionEncoderLayer", + "BltCrossAttentionDecoderLayer", + "BltSelfAttentionDecoderLayer", + ] _can_compile_fullgraph = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn = True @@ -674,9 +675,6 @@ class BltPreTrainedModel(PreTrainedModel): _supports_static_cache = False # static cache cannot have different shapes for each layer - def _init_weights(self, module): - PreTrainedModel._init_weights(self, module) - def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -946,7 +944,7 @@ def _prepare_patch_cross_attention_mask( ) cross_attention_mask *= full_text_row_masked_out_mask - return cross_attention_mask, full_text_row_masked_out_mask + return cross_attention_mask def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: @@ -1078,8 +1076,17 @@ def forward( ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, encoder_embeds, cache_position, past_key_values) - cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=encoder_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + cross_attn_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) # Remove full_text_row_masked_out_mask from kwargs if present to avoid multiple values error @@ -1092,16 +1099,22 @@ def forward( position_ids=position_ids, past_key_values=None, cache_position=None, - cross_mask=cross_attn_mask_enc, - full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, + cross_attention_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, - **kwargs, ) global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) global_cache_position = torch.arange(0, global_hidden_states.shape[1], device=global_hidden_states.device) global_position_ids = global_cache_position.unsqueeze(0) - global_causal_mask = self._update_causal_mask(None, global_hidden_states, global_cache_position, None) + global_causal_mask = create_causal_mask( + config=self.config, + input_embeds=global_hidden_states, + attention_mask=None, + cache_position=global_cache_position, + past_key_values=None, + position_ids=None, + ) + global_hidden_states = self.global_transformer( input_embeds=global_hidden_states, attention_mask=global_causal_mask, @@ -1111,7 +1124,7 @@ def forward( **kwargs, ) decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) - cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( + cross_attn_mask_dec = _prepare_patch_cross_attention_mask( decoder_patch_ids, patch_lengths.shape[1], sequence_length, @@ -1127,10 +1140,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, - mask=None, - cross_mask=cross_attn_mask_dec, - full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, - **kwargs, + cross_attention_mask=cross_attn_mask_dec, ) return BaseModelOutputWithPast( last_hidden_state=output, @@ -1201,11 +1211,14 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) cache_position = torch.arange(sequence_length, device=input_embeds.device) - causal_mask = self._update_causal_mask( - None, # attention_mask - input_embeds, - cache_position, - None, # past_key_values + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=input_embeds, + attention_mask=None, + cache_position=cache_position, + past_key_values=None, + position_ids=None, ) for i, layer in enumerate(self.layers): diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 8b23221454e6..af0c9665f42a 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -14,7 +14,7 @@ # limitations under the License. """Blt modular model, inheriting from Mllama where appropriate.""" -from typing import Callable, Optional, Union +from typing import Callable, Optional import torch import torch.distributions @@ -22,9 +22,10 @@ import torch.nn.functional as F from ...cache_utils import Cache +from ...masking_utils import create_causal_mask from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging from ...utils.generic import OutputRecorder, check_model_inputs @@ -38,7 +39,7 @@ if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask + pass from ..mllama.modeling_mllama import ( @@ -202,7 +203,7 @@ def _prepare_patch_cross_attention_mask( ) cross_attention_mask *= full_text_row_masked_out_mask - return cross_attention_mask, full_text_row_masked_out_mask + return cross_attention_mask def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: @@ -294,14 +295,6 @@ def __init__(self, config, layer_idx: int): self.post_attention_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) -def rotate_half(x): - # From Cohere2. Split and rotate. Note that this function is different from e.g. Llama. - x1 = x[..., ::2] - x2 = x[..., 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - return rot_x - - class BltSelfAttention(MllamaTextSelfAttention): def __init__(self, config: BltConfig, layer_idx: int): super().__init__(config, layer_idx) @@ -330,6 +323,70 @@ def forward( ) +class BltCrossAttention(MllamaTextCrossAttention): + """Cross-attention module for Blt, following transformers style""" + + def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] = None): + super().__init__() + self.is_causal = False + self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ): + bsz, q_len, _ = hidden_states.size() + query_states = self.q_norm(hidden_states) + query_states = self.q_proj(query_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + if cross_attention_states is not None: + cross_attention_states = self.k_norm(cross_attention_states) + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position[0] != 0: + key_states, value_states = ( + past_key_value.layers[self.layer_idx].keys, + past_key_value.layers[self.layer_idx].values, + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + attn_output = attn_output + hidden_states + return attn_output, attn_weights + + class BltLocalEncoder(nn.Module): def __init__(self, config: BltLocalEncoderConfig): super().__init__() @@ -361,7 +418,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - cross_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, @@ -402,7 +459,7 @@ def forward( cross_attention_output, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, - attention_mask=cross_mask, + attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, **kwargs, ) @@ -475,8 +532,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, - cross_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): @@ -505,7 +561,7 @@ def forward( cross_attention_output, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, cross_attention_states=patch_embeds, - attention_mask=cross_mask, + attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, **kwargs, ) @@ -522,70 +578,6 @@ def forward( return logits -class BltCrossAttention(MllamaTextCrossAttention): - """Cross-attention module for Blt, following transformers style""" - - def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] = None): - super().__init__() - self.is_causal = False - self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) - self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - cache_position: Optional[torch.LongTensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs: Unpack[TransformersKwargs], - ): - bsz, q_len, _ = hidden_states.size() - query_states = self.q_norm(hidden_states) - query_states = self.q_proj(query_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - if cross_attention_states is not None: - cross_attention_states = self.k_norm(cross_attention_states) - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - elif cache_position[0] != 0: - key_states, value_states = ( - past_key_value.layers[self.layer_idx].keys, - past_key_value.layers[self.layer_idx].values, - ) - else: - raise ValueError( - "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" - ) - attention_interface: Callable = eager_attention_forward - - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, - scaling=self.scaling, - **kwargs, - ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - attn_output = attn_output + hidden_states - return attn_output, attn_weights - - class BltGlobalTransformer(nn.Module): def __init__(self, config: BltGlobalTransformerConfig): super().__init__() @@ -629,7 +621,6 @@ class BltPreTrainedModel(MllamaPreTrainedModel): config: BltConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] _supports_static_cache = False # static cache cannot have different shapes for each layer _supports_sdpa = True @@ -644,7 +635,7 @@ class BltPreTrainedModel(MllamaPreTrainedModel): } def _init_weights(self, module): - PreTrainedModel._init_weights(self, module) + raise AttributeError("No need to inherit it!") class BltModel(BltPreTrainedModel): @@ -732,8 +723,17 @@ def forward( ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, encoder_embeds, cache_position, past_key_values) - cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=encoder_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + cross_attn_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) # Remove full_text_row_masked_out_mask from kwargs if present to avoid multiple values error @@ -746,16 +746,22 @@ def forward( position_ids=position_ids, past_key_values=None, cache_position=None, - cross_mask=cross_attn_mask_enc, - full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, + cross_attention_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, - **kwargs, ) global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) global_cache_position = torch.arange(0, global_hidden_states.shape[1], device=global_hidden_states.device) global_position_ids = global_cache_position.unsqueeze(0) - global_causal_mask = self._update_causal_mask(None, global_hidden_states, global_cache_position, None) + global_causal_mask = create_causal_mask( + config=self.config, + input_embeds=global_hidden_states, + attention_mask=None, + cache_position=global_cache_position, + past_key_values=None, + position_ids=None, + ) + global_hidden_states = self.global_transformer( input_embeds=global_hidden_states, attention_mask=global_causal_mask, @@ -765,7 +771,7 @@ def forward( **kwargs, ) decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) - cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( + cross_attn_mask_dec = _prepare_patch_cross_attention_mask( decoder_patch_ids, patch_lengths.shape[1], sequence_length, @@ -781,10 +787,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, - mask=None, - cross_mask=cross_attn_mask_dec, - full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, - **kwargs, + cross_attention_mask=cross_attn_mask_dec, ) return BaseModelOutputWithPast( last_hidden_state=output, @@ -855,11 +858,14 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) cache_position = torch.arange(sequence_length, device=input_embeds.device) - causal_mask = self._update_causal_mask( - None, # attention_mask - input_embeds, - cache_position, - None, # past_key_values + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=input_embeds, + attention_mask=None, + cache_position=cache_position, + past_key_values=None, + position_ids=None, ) for i, layer in enumerate(self.layers): diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py index e4c0afdbc821..febf73cdb689 100644 --- a/src/transformers/models/blt/tokenization_blt.py +++ b/src/transformers/models/blt/tokenization_blt.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved. +# Copyright 2025 HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,9 +20,6 @@ from ...utils import logging -if TYPE_CHECKING: - pass - logger = logging.get_logger(__name__) # Blt tokenizer constants From f51e2f47c8ed539e0d76d60236a5a20f2ef12681 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 8 Aug 2025 10:55:03 +0000 Subject: [PATCH 110/139] fix cross_attention --- src/transformers/models/blt/modeling_blt.py | 30 ++++++++++----------- src/transformers/models/blt/modular_blt.py | 18 ++++++++----- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index daa0debe4b8b..655f929eebb5 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -238,10 +238,11 @@ def eager_attention_forward( def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -453,7 +454,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, @@ -494,7 +495,7 @@ def forward( cross_attention_output, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, - attention_mask=cross_attention_mask, + attention_mask=encoder_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, **kwargs, ) @@ -567,7 +568,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): @@ -596,7 +597,7 @@ def forward( cross_attention_output, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, cross_attention_states=patch_embeds, - attention_mask=cross_attention_mask, + attention_mask=encoder_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, **kwargs, ) @@ -656,11 +657,7 @@ class BltPreTrainedModel(PreTrainedModel): config: BltConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = [ - "BltVisionEncoderLayer", - "BltCrossAttentionDecoderLayer", - "BltSelfAttentionDecoderLayer", - ] + _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] _can_compile_fullgraph = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn = True @@ -1099,9 +1096,10 @@ def forward( position_ids=position_ids, past_key_values=None, cache_position=None, - cross_attention_mask=cross_attn_mask_enc, + encoder_attention_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, + **kwargs, ) global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) global_cache_position = torch.arange(0, global_hidden_states.shape[1], device=global_hidden_states.device) @@ -1140,7 +1138,8 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, - cross_attention_mask=cross_attn_mask_dec, + encoder_attention_mask=cross_attn_mask_dec, + **kwargs, ) return BaseModelOutputWithPast( last_hidden_state=output, @@ -1321,7 +1320,6 @@ class BltForCausalLM(BltPreTrainedModel, GenerationMixin): _can_compile_fullgraph = False base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] - _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] def __init__(self, config: BltConfig): super().__init__(config.get_text_config()) diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index af0c9665f42a..f9250d1bfe58 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -42,6 +42,8 @@ pass +from ..cohere2.modeling_cohere2 import rotate_half + from ..mllama.modeling_mllama import ( MllamaForCausalLM, MllamaPreTrainedModel, @@ -418,7 +420,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, @@ -459,7 +461,7 @@ def forward( cross_attention_output, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, - attention_mask=cross_attention_mask, + attention_mask=encoder_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, **kwargs, ) @@ -532,7 +534,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): @@ -561,7 +563,7 @@ def forward( cross_attention_output, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, cross_attention_states=patch_embeds, - attention_mask=cross_attention_mask, + attention_mask=encoder_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, **kwargs, ) @@ -621,6 +623,7 @@ class BltPreTrainedModel(MllamaPreTrainedModel): config: BltConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] _supports_static_cache = False # static cache cannot have different shapes for each layer _supports_sdpa = True @@ -746,9 +749,10 @@ def forward( position_ids=position_ids, past_key_values=None, cache_position=None, - cross_attention_mask=cross_attn_mask_enc, + encoder_attention_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, + **kwargs, ) global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) global_cache_position = torch.arange(0, global_hidden_states.shape[1], device=global_hidden_states.device) @@ -787,7 +791,8 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, - cross_attention_mask=cross_attn_mask_dec, + encoder_attention_mask=cross_attn_mask_dec, + **kwargs, ) return BaseModelOutputWithPast( last_hidden_state=output, @@ -963,7 +968,6 @@ class BltForCausalLM(MllamaForCausalLM): base_model_prefix = "model" _can_compile_fullgraph = False _tied_weights_keys = ["lm_head.weight"] - _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] def __init__(self, config: BltConfig): super().__init__(config) From 22a20f29d0257f200fae98b2bb856c0dce18dbb6 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 8 Aug 2025 16:50:03 +0000 Subject: [PATCH 111/139] apply more feedback --- .../models/auto/tokenization_auto.py | 4 +- src/transformers/models/blt/modeling_blt.py | 40 ++++++------------- src/transformers/models/blt/modular_blt.py | 32 +-------------- .../models/blt/tokenization_blt.py | 2 +- utils/check_repo.py | 2 +- 5 files changed, 18 insertions(+), 62 deletions(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index e128b927ab8e..785a47e4b30b 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -195,7 +195,6 @@ "LlamaTokenizerFast" if is_tokenizers_available() else None, ), ), - ("dia", ("DiaTokenizer", None)), ( "deepseek_vl", ( @@ -210,6 +209,7 @@ "LlamaTokenizerFast" if is_tokenizers_available() else None, ), ), + ("dia", ("DiaTokenizer", None)), ( "diffllama", ( @@ -295,8 +295,8 @@ ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), - ("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("glm4_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)), ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 655f929eebb5..f514cbe0ba48 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -238,11 +238,10 @@ def eager_attention_forward( def rotate_half(x): - # Split and rotate. Note that this function is different from e.g. Llama. - x1 = x[..., ::2] - x2 = x[..., 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - return rot_x + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -374,7 +373,6 @@ def forward( past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -455,7 +453,6 @@ def forward( past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], @@ -496,7 +493,6 @@ def forward( hidden_states=patch_embeds, cross_attention_states=hidden_states, attention_mask=encoder_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, **kwargs, ) patch_embeds = patch_embeds + cross_attention_output @@ -569,7 +565,6 @@ def forward( past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): batch_size = inputs_embeds.shape[0] @@ -598,7 +593,6 @@ def forward( hidden_states=hidden_states, cross_attention_states=patch_embeds, attention_mask=encoder_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, **kwargs, ) hidden_states = hidden_states + cross_attention_output @@ -655,9 +649,13 @@ def forward( @auto_docstring class BltPreTrainedModel(PreTrainedModel): config: BltConfig - base_model_prefix = "model" + base_model_prefix = "" supports_gradient_checkpointing = True - _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] + _no_split_modules = [ + "BltVisionEncoderLayer", + "BltCrossAttentionDecoderLayer", + "BltSelfAttentionDecoderLayer", + ] _can_compile_fullgraph = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn = True @@ -670,8 +668,6 @@ class BltPreTrainedModel(PreTrainedModel): "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), } - _supports_static_cache = False # static cache cannot have different shapes for each layer - def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], @@ -881,7 +877,6 @@ def _prepare_patch_cross_attention_mask( Returns: Tuple[torch.Tensor, torch.Tensor]: - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] - - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows """ batch_size, seq_len = patch_ids.shape device = patch_ids.device @@ -932,15 +927,6 @@ def _prepare_patch_cross_attention_mask( inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min ) - # Apply full-row bias (following mllama pattern exactly) - # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's - # last dimension contains negative infinity values, otherwise it's 1 - negative_inf_value = torch.finfo(dtype).min - full_text_row_masked_out_mask = ( - (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] - ) - cross_attention_mask *= full_text_row_masked_out_mask - return cross_attention_mask @@ -1086,8 +1072,6 @@ def forward( cross_attn_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) - # Remove full_text_row_masked_out_mask from kwargs if present to avoid multiple values error - kwargs.pop("full_text_row_masked_out_mask", None) encoder_hidden_states, encoder_cross_states = self.local_encoder( input_ids=input_ids, inputs_embeds=encoder_embeds, @@ -1221,10 +1205,10 @@ def forward( ) for i, layer in enumerate(self.layers): - layer_outputs = layer( + hidden_states = hidden_states.view(-1, hidden_states.size(-2), hidden_states.size(-1)) + hidden_states = layer( hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask ) - hidden_states = layer_outputs[0] logits = self.lm_head(self.norm(hidden_states)) logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index f9250d1bfe58..ab7962bde1b2 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -42,8 +42,6 @@ pass -from ..cohere2.modeling_cohere2 import rotate_half - from ..mllama.modeling_mllama import ( MllamaForCausalLM, MllamaPreTrainedModel, @@ -145,7 +143,6 @@ def _prepare_patch_cross_attention_mask( Returns: Tuple[torch.Tensor, torch.Tensor]: - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] - - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows """ batch_size, seq_len = patch_ids.shape device = patch_ids.device @@ -196,15 +193,6 @@ def _prepare_patch_cross_attention_mask( inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min ) - # Apply full-row bias (following mllama pattern exactly) - # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's - # last dimension contains negative infinity values, otherwise it's 1 - negative_inf_value = torch.finfo(dtype).min - full_text_row_masked_out_mask = ( - (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] - ) - cross_attention_mask *= full_text_row_masked_out_mask - return cross_attention_mask @@ -341,7 +329,6 @@ def forward( past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): bsz, q_len, _ = hidden_states.size() @@ -421,7 +408,6 @@ def forward( past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], @@ -462,7 +448,6 @@ def forward( hidden_states=patch_embeds, cross_attention_states=hidden_states, attention_mask=encoder_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, **kwargs, ) patch_embeds = patch_embeds + cross_attention_output @@ -535,7 +520,6 @@ def forward( past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ): batch_size = inputs_embeds.shape[0] @@ -564,7 +548,6 @@ def forward( hidden_states=hidden_states, cross_attention_states=patch_embeds, attention_mask=encoder_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, **kwargs, ) hidden_states = hidden_states + cross_attention_output @@ -621,15 +604,6 @@ def forward( @auto_docstring class BltPreTrainedModel(MllamaPreTrainedModel): config: BltConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] - - _supports_static_cache = False # static cache cannot have different shapes for each layer - _supports_sdpa = True - _supports_flash_attn = True - _supports_flex_attn = True - _supports_attention_backend = True _can_record_outputs = { "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), @@ -739,8 +713,6 @@ def forward( cross_attn_mask_enc = _prepare_patch_cross_attention_mask( patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, encoder_embeds.dtype ) - # Remove full_text_row_masked_out_mask from kwargs if present to avoid multiple values error - kwargs.pop("full_text_row_masked_out_mask", None) encoder_hidden_states, encoder_cross_states = self.local_encoder( input_ids=input_ids, inputs_embeds=encoder_embeds, @@ -874,10 +846,10 @@ def forward( ) for i, layer in enumerate(self.layers): - layer_outputs = layer( + hidden_states = hidden_states.view(-1, hidden_states.size(-2), hidden_states.size(-1)) + hidden_states = layer( hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask ) - hidden_states = layer_outputs[0] logits = self.lm_head(self.norm(hidden_states)) logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py index febf73cdb689..b4de4942f136 100644 --- a/src/transformers/models/blt/tokenization_blt.py +++ b/src/transformers/models/blt/tokenization_blt.py @@ -14,7 +14,7 @@ # limitations under the License. """Tokenization classes for Blt.""" -from typing import TYPE_CHECKING, Optional +from typing import Optional from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging diff --git a/utils/check_repo.py b/utils/check_repo.py index 77a2a0a2e258..7fe5915faa91 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -180,7 +180,7 @@ "CsmDepthDecoderForCausalLM", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. "CsmDepthDecoderModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. "CsmBackboneModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. - "BLTPatcher", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. + "BltPatcher", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. "Florence2VisionBackbone", # Building part of bigger (tested) model. Tested implicitly through Florence2ForConditionalGeneration. ] ) From 39be4145a2e2842aef8898a12111cd1b0fb2aaad Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 14 Aug 2025 21:40:57 +0000 Subject: [PATCH 112/139] update modeling fix --- .../models/blt/convert_blt_weights_to_hf.py | 2 +- src/transformers/models/blt/modeling_blt.py | 139 ++---------------- 2 files changed, 10 insertions(+), 131 deletions(-) diff --git a/src/transformers/models/blt/convert_blt_weights_to_hf.py b/src/transformers/models/blt/convert_blt_weights_to_hf.py index a436ca3560cf..7b5fc30d641e 100644 --- a/src/transformers/models/blt/convert_blt_weights_to_hf.py +++ b/src/transformers/models/blt/convert_blt_weights_to_hf.py @@ -72,7 +72,7 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> dict[str "vocab_size": unified_config.get("vocab_size", 256), "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_encoder", False), "cross_attn_k": unified_config.get("cross_attn_k", 2), - "hidden_size_global": unified_config.get("hidden_size_global", 2048), + "hidden_size_global": unified_config.get("dim_global", 2048), "pm_size": unified_config.get("pm_size", 0), "hidden_size": encoder_hidden_size, "num_attention_heads": unified_config.get("n_heads_local_encoder", 16), diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index f514cbe0ba48..25e3bd41df9b 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -238,10 +238,11 @@ def eager_attention_forward( def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -649,13 +650,9 @@ def forward( @auto_docstring class BltPreTrainedModel(PreTrainedModel): config: BltConfig - base_model_prefix = "" + base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = [ - "BltVisionEncoderLayer", - "BltCrossAttentionDecoderLayer", - "BltSelfAttentionDecoderLayer", - ] + _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] _can_compile_fullgraph = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn = True @@ -668,128 +665,9 @@ class BltPreTrainedModel(PreTrainedModel): "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), } - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask + _supports_static_cache = False # static cache cannot have different shapes for each layer - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): @@ -1304,6 +1182,7 @@ class BltForCausalLM(BltPreTrainedModel, GenerationMixin): _can_compile_fullgraph = False base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] + _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] def __init__(self, config: BltConfig): super().__init__(config.get_text_config()) From 6ecc6ff68736cdd05ced32a28c10fe235e010663 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Sun, 17 Aug 2025 16:58:47 +0000 Subject: [PATCH 113/139] load submodules from pretrainedmodel --- src/transformers/models/blt/modeling_blt.py | 72 +++++++++++---------- 1 file changed, 39 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 25e3bd41df9b..af4f43c06167 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -422,9 +422,31 @@ def forward( return attn_output, attn_weights -class BltLocalEncoder(nn.Module): +@auto_docstring +class BltPreTrainedModel(PreTrainedModel): + config: BltConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), + "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), + "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), + "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), + } + + +class BltLocalEncoder(BltPreTrainedModel): + config_class = BltLocalEncoderConfig + base_model_prefix = "local_encoder" + _no_split_modules = ["BltTransformerLayer"] + def __init__(self, config: BltLocalEncoderConfig): - super().__init__() + super().__init__(config) self.gradient_checkpointing = False self.config = config self.layers = nn.ModuleList( @@ -443,6 +465,7 @@ def __init__(self, config: BltLocalEncoderConfig): self.cross_attn_layers.append( BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) + self.post_init() def forward( self, @@ -532,9 +555,13 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): return reduced_embeddings -class BltLocalDecoder(nn.Module): +class BltLocalDecoder(BltPreTrainedModel): + config_class = BltLocalDecoderConfig + base_model_prefix = "local_decoder" + _no_split_modules = ["BltTransformerLayer"] + def __init__(self, config: BltLocalDecoderConfig): - super().__init__() + super().__init__(config) self.gradient_checkpointing = False self.config = config self.cross_attn_decoder = True @@ -554,6 +581,7 @@ def __init__(self, config: BltLocalDecoderConfig): self.cross_attn_layers.append( BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) + self.post_init() @check_model_inputs def forward( @@ -609,14 +637,19 @@ def forward( return logits -class BltGlobalTransformer(nn.Module): +class BltGlobalTransformer(BltPreTrainedModel): + config_class = BltGlobalTransformerConfig + base_model_prefix = "global_transformer" + _no_split_modules = ["BltTransformerLayer"] + def __init__(self, config: BltGlobalTransformerConfig): - super().__init__() + super().__init__(config) self.config = config self.layers = nn.ModuleList() for layer_idx in range(config.num_hidden_layers): self.layers.append(BltTransformerLayer(config, layer_idx)) self.rotary_emb = BltRotaryEmbedding(config=config) + self.post_init() def forward( self, @@ -647,29 +680,6 @@ def forward( return hidden_states -@auto_docstring -class BltPreTrainedModel(PreTrainedModel): - config: BltConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] - _can_compile_fullgraph = False # static cache cannot have different shapes for each layer - _supports_sdpa = True - _supports_flash_attn = True - _supports_flex_attn = True - _supports_attention_backend = True - _can_record_outputs = { - "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), - "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), - "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), - "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), - } - - _supports_static_cache = False # static cache cannot have different shapes for each layer - - - - def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): primes = [ 1000000007, @@ -856,10 +866,6 @@ class BltModel(BltPreTrainedModel): def __init__(self, config: BltConfig): super().__init__(config) self.gradient_checkpointing = False - config.patcher_config._attn_implementation = config._attn_implementation - config.encoder_config._attn_implementation = config._attn_implementation - config.decoder_config._attn_implementation = config._attn_implementation - config.global_config._attn_implementation = config._attn_implementation self.config = config self.local_encoder = BltLocalEncoder(config.encoder_config) From eea290d404d94ea967cca9909a0d17e672654a47 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Sun, 17 Aug 2025 19:31:15 +0000 Subject: [PATCH 114/139] set initializer_range to subconfigs --- .../models/blt/configuration_blt.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index f51d7cd18685..2ce9d84e3989 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -46,6 +46,7 @@ def __init__( rope_scaling=None, hidden_act="silu", intermediate_size=2816, + initializer_range=0.02, **kwargs, ): self.vocab_size = vocab_size @@ -65,6 +66,7 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.hidden_act = hidden_act + self.initializer_range = initializer_range super().__init__(**kwargs) @@ -93,6 +95,7 @@ def __init__( rope_scaling=None, hidden_act="silu", intermediate_size=2816, + initializer_range=0.02, **kwargs, ): self.vocab_size = vocab_size @@ -111,6 +114,7 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.hidden_act = hidden_act + self.initializer_range = initializer_range super().__init__(**kwargs) @@ -135,6 +139,7 @@ def __init__( rope_scaling=None, hidden_act="silu", intermediate_size=5632, + initializer_range=0.02, **kwargs, ): self.hidden_size = hidden_size @@ -149,6 +154,7 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.hidden_act = hidden_act + self.initializer_range = initializer_range super().__init__(**kwargs) @@ -202,6 +208,7 @@ def __init__( attn_bias_type="local_block_causal", intermediate_size=2048, rope_scaling=None, + initializer_range=0.02, **kwargs, ): self.vocab_size = vocab_size @@ -218,6 +225,7 @@ def __init__( self.hidden_act = "silu" # Blt uses silu activation self.intermediate_size = intermediate_size or int(8 * self.hidden_size / 3) self.rope_scaling = rope_scaling + self.initializer_range = initializer_range super().__init__(**kwargs) @@ -350,33 +358,37 @@ def __init__( # Initialize component configurations if patcher_config is None: - self.patcher_config = BltPatcherConfig() + self.patcher_config = BltPatcherConfig(initializer_range=initializer_range) logger.info("patcher_config is None, using default Blt patcher config") elif isinstance(patcher_config, dict): + patcher_config.setdefault("initializer_range", initializer_range) self.patcher_config = BltPatcherConfig(**patcher_config) elif isinstance(patcher_config, BltPatcherConfig): self.patcher_config = patcher_config if encoder_config is None: - self.encoder_config = BltLocalEncoderConfig() + self.encoder_config = BltLocalEncoderConfig(initializer_range=initializer_range) logger.info("encoder_config is None, using default Blt encoder config") elif isinstance(encoder_config, dict): + encoder_config.setdefault("initializer_range", initializer_range) self.encoder_config = BltLocalEncoderConfig(**encoder_config) elif isinstance(encoder_config, BltLocalEncoderConfig): self.encoder_config = encoder_config if decoder_config is None: - self.decoder_config = BltLocalDecoderConfig() + self.decoder_config = BltLocalDecoderConfig(initializer_range=initializer_range) logger.info("decoder_config is None, using default Blt decoder config") elif isinstance(decoder_config, dict): + decoder_config.setdefault("initializer_range", initializer_range) self.decoder_config = BltLocalDecoderConfig(**decoder_config) elif isinstance(decoder_config, BltLocalDecoderConfig): self.decoder_config = decoder_config if global_config is None: - self.global_config = BltGlobalTransformerConfig() + self.global_config = BltGlobalTransformerConfig(initializer_range=initializer_range) logger.info("global_config is None, using default Blt global config") elif isinstance(global_config, dict): + global_config.setdefault("initializer_range", initializer_range) self.global_config = BltGlobalTransformerConfig(**global_config) elif isinstance(global_config, BltGlobalTransformerConfig): self.global_config = global_config @@ -390,4 +402,4 @@ def __init__( "BltLocalEncoderConfig", "BltLocalDecoderConfig", "BltGlobalTransformerConfig", -] +] \ No newline at end of file From 294b80dc7fe65bc708774bfda0d3598998bb294d Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Sun, 17 Aug 2025 20:07:15 +0000 Subject: [PATCH 115/139] rm cross_attnetion_states pass when not needed --- src/transformers/models/blt/modeling_blt.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index af4f43c06167..9ef31f1153a0 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -511,8 +511,6 @@ def forward( batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size ) layer_idx = idx if self.config.cross_attn_all_layers else 0 - # Remove cross_attention_states from kwargs if present to avoid multiple values error - kwargs.pop("cross_attention_states", None) cross_attention_output, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, @@ -616,8 +614,6 @@ def forward( for i, layer in enumerate(self.layers): if i == 0 or self.config.cross_attn_all_layers: - # Remove cross_attention_states from kwargs if present to avoid multiple values error - kwargs.pop("cross_attention_states", None) cross_attention_output, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, cross_attention_states=patch_embeds, @@ -1212,7 +1208,6 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.LongTensor] = None, cross_attention_mask: Optional[torch.LongTensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, @@ -1224,9 +1219,6 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" - cross_attention_states (`torch.FloatTensor`, *optional*): - Output of the vision model, used for cross-attention. This tensor contains the processed image features that - the language model will attend to. cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*): Cross-attention mask to control the interaction between text tokens and image tiles. This 4D tensor defines which image tiles each text token should attend to. @@ -1270,7 +1262,6 @@ def forward( # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, - cross_attention_states=cross_attention_states, attention_mask=attention_mask, position_ids=position_ids, cross_attention_mask=cross_attention_mask, From 9ec7b28fb77b3dfbe5ba50d4545c2ac00c4948d2 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 18 Aug 2025 02:00:45 +0000 Subject: [PATCH 116/139] add 7b projection layer support --- docs/source/en/model_doc/blt.md | 10 - .../models/blt/configuration_blt.py | 8 +- src/transformers/models/blt/modeling_blt.py | 54 ++--- src/transformers/models/blt/modular_blt.py | 185 +++++++++++------- 4 files changed, 151 insertions(+), 106 deletions(-) diff --git a/docs/source/en/model_doc/blt.md b/docs/source/en/model_doc/blt.md index 8ab1fcc5dfdd..debabaf9c43a 100644 --- a/docs/source/en/model_doc/blt.md +++ b/docs/source/en/model_doc/blt.md @@ -90,13 +90,3 @@ The original code can be found [here](). [[autodoc]] BLTForTokenClassification - forward - -## FlaxBLTModel - -[[autodoc]] FlaxBLTModel - - __call__ - -## FlaxBLTForCausalLM - -[[autodoc]] FlaxBLTForCausalLM - - __call__ diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 2ce9d84e3989..cd9e63e040ed 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -393,6 +393,12 @@ def __init__( elif isinstance(global_config, BltGlobalTransformerConfig): self.global_config = global_config + # Determine if token embedding projection is needed based on dimension mismatch (7b) + encoder_cross_output_size = self.encoder_config.hidden_size * self.cross_attn_k + self.global_config.encoder_cross_output_size = ( + encoder_cross_output_size if encoder_cross_output_size != self.global_config.hidden_size else None + ) + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) @@ -402,4 +408,4 @@ def __init__( "BltLocalEncoderConfig", "BltLocalDecoderConfig", "BltGlobalTransformerConfig", -] \ No newline at end of file +] diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 9ef31f1153a0..b9bda8a58b8e 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -30,14 +30,13 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask -from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_blt import ( BltConfig, @@ -48,12 +47,6 @@ ) -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - class BltMLP(nn.Module): def __init__(self, config): super().__init__() @@ -95,16 +88,16 @@ class BltRotaryEmbedding(nn.Module): def __init__(self, config: BltConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" - self.rope_type = ( - config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - if config.rope_scaling is not None - else "default" - ) + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -112,14 +105,13 @@ def __init__(self, config: BltConfig, device=None): @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): - # Copied from Cohere2RotaryEmbedding.forward inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.repeat_interleave(freqs, 2, dim=-1) # Use Cohere2 pattern for compatibility + emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat() cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -238,11 +230,10 @@ def eager_attention_forward( def rotate_half(x): - # Split and rotate. Note that this function is different from e.g. Llama. - x1 = x[..., ::2] - x2 = x[..., 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - return rot_x + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -425,9 +416,10 @@ def forward( @auto_docstring class BltPreTrainedModel(PreTrainedModel): config: BltConfig - base_model_prefix = "model" + base_model_prefix = "" supports_gradient_checkpointing = True _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] + _can_compile_fullgraph = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True @@ -465,6 +457,7 @@ def __init__(self, config: BltLocalEncoderConfig): self.cross_attn_layers.append( BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) + self.post_init() def forward( @@ -579,6 +572,7 @@ def __init__(self, config: BltLocalDecoderConfig): self.cross_attn_layers.append( BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) + self.post_init() @check_model_inputs @@ -645,6 +639,15 @@ def __init__(self, config: BltGlobalTransformerConfig): for layer_idx in range(config.num_hidden_layers): self.layers.append(BltTransformerLayer(config, layer_idx)) self.rotary_emb = BltRotaryEmbedding(config=config) + + # Create token embedding projection (use nn.Identity() when no projection needed) + if getattr(config, "encoder_cross_output_size", None) is not None: + self.token_embedding_projection = nn.Linear( + config.encoder_cross_output_size, config.hidden_size, bias=False + ) + else: + self.token_embedding_projection = nn.Identity() + self.post_init() def forward( @@ -657,7 +660,7 @@ def forward( **kwargs: Unpack[TransformersKwargs], ): batch_size, seq_len, _ = input_embeds.shape - hidden_states = input_embeds + hidden_states = self.token_embedding_projection(input_embeds) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) if position_ids is None: position_ids = ( @@ -1184,7 +1187,6 @@ class BltForCausalLM(BltPreTrainedModel, GenerationMixin): _can_compile_fullgraph = False base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] - _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] def __init__(self, config: BltConfig): super().__init__(config.get_text_config()) @@ -1208,6 +1210,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, # Keep for compatibility cross_attention_mask: Optional[torch.LongTensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, @@ -1219,6 +1222,9 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" + cross_attention_states (`torch.FloatTensor`, *optional*): + Output of the vision model, used for cross-attention. This tensor contains the processed image features that + the language model will attend to. cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*): Cross-attention mask to control the interaction between text tokens and image tiles. This 4D tensor defines which image tiles each text token should attend to. @@ -1259,7 +1265,7 @@ def forward( I love the idea of snowflakes gently falling, each one ``` """ - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + # Call parent forward but exclude cross_attention_states from model call outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index ab7962bde1b2..0f4e1a7580ac 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -14,7 +14,7 @@ # limitations under the License. """Blt modular model, inheriting from Mllama where appropriate.""" -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch import torch.distributions @@ -23,29 +23,17 @@ from ...cache_utils import Cache from ...masking_utils import create_causal_mask -from ...modeling_outputs import BaseModelOutputWithPast -from ...modeling_rope_utils import dynamic_rope_update +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging +from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import OutputRecorder, check_model_inputs -from .configuration_blt import ( - BltConfig, - BltGlobalTransformerConfig, - BltLocalDecoderConfig, - BltLocalEncoderConfig, - BltPatcherConfig, +from ..cohere2.modeling_cohere2 import ( + Cohere2RotaryEmbedding, ) - - -if is_torch_flex_attn_available(): - pass - - from ..mllama.modeling_mllama import ( MllamaForCausalLM, MllamaPreTrainedModel, - MllamaRotaryEmbedding, MllamaSelfAttentionDecoderLayer, MllamaTextCrossAttention, MllamaTextMLP, @@ -53,6 +41,13 @@ MllamaTextSelfAttention, eager_attention_forward, ) +from .configuration_blt import ( + BltConfig, + BltGlobalTransformerConfig, + BltLocalDecoderConfig, + BltLocalEncoderConfig, + BltPatcherConfig, +) logger = logging.get_logger(__name__) @@ -248,31 +243,8 @@ class BltRMSNorm(MllamaTextRMSNorm): pass -class BltRotaryEmbedding(MllamaRotaryEmbedding): - def __init__(self, config: BltConfig, device=None): - super().__init__(config=config, device=device) - # BC: "rope_type" was originally "type" - self.rope_type = ( - config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - if config.rope_scaling is not None - else "default" - ) - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - # Copied from Cohere2RotaryEmbedding.forward - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.repeat_interleave(freqs, 2, dim=-1) # Use Cohere2 pattern for compatibility - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +class BltRotaryEmbedding(Cohere2RotaryEmbedding): + pass class BltTransformerLayer(MllamaSelfAttentionDecoderLayer): @@ -376,9 +348,34 @@ def forward( return attn_output, attn_weights -class BltLocalEncoder(nn.Module): +@auto_docstring +class BltPreTrainedModel(MllamaPreTrainedModel): + config: BltConfig + _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] + _can_record_outputs = { + "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), + "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), + "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), + "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), + } + + def _init_weights(self, module): + raise AttributeError("No need to inherit it!") + + def _update_causal_mask(self, module): + raise AttributeError("No need to inherit it!") + + def _prepare_4d_causal_attention_mask_with_cache_position(self, module): + raise AttributeError("No need to inherit it!") + + +class BltLocalEncoder(BltPreTrainedModel): + config_class = BltLocalEncoderConfig + base_model_prefix = "local_encoder" + _no_split_modules = ["BltTransformerLayer"] + def __init__(self, config: BltLocalEncoderConfig): - super().__init__() + super().__init__(config) self.gradient_checkpointing = False self.config = config self.layers = nn.ModuleList( @@ -398,6 +395,8 @@ def __init__(self, config: BltLocalEncoderConfig): BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) + self.post_init() + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -442,8 +441,6 @@ def forward( batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size ) layer_idx = idx if self.config.cross_attn_all_layers else 0 - # Remove cross_attention_states from kwargs if present to avoid multiple values error - kwargs.pop("cross_attention_states", None) cross_attention_output, _ = self.cross_attn_layers[layer_idx]( hidden_states=patch_embeds, cross_attention_states=hidden_states, @@ -486,9 +483,13 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): return reduced_embeddings -class BltLocalDecoder(nn.Module): +class BltLocalDecoder(BltPreTrainedModel): + config_class = BltLocalDecoderConfig + base_model_prefix = "local_decoder" + _no_split_modules = ["BltTransformerLayer"] + def __init__(self, config: BltLocalDecoderConfig): - super().__init__() + super().__init__(config) self.gradient_checkpointing = False self.config = config self.cross_attn_decoder = True @@ -509,6 +510,8 @@ def __init__(self, config: BltLocalDecoderConfig): BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) ) + self.post_init() + @check_model_inputs def forward( self, @@ -542,8 +545,6 @@ def forward( for i, layer in enumerate(self.layers): if i == 0 or self.config.cross_attn_all_layers: - # Remove cross_attention_states from kwargs if present to avoid multiple values error - kwargs.pop("cross_attention_states", None) cross_attention_output, _ = self.cross_attn_layers[i]( hidden_states=hidden_states, cross_attention_states=patch_embeds, @@ -563,15 +564,29 @@ def forward( return logits -class BltGlobalTransformer(nn.Module): +class BltGlobalTransformer(BltPreTrainedModel): + config_class = BltGlobalTransformerConfig + base_model_prefix = "global_transformer" + _no_split_modules = ["BltTransformerLayer"] + def __init__(self, config: BltGlobalTransformerConfig): - super().__init__() + super().__init__(config) self.config = config self.layers = nn.ModuleList() for layer_idx in range(config.num_hidden_layers): self.layers.append(BltTransformerLayer(config, layer_idx)) self.rotary_emb = BltRotaryEmbedding(config=config) + # Create token embedding projection (use nn.Identity() when no projection needed) + if getattr(config, "encoder_cross_output_size", None) is not None: + self.token_embedding_projection = nn.Linear( + config.encoder_cross_output_size, config.hidden_size, bias=False + ) + else: + self.token_embedding_projection = nn.Identity() + + self.post_init() + def forward( self, input_embeds: torch.Tensor, @@ -582,7 +597,7 @@ def forward( **kwargs: Unpack[TransformersKwargs], ): batch_size, seq_len, _ = input_embeds.shape - hidden_states = input_embeds + hidden_states = self.token_embedding_projection(input_embeds) hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) if position_ids is None: position_ids = ( @@ -601,28 +616,10 @@ def forward( return hidden_states -@auto_docstring -class BltPreTrainedModel(MllamaPreTrainedModel): - config: BltConfig - _can_record_outputs = { - "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), - "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), - "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), - "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), - } - - def _init_weights(self, module): - raise AttributeError("No need to inherit it!") - - class BltModel(BltPreTrainedModel): def __init__(self, config: BltConfig): super().__init__(config) self.gradient_checkpointing = False - config.patcher_config._attn_implementation = config._attn_implementation - config.encoder_config._attn_implementation = config._attn_implementation - config.decoder_config._attn_implementation = config._attn_implementation - config.global_config._attn_implementation = config._attn_implementation self.config = config self.local_encoder = BltLocalEncoder(config.encoder_config) @@ -949,6 +946,52 @@ def __init__(self, config: BltConfig): self.post_init() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, # Keep for compatibility + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + # Call parent forward but exclude cross_attention_states from model call + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]).float() + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + __all__ = [ "BltPreTrainedModel", From 2f9ab61191036deb692d899e7cf4a48d411bb273 Mon Sep 17 00:00:00 2001 From: itazap Date: Mon, 18 Aug 2025 07:58:39 -0400 Subject: [PATCH 117/139] check repo --- utils/check_repo.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/utils/check_repo.py b/utils/check_repo.py index 7fe5915faa91..eab765f821bf 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -181,6 +181,9 @@ "CsmDepthDecoderModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. "CsmBackboneModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. "BltPatcher", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. + "BltLocalEncoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. + "BltLocalDecoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. + "BltGlobalTransformer", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. "Florence2VisionBackbone", # Building part of bigger (tested) model. Tested implicitly through Florence2ForConditionalGeneration. ] ) From 3e280825a10e68f15d71e15db2e59d3ee22ca678 Mon Sep 17 00:00:00 2001 From: itazap Date: Tue, 19 Aug 2025 09:31:34 -0400 Subject: [PATCH 118/139] make copies --- .../models/blt/configuration_blt.py | 100 ++++++------------ .../models/blt/tokenization_blt.py | 27 ++--- utils/check_repo.py | 3 + 3 files changed, 45 insertions(+), 85 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index cd9e63e040ed..735d2cd10d02 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -164,32 +164,19 @@ class BltPatcherConfig(PretrainedConfig): Configuration class for the Blt Patcher/Entropy model component. Args: - vocab_size (`int`, *optional*, defaults to 256): - Vocabulary size for the entropy model used in patching. - hidden_size (`int`, *optional*, defaults to 512): - Hidden dimension for the entropy model. - num_hidden_layers (`int`, *optional*, defaults to 8): - Number of layers in the entropy model. - num_attention_heads (`int`, *optional*, defaults to 8): - Number of attention heads in the entropy model. - head_dim (`int`, *optional*): - Dimension of each attention head in the entropy model. - num_key_value_heads (`int`, *optional*): - Number of key-value heads in the entropy model. - max_position_embeddings (`int`, *optional*, defaults to 1024): - Maximum sequence length for the entropy model. - rms_norm_eps (`float`, *optional*, defaults to 1e-5): - Layer normalization epsilon for the entropy model. - dropout (`float`, *optional*, defaults to 0.0): - Dropout probability for the entropy model. - ffn_dim_multiplier (`float`, *optional*): - Feedforward dimension multiplier for the entropy model. - multiple_of (`int`, *optional*, defaults to 256): - Make feedforward dimension multiple of this for the entropy model. - rope_theta (`float`, *optional*, defaults to 10000.0): - RoPE theta parameter for the entropy model. - attn_bias_type (`str`, *optional*, defaults to "causal"): - Attention bias type for the entropy model. + vocab_size (``, *optional*, defaults to 260): + hidden_size (``, *optional*, defaults to 768): + num_hidden_layers (``, *optional*, defaults to 14): + num_attention_heads (``, *optional*, defaults to 12): + num_key_value_heads (``, *optional*): + max_position_embeddings (``, *optional*, defaults to 8192): + rms_norm_eps (``, *optional*, defaults to 1e-05): + dropout (``, *optional*, defaults to 0.0): + rope_theta (``, *optional*, defaults to 10000.0): + attn_bias_type (``, *optional*, defaults to `"local_block_causal"`): + intermediate_size (``, *optional*, defaults to 2048): + rope_scaling (``, *optional*): + initializer_range (``, *optional*, defaults to 0.02): """ model_type = "blt_patcher" @@ -239,47 +226,26 @@ class BltConfig(PretrainedConfig): documentation from [`PretrainedConfig`] for more information. Args: - vocab_size (`int`, *optional*, defaults to 256): - Vocabulary size of the Blt model. Defines the number of different tokens (bytes) that can be represented. - max_position_embeddings (`int`, *optional*, defaults to 1024): - The maximum sequence length that this model can handle. - # Patching configuration - patch_in_forward (`bool`, *optional*, defaults to False): - Whether to perform patching during forward pass. - patch_size (`float`, *optional*): - Size of patches for static patching. - patching_mode (`str`, *optional*): - Mode for patching ("entropy", "static", etc.). - patching_threshold (`float`, *optional*): - Threshold for entropy-based patching. - patching_batch_size (`int`, *optional*, defaults to 1): - Batch size for patching operations. - patching_device (`str`, *optional*, defaults to "cuda"): - Device to use for patching operations. - max_patch_length (`int`, *optional*): - Maximum length of patches. - - # Cross attention configurations - cross_attn_k (`int`, *optional*): - Number of cross attention components. - - # Encoder configurations - encoder_hash_byte_group_size (`Any`, *optional*): - Hash byte group size for encoder. - encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 30000): - Vocabulary size for hash byte groups. - encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3): - Number of hash functions for byte groups. - - # Component configurations - patcher_config (`Union[BltPatcherConfig, dict]`, *optional*): - Configuration for the Blt patcher/entropy model component. - encoder_config (`Union[BltLocalEncoderConfig, dict]`, *optional*): - Configuration for the Blt local encoder component. - decoder_config (`Union[BltLocalDecoderConfig, dict]`, *optional*): - Configuration for the Blt local decoder component. - global_config (`Union[BltGlobalTransformerConfig, dict]`, *optional*): - Configuration for the Blt global transformer component. + vocab_size (``, *optional*, defaults to 260): + max_position_embeddings (``, *optional*, defaults to 4096): + patch_in_forward (``, *optional*, defaults to `True`): + patch_size (``, *optional*, defaults to 4): + patching_mode (``, *optional*, defaults to `"entropy"`): + patching_threshold (``, *optional*, defaults to 1.34): + patching_batch_size (``, *optional*, defaults to 1): + max_patch_length (``, *optional*): + cross_attn_k (``, *optional*, defaults to 2): + encoder_hash_byte_group_size (``, *optional*): + encoder_hash_byte_group_vocab (``, *optional*, defaults to 500002): + encoder_hash_byte_group_nb_functions (``, *optional*, defaults to 1): + patcher_config (``, *optional*): + encoder_config (``, *optional*): + decoder_config (``, *optional*): + global_config (``, *optional*): + tie_word_embeddings (``, *optional*, defaults to `False`): + initializer_range (``, *optional*, defaults to 0.02): + rope_theta (``, *optional*, defaults to 500000.0): + rope_scaling (``, *optional*): ```python >>> from transformers import BltModel, BltConfig diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py index b4de4942f136..fd8b857474bb 100644 --- a/src/transformers/models/blt/tokenization_blt.py +++ b/src/transformers/models/blt/tokenization_blt.py @@ -44,24 +44,15 @@ class BltTokenizer(PreTrainedTokenizer): beginning of example (BOE), and padding (PAD). Args: - bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The beginning of sequence token. - eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The end of sequence token. - pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The padding token. - unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The unknown token. Not used in Blt but kept for compatibility. - boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The beginning of example token, specific to Blt. - add_bos_token (`bool`, *optional*, defaults to `True`): - Whether or not to add a `bos_token` at the start of sequences. - add_eos_token (`bool`, *optional*, defaults to `True`): - Whether or not to add an `eos_token` at the end of sequences. - clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): - Whether or not to cleanup spaces after decoding. - spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to add spaces between special tokens. + bos_token (``, *optional*, defaults to `""`): + eos_token (``, *optional*, defaults to `""`): + pad_token (``, *optional*, defaults to `""`): + unk_token (``, *optional*, defaults to `""`): + boe_token (``, *optional*, defaults to `""`): + add_bos_token (``, *optional*, defaults to `True`): + add_eos_token (``, *optional*, defaults to `False`): + clean_up_tokenization_spaces (``, *optional*, defaults to `False`): + spaces_between_special_tokens (``, *optional*, defaults to `False`): """ vocab_files_names = VOCAB_FILES_NAMES diff --git a/utils/check_repo.py b/utils/check_repo.py index eab765f821bf..3c4445a32f9a 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -99,6 +99,9 @@ "Glm4vVisionModel", "Glm4vMoeVisionModel", "EvollaSaProtPreTrainedModel", + "BltLocalEncoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. + "BltLocalDecoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. + "BltGlobalTransformer", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. "Ovis2VisionModel", ] From 52fa98715b73e51e8b5d0a6d58dde45b41251be9 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 19 Aug 2025 14:37:50 +0000 Subject: [PATCH 119/139] lost cohere2 rotate_half --- src/transformers/models/blt/modeling_blt.py | 9 +++++---- src/transformers/models/blt/modular_blt.py | 1 + 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index b9bda8a58b8e..fa6439cd3270 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -230,10 +230,11 @@ def eager_attention_forward( def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 0f4e1a7580ac..ea5bb895a1f0 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -30,6 +30,7 @@ from ...utils.generic import OutputRecorder, check_model_inputs from ..cohere2.modeling_cohere2 import ( Cohere2RotaryEmbedding, + rotate_half ) from ..mllama.modeling_mllama import ( MllamaForCausalLM, From f25630ca2460b0bd12625f4dfed26874b6e91bff Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 19 Aug 2025 14:43:23 +0000 Subject: [PATCH 120/139] ruff --- .../models/blt/configuration_blt.py | 2 +- src/transformers/models/blt/modeling_blt.py | 41 +++++++++++-------- src/transformers/models/blt/modular_blt.py | 23 +++++------ 3 files changed, 34 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 735d2cd10d02..8d037d091329 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -310,7 +310,7 @@ def __init__( self.max_patch_length = max_patch_length self.patching_device = kwargs.get("patching_device", "cuda") self.realtime_patching = kwargs.get("realtime_patching", True) - self.patching_threshold_add = kwargs.get("patching_threshold_add", None) + self.patching_threshold_add = kwargs.get("patching_threshold_add") self.monotonicity = kwargs.get("monotonicity", False) self.pm_size = kwargs.get("pm_size", 0) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index fa6439cd3270..fa76e177ad8c 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -37,6 +37,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_blt import ( BltConfig, @@ -85,6 +86,8 @@ def extra_repr(self): class BltRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + def __init__(self, config: BltConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" @@ -131,6 +134,7 @@ def __init__(self, config, layer_idx: int): self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -139,7 +143,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -155,7 +159,7 @@ def forward( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): @@ -174,7 +178,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -230,11 +234,10 @@ def eager_attention_forward( def rotate_half(x): - # Split and rotate. Note that this function is different from e.g. Llama. - x1 = x[..., ::2] - x2 = x[..., 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - return rot_x + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -284,13 +287,14 @@ def __init__(self, config: BltConfig, layer_idx: int): self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.is_causal = True + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, use_cache: bool = False, - past_key_value=None, + past_key_values=None, cache_position=None, **kwargs, ): @@ -309,10 +313,10 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -359,11 +363,12 @@ def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.is_causal = False + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], @@ -380,14 +385,14 @@ def forward( value_states = self.v_proj(cross_attention_states) key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: - key_states, value_states = past_key_value.update( + if past_key_values is not None: + key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) elif cache_position[0] != 0: key_states, value_states = ( - past_key_value.layers[self.layer_idx].keys, - past_key_value.layers[self.layer_idx].values, + past_key_values.layers[self.layer_idx].keys, + past_key_values.layers[self.layer_idx].values, ) else: raise ValueError( @@ -494,7 +499,7 @@ def forward( hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) @@ -620,7 +625,7 @@ def forward( hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index ea5bb895a1f0..86084bf1cdea 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -28,10 +28,7 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import OutputRecorder, check_model_inputs -from ..cohere2.modeling_cohere2 import ( - Cohere2RotaryEmbedding, - rotate_half -) +from ..cohere2.modeling_cohere2 import Cohere2RotaryEmbedding from ..mllama.modeling_mllama import ( MllamaForCausalLM, MllamaPreTrainedModel, @@ -269,7 +266,7 @@ def forward( attention_mask: torch.Tensor, position_embeddings: torch.Tensor, use_cache: bool = False, - past_key_value=None, + past_key_values=None, cache_position=None, **kwargs, ): @@ -280,7 +277,7 @@ def forward( attention_mask=attention_mask, position_embeddings=position_embeddings, use_cache=use_cache, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) @@ -299,7 +296,7 @@ def forward( self, hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], @@ -315,14 +312,14 @@ def forward( value_states = self.v_proj(cross_attention_states) key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: - key_states, value_states = past_key_value.update( + if past_key_values is not None: + key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) elif cache_position[0] != 0: key_states, value_states = ( - past_key_value.layers[self.layer_idx].keys, - past_key_value.layers[self.layer_idx].values, + past_key_values.layers[self.layer_idx].keys, + past_key_values.layers[self.layer_idx].values, ) else: raise ValueError( @@ -431,7 +428,7 @@ def forward( hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) @@ -557,7 +554,7 @@ def forward( hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) From 26706e59d7c37be84bf8bdd1922d42fd67ca36c6 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 19 Aug 2025 15:00:10 +0000 Subject: [PATCH 121/139] copies? --- docs/source/en/model_doc/blt.md | 40 +++++---------------- src/transformers/models/blt/modeling_blt.py | 9 ++--- src/transformers/models/blt/modular_blt.py | 5 ++- utils/check_repo.py | 4 +++ 4 files changed, 21 insertions(+), 37 deletions(-) diff --git a/docs/source/en/model_doc/blt.md b/docs/source/en/model_doc/blt.md index debabaf9c43a..cfab8d3a2422 100644 --- a/docs/source/en/model_doc/blt.md +++ b/docs/source/en/model_doc/blt.md @@ -45,48 +45,24 @@ This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface The original code can be found [here](). -## BLTConfig +## BltConfig -[[autodoc]] BLTConfig +[[autodoc]] BltConfig -## BLTTokenizer +## BltTokenizer -[[autodoc]] BLTTokenizer +[[autodoc]] BltTokenizer - build_inputs_with_special_tokens - get_special_tokens_mask - create_token_type_ids_from_sequences - save_vocabulary -## BLTTokenizerFast +## BltModel -[[autodoc]] BLTTokenizerFast - - build_inputs_with_special_tokens - - get_special_tokens_mask - - create_token_type_ids_from_sequences - - update_post_processor - - save_vocabulary - -## BLTModel - -[[autodoc]] BLTModel - - forward - -## BLTForCausalLM - -[[autodoc]] BLTForCausalLM - - forward - -## BLTForSequenceClassification - -[[autodoc]] BLTForSequenceClassification - - forward - -## BLTForQuestionAnswering - -[[autodoc]] BLTForQuestionAnswering +[[autodoc]] BltModel - forward -## BLTForTokenClassification +## BltForCausalLM -[[autodoc]] BLTForTokenClassification +[[autodoc]] BltForCausalLM - forward diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index fa76e177ad8c..26ee6aa2f071 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -234,10 +234,11 @@ def eager_attention_forward( def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 86084bf1cdea..dbf6c7159ab4 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -28,7 +28,10 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import OutputRecorder, check_model_inputs -from ..cohere2.modeling_cohere2 import Cohere2RotaryEmbedding +from ..cohere2.modeling_cohere2 import ( + Cohere2RotaryEmbedding, + rotate_half +) from ..mllama.modeling_mllama import ( MllamaForCausalLM, MllamaPreTrainedModel, diff --git a/utils/check_repo.py b/utils/check_repo.py index 3c4445a32f9a..ffd3fb56d773 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -409,6 +409,7 @@ "CsmDepthDecoderModel", # Building part of a bigger model "CsmDepthDecoderForCausalLM", # Building part of a bigger model "CsmForConditionalGeneration", # Building part of a bigger model + "BltPatcher", # Building part of a bigger model, tested implicitly through BltForCausalLM "Florence2VisionBackbone", # Building part of a bigger model ] @@ -1155,6 +1156,9 @@ def ignore_undocumented(name: str) -> bool: # MMBT model does not really work. if name.startswith("MMBT"): return True + # BLT models are internal building blocks, tested implicitly through BltForCausalLM + if name.startswith("Blt"): + return True if name in SHOULD_HAVE_THEIR_OWN_PAGE: return True return False From 35dde6ef6fee6cce0feb19e7313f413978ea1d41 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 19 Aug 2025 16:34:31 +0000 Subject: [PATCH 122/139] don't tie weights for submodules --- src/transformers/models/blt/configuration_blt.py | 5 ++++- src/transformers/models/blt/modeling_blt.py | 6 ++++++ src/transformers/models/blt/modular_blt.py | 8 +++++++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 8d037d091329..b3212fa8193c 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -258,7 +258,10 @@ class BltConfig(PretrainedConfig): >>> # Accessing the model configuration >>> configuration = model.config - ```""" + ``` + + Checkpoint: [facebook/blt](https://huggingface.co/facebook/blt) + """ model_type = "blt" keys_to_ignore_at_inference = ["past_key_values"] diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 26ee6aa2f071..9630811f412d 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -1302,5 +1302,11 @@ def forward( attentions=outputs.attentions, ) + def tie_weights(self): + """Prevent double execution from from_pretrained().""" + if hasattr(self, "_weights_tied"): + return + self._weights_tied = True + __all__ = ["BltPreTrainedModel", "BltModel", "BltPatcher", "BltForCausalLM"] diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index dbf6c7159ab4..6fd8d3c827ef 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -30,7 +30,7 @@ from ...utils.generic import OutputRecorder, check_model_inputs from ..cohere2.modeling_cohere2 import ( Cohere2RotaryEmbedding, - rotate_half + rotate_half, # noqa: F401 ) from ..mllama.modeling_mllama import ( MllamaForCausalLM, @@ -947,6 +947,12 @@ def __init__(self, config: BltConfig): self.post_init() + def tie_weights(self): + """Prevent double execution from from_pretrained().""" + if hasattr(self, '_weights_tied'): + return + self._weights_tied = True + def forward( self, input_ids: Optional[torch.LongTensor] = None, From f855e52f23942127d17566005007a69098d92644 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 20 Aug 2025 15:42:00 +0000 Subject: [PATCH 123/139] tie weights setting --- .../models/blt/configuration_blt.py | 21 ++++---- src/transformers/models/blt/modeling_blt.py | 16 +++---- src/transformers/models/blt/modular_blt.py | 16 +++---- tests/models/blt/test_modeling_blt.py | 48 ++++--------------- 4 files changed, 31 insertions(+), 70 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index b3212fa8193c..1bf6c0635a11 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -34,7 +34,6 @@ def __init__( cross_attn_all_layers=False, cross_attn_k=2, hidden_size_global=2048, - pm_size=0, hidden_size=1024, num_attention_heads=16, num_key_value_heads=None, @@ -47,13 +46,13 @@ def __init__( hidden_act="silu", intermediate_size=2816, initializer_range=0.02, + tie_word_embeddings=False, **kwargs, ): self.vocab_size = vocab_size self.cross_attn_all_layers = cross_attn_all_layers self.cross_attn_k = cross_attn_k self.hidden_size_global = hidden_size_global - self.pm_size = pm_size self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads or num_attention_heads @@ -68,7 +67,7 @@ def __init__( self.hidden_act = hidden_act self.initializer_range = initializer_range - super().__init__(**kwargs) + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) class BltLocalDecoderConfig(PretrainedConfig): @@ -96,6 +95,7 @@ def __init__( hidden_act="silu", intermediate_size=2816, initializer_range=0.02, + tie_word_embeddings=False, **kwargs, ): self.vocab_size = vocab_size @@ -116,7 +116,7 @@ def __init__( self.hidden_act = hidden_act self.initializer_range = initializer_range - super().__init__(**kwargs) + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) class BltGlobalTransformerConfig(PretrainedConfig): @@ -140,6 +140,7 @@ def __init__( hidden_act="silu", intermediate_size=5632, initializer_range=0.02, + tie_word_embeddings=False, **kwargs, ): self.hidden_size = hidden_size @@ -156,7 +157,7 @@ def __init__( self.hidden_act = hidden_act self.initializer_range = initializer_range - super().__init__(**kwargs) + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) class BltPatcherConfig(PretrainedConfig): @@ -173,7 +174,6 @@ class BltPatcherConfig(PretrainedConfig): rms_norm_eps (``, *optional*, defaults to 1e-05): dropout (``, *optional*, defaults to 0.0): rope_theta (``, *optional*, defaults to 10000.0): - attn_bias_type (``, *optional*, defaults to `"local_block_causal"`): intermediate_size (``, *optional*, defaults to 2048): rope_scaling (``, *optional*): initializer_range (``, *optional*, defaults to 0.02): @@ -192,10 +192,10 @@ def __init__( rms_norm_eps=1e-5, dropout=0.0, rope_theta=10000.0, - attn_bias_type="local_block_causal", intermediate_size=2048, rope_scaling=None, initializer_range=0.02, + tie_word_embeddings=False, **kwargs, ): self.vocab_size = vocab_size @@ -208,13 +208,12 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.dropout = dropout self.rope_theta = rope_theta - self.attn_bias_type = attn_bias_type self.hidden_act = "silu" # Blt uses silu activation self.intermediate_size = intermediate_size or int(8 * self.hidden_size / 3) self.rope_scaling = rope_scaling self.initializer_range = initializer_range - super().__init__(**kwargs) + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) class BltConfig(PretrainedConfig): @@ -290,14 +289,13 @@ def __init__( encoder_config=None, decoder_config=None, global_config=None, - tie_word_embeddings=False, + tie_word_embeddings=True, initializer_range=0.02, rope_theta=500000.0, rope_scaling=None, **kwargs, ): # Basic model configuration - self.tie_word_embeddings = tie_word_embeddings self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range @@ -315,7 +313,6 @@ def __init__( self.realtime_patching = kwargs.get("realtime_patching", True) self.patching_threshold_add = kwargs.get("patching_threshold_add") self.monotonicity = kwargs.get("monotonicity", False) - self.pm_size = kwargs.get("pm_size", 0) # Cross attention configurations self.cross_attn_k = cross_attn_k diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 9630811f412d..60cb9943a390 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -430,7 +430,7 @@ class BltPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True - _supports_attention_backend = True + _supports_attention_backend = False _can_record_outputs = { "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), @@ -440,7 +440,7 @@ class BltPreTrainedModel(PreTrainedModel): class BltLocalEncoder(BltPreTrainedModel): - config_class = BltLocalEncoderConfig + config: BltLocalEncoderConfig base_model_prefix = "local_encoder" _no_split_modules = ["BltTransformerLayer"] @@ -554,7 +554,7 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): class BltLocalDecoder(BltPreTrainedModel): - config_class = BltLocalDecoderConfig + config: BltLocalDecoderConfig base_model_prefix = "local_decoder" _no_split_modules = ["BltTransformerLayer"] @@ -635,7 +635,7 @@ def forward( class BltGlobalTransformer(BltPreTrainedModel): - config_class = BltGlobalTransformerConfig + config: BltGlobalTransformerConfig base_model_prefix = "global_transformer" _no_split_modules = ["BltTransformerLayer"] @@ -1040,6 +1040,8 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> class BltPatcher(BltPreTrainedModel): + config: BltPatcherConfig + def __init__(self, config: BltPatcherConfig): super().__init__(config) self.rotary_emb = BltRotaryEmbedding(config=self.config) @@ -1302,11 +1304,5 @@ def forward( attentions=outputs.attentions, ) - def tie_weights(self): - """Prevent double execution from from_pretrained().""" - if hasattr(self, "_weights_tied"): - return - self._weights_tied = True - __all__ = ["BltPreTrainedModel", "BltModel", "BltPatcher", "BltForCausalLM"] diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 6fd8d3c827ef..70fc06745051 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -352,6 +352,7 @@ def forward( @auto_docstring class BltPreTrainedModel(MllamaPreTrainedModel): config: BltConfig + _supports_attention_backend = False _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] _can_record_outputs = { "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), @@ -371,7 +372,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(self, module): class BltLocalEncoder(BltPreTrainedModel): - config_class = BltLocalEncoderConfig + config: BltLocalEncoderConfig base_model_prefix = "local_encoder" _no_split_modules = ["BltTransformerLayer"] @@ -485,7 +486,7 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): class BltLocalDecoder(BltPreTrainedModel): - config_class = BltLocalDecoderConfig + config: BltLocalDecoderConfig base_model_prefix = "local_decoder" _no_split_modules = ["BltTransformerLayer"] @@ -566,7 +567,7 @@ def forward( class BltGlobalTransformer(BltPreTrainedModel): - config_class = BltGlobalTransformerConfig + config: BltGlobalTransformerConfig base_model_prefix = "global_transformer" _no_split_modules = ["BltTransformerLayer"] @@ -789,6 +790,8 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> class BltPatcher(BltPreTrainedModel): + config: BltPatcherConfig + def __init__(self, config: BltPatcherConfig): super().__init__(config) self.rotary_emb = BltRotaryEmbedding(config=self.config) @@ -944,15 +947,8 @@ def __init__(self, config: BltConfig): self.vocab_size = config.vocab_size self.model = BltModel(config) self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) - self.post_init() - def tie_weights(self): - """Prevent double execution from from_pretrained().""" - if hasattr(self, '_weights_tied'): - return - self._weights_tied = True - def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index dcbf89d6168b..76ffcc8e6da9 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -322,34 +322,6 @@ def test_eager_padding_matches_padding_free_with_position_ids(): def test_sdpa_padding_matches_padding_free_with_position_ids(): return - # def test_attention_outputs(self): - # if not self.has_attentions: - # self.skipTest(reason="Model does not output attentions") - - # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - # config.return_dict = True - # config._attn_implementation = "eager" - - # for model_class in self.all_model_classes: - # inputs_dict["output_attentions"] = True - # inputs_dict["output_hidden_states"] = False - # config.return_dict = True - # model = model_class._from_config(config, attn_implementation="eager") - # config = model.config - # model.to(torch_device) - # model.eval() - # with torch.no_grad(): - # outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - # # For Blt, check separate attention outputs from each component - # attentions = outputs.attentions - # encoder_attentions = outputs.encoder_attentions - # global_attentions = outputs.global_attentions - - # # Each component should have attention outputs equal to their layer count - # self.assertEqual(len(attentions), config.decoder_config.num_hidden_layers) - # self.assertEqual(len(encoder_attentions), config.encoder_config.num_hidden_layers) - # self.assertEqual(len(global_attentions), config.global_config.num_hidden_layers) - @require_torch_accelerator class BltIntegrationTest(unittest.TestCase): @@ -368,9 +340,9 @@ def test_model(self): prompt = "my name is" - model = BltForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto", attn_implementation="sdpa") + model = BltForCausalLM.from_pretrained("itazap/blt-1b-hf", device_map="auto", attn_implementation="sdpa") - tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") + tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf") inputs = tokenizer(prompt, return_tensors="pt").to(model.device) @@ -455,7 +427,7 @@ def test_model_logits(self): input_ids = [1, 42, 21, 12, 43, 23, 1, 4] - model = BltForCausalLM.from_pretrained("itazap/blt-1b-testing", attn_implementation="sdpa", device_map="auto") + model = BltForCausalLM.from_pretrained("itazap/blt-1b-hf", attn_implementation="sdpa", device_map="auto") with torch.no_grad(): output = model(torch.tensor([input_ids]).to(torch_device))[0] @@ -473,10 +445,10 @@ def test_model_bf16(self): prompt = "my name is" model = BltForCausalLM.from_pretrained( - "itazap/blt-1b-testing", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16 + "itazap/blt-1b-hf", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16 ) - tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") + tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf") inputs = tokenizer(prompt, return_tensors="pt").to(model.device) @@ -565,7 +537,7 @@ def test_model_logits_bf16(self): input_ids = [1, 42, 21, 12, 43, 23, 1, 4] model = BltForCausalLM.from_pretrained( - "itazap/blt-1b-testing", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16 + "itazap/blt-1b-hf", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16 ) with torch.no_grad(): @@ -584,9 +556,9 @@ def test_model_eager(self): prompt = "my name is" - model = BltForCausalLM.from_pretrained("itazap/blt-1b-testing", device_map="auto", attn_implementation="eager") + model = BltForCausalLM.from_pretrained("itazap/blt-1b-hf", device_map="auto", attn_implementation="eager") - tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") + tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf") inputs = tokenizer(prompt, return_tensors="pt").to(model.device) @@ -608,12 +580,12 @@ def test_model_bf16_static_cache(self): prompt = "my name is" model = BltForCausalLM.from_pretrained( - "itazap/blt-1b-testing", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16 + "itazap/blt-1b-hf", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16 ) model.generation_config.cache_implementation = "static" - tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-testing") + tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf") inputs = tokenizer(prompt, return_tensors="pt").to(model.device) From 966e2f032d2569362ba22bba9077d185eadc9eca Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Wed, 20 Aug 2025 18:33:48 +0000 Subject: [PATCH 124/139] check docstrings --- .../models/auto/tokenization_auto.py | 1 + .../models/blt/configuration_blt.py | 24 ++++++++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 785a47e4b30b..3b00eeed8114 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -297,6 +297,7 @@ ("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("glm4_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("glm4v_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)), ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 1bf6c0635a11..90cb0f4b883f 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -46,7 +46,6 @@ def __init__( hidden_act="silu", intermediate_size=2816, initializer_range=0.02, - tie_word_embeddings=False, **kwargs, ): self.vocab_size = vocab_size @@ -67,7 +66,9 @@ def __init__( self.hidden_act = hidden_act self.initializer_range = initializer_range - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error + kwargs.pop("tie_word_embeddings", None) + super().__init__(**kwargs, tie_word_embeddings=False) class BltLocalDecoderConfig(PretrainedConfig): @@ -95,7 +96,6 @@ def __init__( hidden_act="silu", intermediate_size=2816, initializer_range=0.02, - tie_word_embeddings=False, **kwargs, ): self.vocab_size = vocab_size @@ -116,7 +116,9 @@ def __init__( self.hidden_act = hidden_act self.initializer_range = initializer_range - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error + kwargs.pop("tie_word_embeddings", None) + super().__init__(**kwargs, tie_word_embeddings=False) class BltGlobalTransformerConfig(PretrainedConfig): @@ -140,7 +142,6 @@ def __init__( hidden_act="silu", intermediate_size=5632, initializer_range=0.02, - tie_word_embeddings=False, **kwargs, ): self.hidden_size = hidden_size @@ -157,7 +158,9 @@ def __init__( self.hidden_act = hidden_act self.initializer_range = initializer_range - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error + kwargs.pop("tie_word_embeddings", None) + super().__init__(**kwargs, tie_word_embeddings=False) class BltPatcherConfig(PretrainedConfig): @@ -195,7 +198,6 @@ def __init__( intermediate_size=2048, rope_scaling=None, initializer_range=0.02, - tie_word_embeddings=False, **kwargs, ): self.vocab_size = vocab_size @@ -213,7 +215,9 @@ def __init__( self.rope_scaling = rope_scaling self.initializer_range = initializer_range - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error + kwargs.pop("tie_word_embeddings", None) + super().__init__(**kwargs, tie_word_embeddings=False) class BltConfig(PretrainedConfig): @@ -289,7 +293,7 @@ def __init__( encoder_config=None, decoder_config=None, global_config=None, - tie_word_embeddings=True, + tie_word_embeddings=False, initializer_range=0.02, rope_theta=500000.0, rope_scaling=None, @@ -365,6 +369,8 @@ def __init__( encoder_cross_output_size if encoder_cross_output_size != self.global_config.hidden_size else None ) + # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error + kwargs.pop("tie_word_embeddings", None) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) From 5513a6a6046a4b7430ad37dedea4bac234f3dd7d Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 22 Aug 2025 13:47:59 +0000 Subject: [PATCH 125/139] apply feedback --- .../models/blt/configuration_blt.py | 104 ++++++++++++------ src/transformers/models/blt/modeling_blt.py | 6 - src/transformers/models/blt/modular_blt.py | 6 - .../models/blt/tokenization_blt.py | 30 +++-- tests/models/blt/test_modeling_blt.py | 103 ++++++----------- 5 files changed, 128 insertions(+), 121 deletions(-) diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 90cb0f4b883f..0bc6718e5bd1 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -168,18 +168,37 @@ class BltPatcherConfig(PretrainedConfig): Configuration class for the Blt Patcher/Entropy model component. Args: - vocab_size (``, *optional*, defaults to 260): - hidden_size (``, *optional*, defaults to 768): - num_hidden_layers (``, *optional*, defaults to 14): - num_attention_heads (``, *optional*, defaults to 12): - num_key_value_heads (``, *optional*): - max_position_embeddings (``, *optional*, defaults to 8192): - rms_norm_eps (``, *optional*, defaults to 1e-05): - dropout (``, *optional*, defaults to 0.0): - rope_theta (``, *optional*, defaults to 10000.0): - intermediate_size (``, *optional*, defaults to 2048): - rope_scaling (``, *optional*): - initializer_range (``, *optional*, defaults to 0.02): + vocab_size (`int`, *optional*, defaults to 260): + Vocabulary size of the Blt patcher model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling the patcher model. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the hidden representations. + num_hidden_layers (`int`, *optional*, defaults to 14): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimension of the MLP representations. + rope_scaling (`dict`, *optional*): + Dictionary containing the RoPE scaling configuration. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. """ model_type = "blt_patcher" @@ -229,26 +248,47 @@ class BltConfig(PretrainedConfig): documentation from [`PretrainedConfig`] for more information. Args: - vocab_size (``, *optional*, defaults to 260): - max_position_embeddings (``, *optional*, defaults to 4096): - patch_in_forward (``, *optional*, defaults to `True`): - patch_size (``, *optional*, defaults to 4): - patching_mode (``, *optional*, defaults to `"entropy"`): - patching_threshold (``, *optional*, defaults to 1.34): - patching_batch_size (``, *optional*, defaults to 1): - max_patch_length (``, *optional*): - cross_attn_k (``, *optional*, defaults to 2): - encoder_hash_byte_group_size (``, *optional*): - encoder_hash_byte_group_vocab (``, *optional*, defaults to 500002): - encoder_hash_byte_group_nb_functions (``, *optional*, defaults to 1): - patcher_config (``, *optional*): - encoder_config (``, *optional*): - decoder_config (``, *optional*): - global_config (``, *optional*): - tie_word_embeddings (``, *optional*, defaults to `False`): - initializer_range (``, *optional*, defaults to 0.02): - rope_theta (``, *optional*, defaults to 500000.0): - rope_scaling (``, *optional*): + vocab_size (`int`, *optional*, defaults to 260): + Vocabulary size of the Blt model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BltModel`]. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + patch_in_forward (`bool`, *optional*, defaults to `True`): + Whether to perform patching during the forward pass. + patch_size (`int`, *optional*, defaults to 4): + Size of the patches used in the patching mechanism. + patching_mode (`str`, *optional*, defaults to `"entropy"`): + The mode used for patching, such as entropy-based patching. + patching_threshold (`float`, *optional*, defaults to 1.34): + Threshold value used for determining when to apply patches. + patching_batch_size (`int`, *optional*, defaults to 1): + Batch size used during the patching process. + max_patch_length (`int`, *optional*): + Maximum length of patches that can be generated. + cross_attn_k (`int`, *optional*, defaults to 2): + Number of cross-attention heads used in the model. + encoder_hash_byte_group_size (`list`, *optional*): + List of byte group sizes used in the encoder hash function. + encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 500002): + Vocabulary size for the encoder hash byte groups. + encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 1): + Number of hash functions used in the encoder byte grouping. + patcher_config (`BltPatcherConfig`, *optional*): + Configuration for the patcher component of the model. + encoder_config (`BltLocalEncoderConfig`, *optional*): + Configuration for the local encoder component of the model. + decoder_config (`BltLocalDecoderConfig`, *optional*): + Configuration for the local decoder component of the model. + global_config (`BltGlobalTransformerConfig`, *optional*): + Configuration for the global transformer component of the model. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rope_theta (`float`, *optional*, defaults to 500000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + Dictionary containing the RoPE scaling configuration. ```python >>> from transformers import BltModel, BltConfig diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 60cb9943a390..d1a0cce1d1be 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -441,8 +441,6 @@ class BltPreTrainedModel(PreTrainedModel): class BltLocalEncoder(BltPreTrainedModel): config: BltLocalEncoderConfig - base_model_prefix = "local_encoder" - _no_split_modules = ["BltTransformerLayer"] def __init__(self, config: BltLocalEncoderConfig): super().__init__(config) @@ -555,8 +553,6 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): class BltLocalDecoder(BltPreTrainedModel): config: BltLocalDecoderConfig - base_model_prefix = "local_decoder" - _no_split_modules = ["BltTransformerLayer"] def __init__(self, config: BltLocalDecoderConfig): super().__init__(config) @@ -636,8 +632,6 @@ def forward( class BltGlobalTransformer(BltPreTrainedModel): config: BltGlobalTransformerConfig - base_model_prefix = "global_transformer" - _no_split_modules = ["BltTransformerLayer"] def __init__(self, config: BltGlobalTransformerConfig): super().__init__(config) diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 70fc06745051..7a05b1327192 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -373,8 +373,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(self, module): class BltLocalEncoder(BltPreTrainedModel): config: BltLocalEncoderConfig - base_model_prefix = "local_encoder" - _no_split_modules = ["BltTransformerLayer"] def __init__(self, config: BltLocalEncoderConfig): super().__init__(config) @@ -487,8 +485,6 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): class BltLocalDecoder(BltPreTrainedModel): config: BltLocalDecoderConfig - base_model_prefix = "local_decoder" - _no_split_modules = ["BltTransformerLayer"] def __init__(self, config: BltLocalDecoderConfig): super().__init__(config) @@ -568,8 +564,6 @@ def forward( class BltGlobalTransformer(BltPreTrainedModel): config: BltGlobalTransformerConfig - base_model_prefix = "global_transformer" - _no_split_modules = ["BltTransformerLayer"] def __init__(self, config: BltGlobalTransformerConfig): super().__init__(config) diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py index fd8b857474bb..70ad64362f64 100644 --- a/src/transformers/models/blt/tokenization_blt.py +++ b/src/transformers/models/blt/tokenization_blt.py @@ -44,15 +44,27 @@ class BltTokenizer(PreTrainedTokenizer): beginning of example (BOE), and padding (PAD). Args: - bos_token (``, *optional*, defaults to `""`): - eos_token (``, *optional*, defaults to `""`): - pad_token (``, *optional*, defaults to `""`): - unk_token (``, *optional*, defaults to `""`): - boe_token (``, *optional*, defaults to `""`): - add_bos_token (``, *optional*, defaults to `True`): - add_eos_token (``, *optional*, defaults to `False`): - clean_up_tokenization_spaces (``, *optional*, defaults to `False`): - spaces_between_special_tokens (``, *optional*, defaults to `False`): + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of example token used for marking the start of individual examples in a sequence. + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. """ vocab_files_names = VOCAB_FILES_NAMES diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 76ffcc8e6da9..d6923a7c9b0b 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -17,6 +17,7 @@ import pytest from parameterized import parameterized +from torch import return_types from transformers import AutoTokenizer, is_torch_available, set_seed from transformers.testing_utils import ( @@ -204,18 +205,18 @@ class BltModelTest(CausalLMModelTest, unittest.TestCase): @pytest.mark.generate @parameterized.expand([("greedy", 1), ("beam search", 2)]) + @unittest.skip( + "Blt requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs" + ) def test_generate_from_inputs_embeds(self, _, num_beams): - """Skip this test for Blt as it has complex embedding computation that requires real token IDs for hash-based embeddings.""" - self.skipTest( - "Blt requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs" - ) + pass @pytest.mark.generate + @unittest.skip( + "Blt requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs" + ) def test_inputs_embeds_matches_input_ids(self): - """Skip this test for Blt as it has complex embedding computation that requires real token IDs for hash-based embeddings.""" - self.skipTest( - "Blt requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs" - ) + pass @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) @require_torch_sdpa @@ -247,14 +248,6 @@ def test_eager_matches_sdpa_inference( self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels, atols=atols ) - def test_torchscript_simple(self): - """Skip torchscript test for Blt as it has complex patching logic that's not compatible.""" - self.skipTest("Blt has complex patching logic that's not compatible with torchscript") - - def test_torchscript_output_hidden_state(self): - """Skip torchscript test for Blt as it has complex patching logic that's not compatible.""" - self.skipTest("Blt has complex patching logic that's not compatible with torchscript") - @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) def test_model_rope_scaling_from_config(self, scaling_type): """Override rope scaling from config test to handle Blt's sub-config structure.""" @@ -294,33 +287,9 @@ def test_model_rope_scaling_from_config(self, scaling_type): self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - @unittest.skip(reason="Training is not supported yet") - def test_training_gradient_checkpointing(self): - pass - - @unittest.skip( - reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" - ) - def test_training_gradient_checkpointing_use_reentrant(self): - pass - - @unittest.skip( - reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" - ) - def test_training_gradient_checkpointing_use_reentrant_false(self): - pass - @unittest.skip(reason="Decoder cannot keep gradients") def test_flex_attention_with_grads(): - return - - @unittest.skip(reason="Padding with patcher is complex") - def test_eager_padding_matches_padding_free_with_position_ids(): - return - - @unittest.skip(reason="Padding with patcher is complex") - def test_sdpa_padding_matches_padding_free_with_position_ids(): - return + return_types @require_torch_accelerator @@ -440,7 +409,7 @@ def test_model_logits(self): def test_model_bf16(self): """Test Blt model with bfloat16 precision.""" NUM_TOKENS_TO_GENERATE = 200 - EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m" + EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s" prompt = "my name is" @@ -464,74 +433,74 @@ def test_model_bf16(self): @require_torch_bf16 def test_model_logits_bf16(self): """Test Blt model logits with bfloat16 precision.""" + EXPECTED_OUTPUT = torch.tensor( [ [ -10.5000, -10.6875, -6.1875, - -10.5000, + -10.5625, -10.3125, - -9.1250, + -9.1875, -8.5000, - -8.6250, + -8.6875, -9.1875, -9.5625, -9.3750, -8.5000, -9.0625, - -3.4062, - 2.9688, + -3.4219, + 2.9531, -10.3125, -6.4062, - -5.9688, + -6.0000, -9.6875, -9.1875, -8.8125, -9.8125, -9.7500, -9.4375, - -9.7500, - -9.4375, + -9.8125, + -9.5000, -9.0000, - -9.7500, + -9.8125, -9.4375, -9.3125, ], [ - -13.3125, + -13.2500, -13.1875, -5.6875, - -13.2500, + -13.3125, -13.5000, -8.7500, + -7.0625, -7.0312, - -7.0000, -10.1250, -10.3750, -9.8750, - -7.7812, + -7.8438, -8.8750, - -5.2500, - -3.5312, - -12.5625, + -5.2812, + -3.5625, + -12.5000, -9.1875, - -6.7812, + -6.8125, -10.3750, - -9.2500, + -9.3125, -10.6250, -11.5000, - -11.1875, - -10.9375, + -11.2500, + -11.0000, -10.5625, -10.8750, -11.0625, -11.3750, - -10.5000, + -10.5625, -10.0000, ], - ], - dtype=torch.bfloat16, + ] ).to(torch_device) input_ids = [1, 42, 21, 12, 43, 23, 1, 4] @@ -543,8 +512,6 @@ def test_model_logits_bf16(self): with torch.no_grad(): output = model(torch.tensor([input_ids]).to(torch_device))[0] - # print(output[0, :2, :30]) - torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-3, atol=1e-3) @slow @@ -575,7 +542,7 @@ def test_model_eager(self): def test_model_bf16_static_cache(self): """Test Blt model with bfloat16 precision and static cache.""" NUM_TOKENS_TO_GENERATE = 200 - EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m" + EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s" prompt = "my name is" From 29144c79412ac8c942fab4f6eeac16d3e36d3088 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 22 Aug 2025 14:03:48 +0000 Subject: [PATCH 126/139] rebase --- tests/models/blt/test_modeling_blt.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index d6923a7c9b0b..97870bf374cb 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -26,7 +26,6 @@ require_torch, require_torch_accelerator, require_torch_bf16, - require_torch_sdpa, slow, torch_device, ) @@ -219,7 +218,6 @@ def test_inputs_embeds_matches_input_ids(self): pass @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) - @require_torch_sdpa def test_eager_matches_sdpa_inference( self, name, From 8869cc1917e3f9af618de13fc7b8462a4712274f Mon Sep 17 00:00:00 2001 From: itazap Date: Mon, 25 Aug 2025 09:54:50 -0400 Subject: [PATCH 127/139] rebased modeling --- src/transformers/models/blt/modeling_blt.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index d1a0cce1d1be..8b6e24797fa0 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -1200,12 +1200,6 @@ def __init__(self, config: BltConfig): self.post_init() - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - @can_return_tuple @auto_docstring def forward( From f3e62f00d621d71a07f75897558bfd70c3882644 Mon Sep 17 00:00:00 2001 From: itazap Date: Wed, 27 Aug 2025 10:37:22 -0400 Subject: [PATCH 128/139] update docs --- docs/source/en/model_doc/blt.md | 59 +++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/docs/source/en/model_doc/blt.md b/docs/source/en/model_doc/blt.md index cfab8d3a2422..ebc54039d181 100644 --- a/docs/source/en/model_doc/blt.md +++ b/docs/source/en/model_doc/blt.md @@ -24,25 +24,64 @@ rendered properly in your Markdown viewer. -# BLT - -# BLT +# Byte Lantet Transformer (BLT) ## Overview -The BLT model was proposed in []() by . - +The BLT model was proposed in [Byte Latent Transformer: Patches Scale Better Than Tokens]() by Artidoro Pagnoni, Ram Pasunuru, Pedro Rodriguez, John Nguyen, Benjamin Muller, Margaret Li1, Chunting Zhou, Lili Yu, Jason Weston, Luke Zettlemoyer, Gargi Ghosh, Mike Lewis, Ari Holtzman†, Srinivasan Iyer. +BLT is a byte-level LLM that achieves tokenization-level performance through entropy-based dynamic patching. The abstract from the paper is the following: -** +*We introduce the Byte Latent Transformer (BLT), a new byte-level LLM architecture that, for the first time, matches tokenization-based LLM performance at scale with significant improvements in inference +efficiency and robustness. BLT encodes bytes into dynamically sized patches, which serve as the primary units of computation. Patches are segmented based on the entropy of the next byte, allocating +more compute and model capacity where increased data complexity demands it. We present the first flop controlled scaling study of byte-level models up to 8B parameters and 4T training bytes. Our results demonstrate the feasibility of scaling models trained on raw bytes without a fixed vocabulary. Both training and inference efficiency improve due to dynamically selecting long patches when data is predictable, along with qualitative improvements on reasoning and long tail generalization. Overall, for fixed inference costs, BLT shows significantly better scaling than tokenization-based models, by simultaneously growing both patch and model size.* + +## Usage Tips: + +- **Dual Model Architecture**: BLT consists of two separate trained models: + - **Patcher (Entropy Model)**: A smaller transformer model that predicts byte-level entropy to determine patch boundaries and segment input. + - **Main Transformer Model**: The primary model that processes the patches through a Local Encoder, Global Transformer, and Local Decoder. + +- **Dynamic Patching**: The model uses entropy-based dynamic patching where: + - High-entropy regions (complex data) get shorter patches with more computational attention + - Low-entropy regions (predictable data) get longer patches for efficiency + - This allows the model to allocate compute resources where they're most needed + +- **Local Encoder**: Processes byte sequences with cross-attention to patch embeddings +- **Global Transformer**: Processes patch-level representations with full attention across patches +- **Local Decoder**: Generates output with cross-attention back to the original byte sequence + +- **Byte-Level Tokenizer**: Unlike traditional tokenizers that use learned vocabularies, BLT's tokenizer simply converts text to UTF-8 bytes and maps each byte to a token ID. There is no need for a vocabulary. + +The model can be loaded via: + + + +```python +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM + +tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf") +model = AutoModelForCausalLM.from_pretrained( + "itazap/blt-1b-hf", + device_map="auto", +) + +inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + +prompt = "my name is" +generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False +) -Tips: +print(tokenizer.decode(generated_ids[0])) +``` - + -This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). -The original code can be found [here](). +This model was contributed by [itazap](https://huggingface.co/). +The original code can be found [here](). ## BltConfig From cab52b59186e716dd16fa1c0cc70b77bb7e59e60 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 28 Aug 2025 18:12:59 +0000 Subject: [PATCH 129/139] applying feedback --- src/transformers/models/blt/modeling_blt.py | 496 +++++++++--------- src/transformers/models/blt/modular_blt.py | 412 ++++++++------- .../models/blt/tokenization_blt.py | 69 +-- tests/models/blt/test_modeling_blt.py | 3 +- tests/models/blt/test_tokenization_blt.py | 12 +- 5 files changed, 500 insertions(+), 492 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 8b6e24797fa0..b3136417985d 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -27,7 +27,7 @@ import torch.nn.functional as F from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -425,7 +425,7 @@ class BltPreTrainedModel(PreTrainedModel): config: BltConfig base_model_prefix = "" supports_gradient_checkpointing = True - _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] + _no_split_modules = ["BltTransformerLayer"] _can_compile_fullgraph = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn = True @@ -503,7 +503,7 @@ def forward( **kwargs, ) if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: - patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) + patch_embeds = self.patch_reduce(hidden_states, num_patches, patch_ids) patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size @@ -519,7 +519,7 @@ def forward( encoder_cross_states = patch_embeds return hidden_states, encoder_cross_states - def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): + def patch_reduce(self, hidden_states, max_num_patches, patch_ids): """ Reduce variable length patches to single embedding per patch Note: this works with variable number of patches for different sequences in the batch @@ -543,7 +543,7 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): src=hidden_states, dim=1, index=patch_ids, - reduce=reduction, + reduce="amax", include_self=False, ) reduced_embeddings = reduced_embeddings[:, :max_num_patches, :] @@ -680,28 +680,225 @@ def forward( return hidden_states -def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): - primes = [ - 1000000007, - 5915587277, - 1500450271, - 3267000013, - 5754853343, - 4093082899, - 9576890767, - 3628273133, - 2860486313, - 5463458053, - 3367900313, - ] - prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: + """ + Splits patch lengths into smaller segments if they exceed `max_patch_length`. + Pads the result to uniform length across the batch. + + Args: + patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. + max_patch_length (int, optional): Maximum allowed length per patch. + + Returns: + torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. + """ + if max_patch_length is None: + return patch_lengths + + batch_size = patch_lengths.size(0) + processed = [] + + for seq in patch_lengths: + splits = [] + for length in seq[seq > 0]: + length = length.item() + full_chunks, remainder = divmod(length, max_patch_length) + splits.extend([max_patch_length] * full_chunks) + if remainder: + splits.append(remainder) + processed.append(splits) + + # Find max length to pad to + max_len = max(len(splits) for splits in processed) + padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) + + for i, splits in enumerate(processed): + if splits: + padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) + + # Trim zero columns + if (padded != 0).any(dim=0).sum() < padded.shape[1]: + last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 + padded = padded[:, :last_nonzero] + + return padded + + +class BltPatcher(BltPreTrainedModel): + config: BltPatcherConfig + + def __init__(self, config: BltPatcherConfig): + super().__init__(config) + self.rotary_emb = BltRotaryEmbedding(config=self.config) + self.layers = nn.ModuleList() + for layer_idx in range(self.config.num_hidden_layers): + self.layers.append(BltTransformerLayer(self.config, layer_idx)) + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.lm_head = nn.Linear( + self.config.hidden_size, + self.config.vocab_size, + bias=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + patch_size: Optional[int] = None, + threshold: Optional[float] = None, + max_patch_length: Optional[int] = None, + **kwargs: Unpack[TransformersKwargs], + ): + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer in self.layers: + hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask) + + logits = self.lm_head(self.norm(hidden_states)) + prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() + + batch_size, sequence_length = inputs_embeds.shape[:2] + if patch_size is not None: + patch_lengths = self.patch_lengths_from_entropies( + entropies=prediction_entropies, + sequence_length=sequence_length, + patch_size=patch_size, + threshold=threshold, + ) + else: + patch_lengths = torch.ones( + (batch_size, sequence_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) + return prediction_entropies, patch_lengths, logits + + @staticmethod + def patch_lengths_from_entropies( + entropies, + sequence_length, + patch_size=None, + threshold=None, + ): + """ + Computes patch lengths from token entropies. + + Depending on whether a threshold is provided, the function uses either: + - Top-k selection based on entropy (when `threshold` is None), or + - Thresholding the entropy values (when `threshold` is set). + """ + + batch_size = entropies.shape[0] + + # Always include token 0 and 1 as starting tokens + init_tokens = ( + torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + ) + offset = init_tokens.shape[1] + + # Ignore first token entropy (BOS) + entropies = entropies[:, 1:] + + if threshold is None: + # Use top-k entropy values to define patch start points + num_patches = sequence_length // patch_size + topk_indices = entropies.topk(num_patches - 2, dim=1).indices + patch_starts = topk_indices.sort(dim=1).values + else: + # Threshold the entropy values to define patch start points + patch_mask = entropies > threshold + + seq_len = patch_mask.shape[1] + + # Create patch IDs (token indices), and add a sentinel to ensure alignment + token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) + sentinel = torch.full_like(token_indices, seq_len) + padded_indices = torch.cat([token_indices, sentinel], dim=1) + + # Pad mask with inverse to align sentinel correctly + padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) + + # Select indices where mask is True + patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) + max_valid_patches = patch_mask.sum(dim=1).max() + patch_starts = patch_starts[:, :max_valid_patches] + + # Offset patch starts to account for the two initial tokens + patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) + + # Compute patch end positions by shifting start positions + last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1) + patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1) + + patch_lengths = patch_ends - patch_start_ids + 1 + + return patch_lengths + + +def rolling_polynomial_hash(token_tensor, prime: int = 1000000007): + """ + A polynomial rolling hash algorithm that converts sequences + of tokens into hash values. The hash is computed as: + hash = (token_0 * prime^0 + token_1 * prime^1 + ... + token_n * prime^n) + + The rolling hash allows the model to efficiently + identify and encode recurring byte-level patterns in the input text. + + Args: + token_tensor (torch.Tensor): [batch_size, seq_len, group_size] containing token IDs to hash + prime (int): Prime number used as the base for the polynomial hash. + + Returns: + torch.Tensor: Hash values of shape [batch_size, seq_len] where each value + represents the hash of the corresponding token group + + Example: + >>> tokens = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> hashes = rolling_polynomial_hash(tokens, prime=31) + >>> # hash[0] = 1*31^0 + 2*31^1 + 3*31^2 + >>> # hash[1] = 4*31^0 + 5*31^1 + 6*31^2 + """ + prime_tensor = torch.tensor(prime, dtype=torch.int64, device=token_tensor.device) powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) - prime_powers = prime**powers + prime_powers = prime_tensor**powers return torch.sum(token_tensor * prime_powers, dim=-1) def byte_group_hash_function( - token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 + token_ids: torch.Tensor, group_size: int = 2, prime: int = 1000000007, max_hash: int = 30000 ): """Hash token groups and map to range [0, max_hash].""" with torch.no_grad(): @@ -712,7 +909,7 @@ def byte_group_hash_function( # Create sliding windows and compute hashes windows = padded_tokens.unfold(1, group_size, 1) - hashes = rolling_polynomial_hash(windows, hash_func_nb) + hashes = rolling_polynomial_hash(windows, prime) hash_values = hashes % max_hash return hash_values @@ -727,13 +924,27 @@ def compute_hash_embeddings( encoder_hash_byte_group_vocab: int, ) -> torch.Tensor: """Compute token embeddings enhanced with hash-based embeddings.""" + # Available primes for hash functions + primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, + ] + embeddings = local_encoder.embed_tokens(local_encoder_tokens) embedding_idx = 0 for func_nb in range(encoder_hash_byte_group_nb_functions): + prime = primes[func_nb % len(primes)] # Cycle through primes if more functions than primes for group_size in encoder_hash_byte_group_size: - hash_ids = byte_group_hash_function( - local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab - ) + hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab) embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) embedding_idx += 1 @@ -818,50 +1029,6 @@ def _prepare_patch_cross_attention_mask( return cross_attention_mask -def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: - """ - Splits patch lengths into smaller segments if they exceed `max_patch_length`. - Pads the result to uniform length across the batch. - - Args: - patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. - max_patch_length (int, optional): Maximum allowed length per patch. - - Returns: - torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. - """ - if max_patch_length is None: - return patch_lengths - - batch_size = patch_lengths.size(0) - processed = [] - - for seq in patch_lengths: - splits = [] - for length in seq[seq > 0]: - length = length.item() - full_chunks, remainder = divmod(length, max_patch_length) - splits.extend([max_patch_length] * full_chunks) - if remainder: - splits.append(remainder) - processed.append(splits) - - # Find max length to pad to - max_len = max(len(splits) for splits in processed) - padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) - - for i, splits in enumerate(processed): - if splits: - padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) - - # Trim zero columns - if (padded != 0).any(dim=0).sum() < padded.shape[1]: - last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 - padded = padded[:, :last_nonzero] - - return padded - - class BltModel(BltPreTrainedModel): def __init__(self, config: BltConfig): super().__init__(config) @@ -901,10 +1068,22 @@ def forward( ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if input_ids is not None: - batch_size, sequence_length = input_ids.shape - else: + + # Extract input embeddings as early as possible + if inputs_embeds is not None: + encoder_embeds = inputs_embeds batch_size, sequence_length, _ = inputs_embeds.shape + else: + batch_size, sequence_length = input_ids.shape + encoder_embeds = compute_hash_embeddings( + input_ids, + self.local_encoder, + self.encoder_hash_tok_embedding, + self.config.encoder_hash_byte_group_nb_functions, + self.config.encoder_hash_byte_group_size, + self.config.encoder_hash_byte_group_vocab, + ) + if patch_lengths is None: if self.config.patching_mode == "entropy" and self.patcher is not None: if input_ids is None: @@ -925,17 +1104,6 @@ def forward( self.config.max_patch_length, ) patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) - if inputs_embeds is not None: - encoder_embeds = inputs_embeds - else: - encoder_embeds = compute_hash_embeddings( - input_ids, - self.local_encoder, - self.encoder_hash_tok_embedding, - self.config.encoder_hash_byte_group_nb_functions, - self.config.encoder_hash_byte_group_size, - self.config.encoder_hash_byte_group_vocab, - ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( @@ -959,22 +1127,19 @@ def forward( encoder_hidden_states, encoder_cross_states = self.local_encoder( input_ids=input_ids, inputs_embeds=encoder_embeds, - patch_embeds=None, attention_mask=causal_mask, position_ids=position_ids, - past_key_values=None, - cache_position=None, encoder_attention_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, **kwargs, ) - global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - global_cache_position = torch.arange(0, global_hidden_states.shape[1], device=global_hidden_states.device) + encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) + global_cache_position = torch.arange(0, encoder_cross_states.shape[1], device=encoder_cross_states.device) global_position_ids = global_cache_position.unsqueeze(0) global_causal_mask = create_causal_mask( config=self.config, - input_embeds=global_hidden_states, + input_embeds=encoder_cross_states, attention_mask=None, cache_position=global_cache_position, past_key_values=None, @@ -982,11 +1147,9 @@ def forward( ) global_hidden_states = self.global_transformer( - input_embeds=global_hidden_states, + input_embeds=encoder_cross_states, attention_mask=global_causal_mask, position_ids=global_position_ids, - past_key_values=None, - cache_position=None, **kwargs, ) decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) @@ -1033,153 +1196,6 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 -class BltPatcher(BltPreTrainedModel): - config: BltPatcherConfig - - def __init__(self, config: BltPatcherConfig): - super().__init__(config) - self.rotary_emb = BltRotaryEmbedding(config=self.config) - self.layers = nn.ModuleList() - for layer_idx in range(self.config.num_hidden_layers): - self.layers.append(BltTransformerLayer(self.config, layer_idx)) - self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) - self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - self.lm_head = nn.Linear( - self.config.hidden_size, - self.config.vocab_size, - bias=False, - ) - - def forward( - self, - token_values: torch.Tensor, - patch_size: Optional[int] = None, - threshold: Optional[float] = None, - max_patch_length: Optional[int] = None, - patching_batch_size: int = 1, - device: Optional[str] = None, - **kwargs: Unpack[TransformersKwargs], - ): - entropies = [] - predictions = [] - max_length = self.config.max_position_embeddings - batch_numel = max_length * patching_batch_size - splits = torch.split(token_values.flatten(), batch_numel) - for split in splits: - pad_size = (max_length - (split.numel() % max_length)) % max_length - pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False) - split = torch.cat((split, pad), dim=0) - split = split.reshape(-1, max_length) - if device is not None: - split = split.to(device) - batch_size, sequence_length = split.shape - input_embeds = self.embed_tokens(split) - hidden_states = input_embeds - batch_size = input_embeds.shape[0] - position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - cache_position = torch.arange(sequence_length, device=input_embeds.device) - - causal_mask = create_causal_mask( - config=self.config, - input_embeds=input_embeds, - attention_mask=None, - cache_position=cache_position, - past_key_values=None, - position_ids=None, - ) - - for i, layer in enumerate(self.layers): - hidden_states = hidden_states.view(-1, hidden_states.size(-2), hidden_states.size(-1)) - hidden_states = layer( - hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask - ) - - logits = self.lm_head(self.norm(hidden_states)) - logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] - predictions.append(logits) - prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() - entropies.append(prediction_entropies) - concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) - concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) - batch_size, sequence_length = token_values.shape - if patch_size is not None: - patch_lengths = self.patch_lengths_from_entropies( - entropies=concat_entropies, - sequence_length=sequence_length, - patch_size=patch_size, - threshold=threshold, - ) - else: - patch_lengths = torch.ones( - (batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device - ) - patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) - return concat_entropies, patch_lengths, concat_predictions - - @staticmethod - def patch_lengths_from_entropies( - entropies, - sequence_length, - patch_size=None, - threshold=None, - ): - """ - Computes patch lengths from token entropies. - - Depending on whether a threshold is provided, the function uses either: - - Top-k selection based on entropy (when `threshold` is None), or - - Thresholding the entropy values (when `threshold` is set). - """ - - batch_size = entropies.shape[0] - - # Always include token 0 and 1 as starting tokens - init_tokens = ( - torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) - ) - offset = init_tokens.shape[1] - - # Ignore first token entropy (BOS) - entropies = entropies[:, 1:] - - if threshold is None: - # Use top-k entropy values to define patch start points - num_patches = sequence_length // patch_size - topk_indices = entropies.topk(num_patches - 2, dim=1).indices - patch_starts = topk_indices.sort(dim=1).values - else: - # Threshold the entropy values to define patch start points - patch_mask = entropies > threshold - - seq_len = patch_mask.shape[1] - - # Create patch IDs (token indices), and add a sentinel to ensure alignment - token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) - sentinel = torch.full_like(token_indices, seq_len) - padded_indices = torch.cat([token_indices, sentinel], dim=1) - - # Pad mask with inverse to align sentinel correctly - padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) - - # Select indices where mask is True - patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) - max_valid_patches = patch_mask.sum(dim=1).max() - patch_starts = patch_starts[:, :max_valid_patches] - - # Offset patch starts to account for the two initial tokens - patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) - - # Compute patch end positions by shifting start positions - last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1) - patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1) - - patch_lengths = patch_ends - patch_start_ids + 1 - - return patch_lengths - - @auto_docstring( custom_intro=""" The Blt Text Model with a language modeling head on top. diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 7a05b1327192..0a1201b8cca1 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -21,7 +21,7 @@ import torch.nn as nn import torch.nn.functional as F -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -54,28 +54,37 @@ logger = logging.get_logger(__name__) -def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): - primes = [ - 1000000007, - 5915587277, - 1500450271, - 3267000013, - 5754853343, - 4093082899, - 9576890767, - 3628273133, - 2860486313, - 5463458053, - 3367900313, - ] - prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) +def rolling_polynomial_hash(token_tensor, prime: int = 1000000007): + """ + A polynomial rolling hash algorithm that converts sequences + of tokens into hash values. The hash is computed as: + hash = (token_0 * prime^0 + token_1 * prime^1 + ... + token_n * prime^n) + + The rolling hash allows the model to efficiently + identify and encode recurring byte-level patterns in the input text. + + Args: + token_tensor (torch.Tensor): [batch_size, seq_len, group_size] containing token IDs to hash + prime (int): Prime number used as the base for the polynomial hash. + + Returns: + torch.Tensor: Hash values of shape [batch_size, seq_len] where each value + represents the hash of the corresponding token group + + Example: + >>> tokens = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> hashes = rolling_polynomial_hash(tokens, prime=31) + >>> # hash[0] = 1*31^0 + 2*31^1 + 3*31^2 + >>> # hash[1] = 4*31^0 + 5*31^1 + 6*31^2 + """ + prime_tensor = torch.tensor(prime, dtype=torch.int64, device=token_tensor.device) powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) - prime_powers = prime**powers + prime_powers = prime_tensor**powers return torch.sum(token_tensor * prime_powers, dim=-1) def byte_group_hash_function( - token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 + token_ids: torch.Tensor, group_size: int = 2, prime: int = 1000000007, max_hash: int = 30000 ): """Hash token groups and map to range [0, max_hash].""" with torch.no_grad(): @@ -86,7 +95,7 @@ def byte_group_hash_function( # Create sliding windows and compute hashes windows = padded_tokens.unfold(1, group_size, 1) - hashes = rolling_polynomial_hash(windows, hash_func_nb) + hashes = rolling_polynomial_hash(windows, prime) hash_values = hashes % max_hash return hash_values @@ -101,13 +110,27 @@ def compute_hash_embeddings( encoder_hash_byte_group_vocab: int, ) -> torch.Tensor: """Compute token embeddings enhanced with hash-based embeddings.""" + # Available primes for hash functions + primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, + ] + embeddings = local_encoder.embed_tokens(local_encoder_tokens) embedding_idx = 0 for func_nb in range(encoder_hash_byte_group_nb_functions): + prime = primes[func_nb % len(primes)] # Cycle through primes if more functions than primes for group_size in encoder_hash_byte_group_size: - hash_ids = byte_group_hash_function( - local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab - ) + hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab) embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) embedding_idx += 1 @@ -273,8 +296,6 @@ def forward( cache_position=None, **kwargs, ): - if hidden_states.dim() == 2: - hidden_states = hidden_states.unsqueeze(0) return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, @@ -353,7 +374,7 @@ def forward( class BltPreTrainedModel(MllamaPreTrainedModel): config: BltConfig _supports_attention_backend = False - _no_split_modules = ["BltTransformerLayer", "BltLocalEncoder", "BltLocalDecoder", "BltGlobalTransformer"] + _no_split_modules = ["BltTransformerLayer"] _can_record_outputs = { "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), @@ -435,7 +456,7 @@ def forward( **kwargs, ) if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: - patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) + patch_embeds = self.patch_reduce(hidden_states, num_patches, patch_ids) patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size @@ -451,7 +472,7 @@ def forward( encoder_cross_states = patch_embeds return hidden_states, encoder_cross_states - def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): + def patch_reduce(self, hidden_states, max_num_patches, patch_ids): """ Reduce variable length patches to single embedding per patch Note: this works with variable number of patches for different sequences in the batch @@ -475,7 +496,7 @@ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): src=hidden_states, dim=1, index=patch_ids, - reduce=reduction, + reduce="amax", include_self=False, ) reduced_embeddings = reduced_embeddings[:, :max_num_patches, :] @@ -612,6 +633,150 @@ def forward( return hidden_states +class BltPatcher(BltPreTrainedModel): + config: BltPatcherConfig + + def __init__(self, config: BltPatcherConfig): + super().__init__(config) + self.rotary_emb = BltRotaryEmbedding(config=self.config) + self.layers = nn.ModuleList() + for layer_idx in range(self.config.num_hidden_layers): + self.layers.append(BltTransformerLayer(self.config, layer_idx)) + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.lm_head = nn.Linear( + self.config.hidden_size, + self.config.vocab_size, + bias=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + patch_size: Optional[int] = None, + threshold: Optional[float] = None, + max_patch_length: Optional[int] = None, + **kwargs: Unpack[TransformersKwargs], + ): + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer in self.layers: + hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask) + + logits = self.lm_head(self.norm(hidden_states)) + prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() + + batch_size, sequence_length = inputs_embeds.shape[:2] + if patch_size is not None: + patch_lengths = self.patch_lengths_from_entropies( + entropies=prediction_entropies, + sequence_length=sequence_length, + patch_size=patch_size, + threshold=threshold, + ) + else: + patch_lengths = torch.ones( + (batch_size, sequence_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) + return prediction_entropies, patch_lengths, logits + + @staticmethod + def patch_lengths_from_entropies( + entropies, + sequence_length, + patch_size=None, + threshold=None, + ): + """ + Computes patch lengths from token entropies. + + Depending on whether a threshold is provided, the function uses either: + - Top-k selection based on entropy (when `threshold` is None), or + - Thresholding the entropy values (when `threshold` is set). + """ + + batch_size = entropies.shape[0] + + # Always include token 0 and 1 as starting tokens + init_tokens = ( + torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + ) + offset = init_tokens.shape[1] + + # Ignore first token entropy (BOS) + entropies = entropies[:, 1:] + + if threshold is None: + # Use top-k entropy values to define patch start points + num_patches = sequence_length // patch_size + topk_indices = entropies.topk(num_patches - 2, dim=1).indices + patch_starts = topk_indices.sort(dim=1).values + else: + # Threshold the entropy values to define patch start points + patch_mask = entropies > threshold + + seq_len = patch_mask.shape[1] + + # Create patch IDs (token indices), and add a sentinel to ensure alignment + token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) + sentinel = torch.full_like(token_indices, seq_len) + padded_indices = torch.cat([token_indices, sentinel], dim=1) + + # Pad mask with inverse to align sentinel correctly + padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) + + # Select indices where mask is True + patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) + max_valid_patches = patch_mask.sum(dim=1).max() + patch_starts = patch_starts[:, :max_valid_patches] + + # Offset patch starts to account for the two initial tokens + patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) + + # Compute patch end positions by shifting start positions + last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1) + patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1) + + patch_lengths = patch_ends - patch_start_ids + 1 + + return patch_lengths + + class BltModel(BltPreTrainedModel): def __init__(self, config: BltConfig): super().__init__(config) @@ -651,10 +816,22 @@ def forward( ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if input_ids is not None: - batch_size, sequence_length = input_ids.shape - else: + + # Extract input embeddings as early as possible + if inputs_embeds is not None: + encoder_embeds = inputs_embeds batch_size, sequence_length, _ = inputs_embeds.shape + else: + batch_size, sequence_length = input_ids.shape + encoder_embeds = compute_hash_embeddings( + input_ids, + self.local_encoder, + self.encoder_hash_tok_embedding, + self.config.encoder_hash_byte_group_nb_functions, + self.config.encoder_hash_byte_group_size, + self.config.encoder_hash_byte_group_vocab, + ) + if patch_lengths is None: if self.config.patching_mode == "entropy" and self.patcher is not None: if input_ids is None: @@ -675,17 +852,6 @@ def forward( self.config.max_patch_length, ) patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) - if inputs_embeds is not None: - encoder_embeds = inputs_embeds - else: - encoder_embeds = compute_hash_embeddings( - input_ids, - self.local_encoder, - self.encoder_hash_tok_embedding, - self.config.encoder_hash_byte_group_nb_functions, - self.config.encoder_hash_byte_group_size, - self.config.encoder_hash_byte_group_vocab, - ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( @@ -709,22 +875,19 @@ def forward( encoder_hidden_states, encoder_cross_states = self.local_encoder( input_ids=input_ids, inputs_embeds=encoder_embeds, - patch_embeds=None, attention_mask=causal_mask, position_ids=position_ids, - past_key_values=None, - cache_position=None, encoder_attention_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, **kwargs, ) - global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) - global_cache_position = torch.arange(0, global_hidden_states.shape[1], device=global_hidden_states.device) + encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) + global_cache_position = torch.arange(0, encoder_cross_states.shape[1], device=encoder_cross_states.device) global_position_ids = global_cache_position.unsqueeze(0) global_causal_mask = create_causal_mask( config=self.config, - input_embeds=global_hidden_states, + input_embeds=encoder_cross_states, attention_mask=None, cache_position=global_cache_position, past_key_values=None, @@ -732,11 +895,9 @@ def forward( ) global_hidden_states = self.global_transformer( - input_embeds=global_hidden_states, + input_embeds=encoder_cross_states, attention_mask=global_causal_mask, position_ids=global_position_ids, - past_key_values=None, - cache_position=None, **kwargs, ) decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) @@ -783,157 +944,10 @@ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 -class BltPatcher(BltPreTrainedModel): - config: BltPatcherConfig - - def __init__(self, config: BltPatcherConfig): - super().__init__(config) - self.rotary_emb = BltRotaryEmbedding(config=self.config) - self.layers = nn.ModuleList() - for layer_idx in range(self.config.num_hidden_layers): - self.layers.append(BltTransformerLayer(self.config, layer_idx)) - self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) - self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - self.lm_head = nn.Linear( - self.config.hidden_size, - self.config.vocab_size, - bias=False, - ) - - def forward( - self, - token_values: torch.Tensor, - patch_size: Optional[int] = None, - threshold: Optional[float] = None, - max_patch_length: Optional[int] = None, - patching_batch_size: int = 1, - device: Optional[str] = None, - **kwargs: Unpack[TransformersKwargs], - ): - entropies = [] - predictions = [] - max_length = self.config.max_position_embeddings - batch_numel = max_length * patching_batch_size - splits = torch.split(token_values.flatten(), batch_numel) - for split in splits: - pad_size = (max_length - (split.numel() % max_length)) % max_length - pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False) - split = torch.cat((split, pad), dim=0) - split = split.reshape(-1, max_length) - if device is not None: - split = split.to(device) - batch_size, sequence_length = split.shape - input_embeds = self.embed_tokens(split) - hidden_states = input_embeds - batch_size = input_embeds.shape[0] - position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - cache_position = torch.arange(sequence_length, device=input_embeds.device) - - causal_mask = create_causal_mask( - config=self.config, - input_embeds=input_embeds, - attention_mask=None, - cache_position=cache_position, - past_key_values=None, - position_ids=None, - ) - - for i, layer in enumerate(self.layers): - hidden_states = hidden_states.view(-1, hidden_states.size(-2), hidden_states.size(-1)) - hidden_states = layer( - hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask - ) - - logits = self.lm_head(self.norm(hidden_states)) - logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] - predictions.append(logits) - prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() - entropies.append(prediction_entropies) - concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) - concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) - batch_size, sequence_length = token_values.shape - if patch_size is not None: - patch_lengths = self.patch_lengths_from_entropies( - entropies=concat_entropies, - sequence_length=sequence_length, - patch_size=patch_size, - threshold=threshold, - ) - else: - patch_lengths = torch.ones( - (batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device - ) - patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) - return concat_entropies, patch_lengths, concat_predictions - - @staticmethod - def patch_lengths_from_entropies( - entropies, - sequence_length, - patch_size=None, - threshold=None, - ): - """ - Computes patch lengths from token entropies. - - Depending on whether a threshold is provided, the function uses either: - - Top-k selection based on entropy (when `threshold` is None), or - - Thresholding the entropy values (when `threshold` is set). - """ - - batch_size = entropies.shape[0] - - # Always include token 0 and 1 as starting tokens - init_tokens = ( - torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) - ) - offset = init_tokens.shape[1] - - # Ignore first token entropy (BOS) - entropies = entropies[:, 1:] - - if threshold is None: - # Use top-k entropy values to define patch start points - num_patches = sequence_length // patch_size - topk_indices = entropies.topk(num_patches - 2, dim=1).indices - patch_starts = topk_indices.sort(dim=1).values - else: - # Threshold the entropy values to define patch start points - patch_mask = entropies > threshold - - seq_len = patch_mask.shape[1] - - # Create patch IDs (token indices), and add a sentinel to ensure alignment - token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) - sentinel = torch.full_like(token_indices, seq_len) - padded_indices = torch.cat([token_indices, sentinel], dim=1) - - # Pad mask with inverse to align sentinel correctly - padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) - - # Select indices where mask is True - patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) - max_valid_patches = patch_mask.sum(dim=1).max() - patch_starts = patch_starts[:, :max_valid_patches] - - # Offset patch starts to account for the two initial tokens - patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) - - # Compute patch end positions by shifting start positions - last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1) - patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1) - - patch_lengths = patch_ends - patch_start_ids + 1 - - return patch_lengths - - class BltForCausalLM(MllamaForCausalLM): config: BltConfig - base_model_prefix = "model" _can_compile_fullgraph = False + base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: BltConfig): diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py index 70ad64362f64..9e3fd2015e4e 100644 --- a/src/transformers/models/blt/tokenization_blt.py +++ b/src/transformers/models/blt/tokenization_blt.py @@ -16,22 +16,12 @@ from typing import Optional -from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging logger = logging.get_logger(__name__) -# Blt tokenizer constants -SEP = " " -BOS_ID: int = 1 -EOS_ID: int = 2 -PAD_ID: int = 260 # Use valid ID after byte tokens (4-259) -BOE_ID: int = 0 -BPE_ID: int = 3 -OFFSET: int = 4 -BYTE_UNITS: int = 256 - VOCAB_FILES_NAMES = {} # Blt doesn't require external vocab files @@ -56,6 +46,8 @@ class BltTokenizer(PreTrainedTokenizer): token instead. boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): The beginning of example token used for marking the start of individual examples in a sequence. + additional_special_tokens (`list[str]` or `list[tokenizers.AddedToken]`, *optional*): + A list of additional special tokens to be added to the tokenizer's vocabulary. add_bos_token (`bool`, *optional*, defaults to `True`): Whether or not to add an `bos_token` at the start of sequences. add_eos_token (`bool`, *optional*, defaults to `False`): @@ -77,6 +69,7 @@ def __init__( pad_token="", unk_token="", boe_token="", + additional_special_tokens=None, add_bos_token=True, add_eos_token=False, clean_up_tokenization_spaces=False, @@ -86,30 +79,22 @@ def __init__( # Store Blt-specific parameters first self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token - self.vocab_size_unit_1 = BYTE_UNITS - self.offsetting_special_char = OFFSET - - # Blt token IDs (exactly like original) - self.boe_id = BOE_ID - self.bos_id = BOS_ID - self.eos_id = EOS_ID - self.pad_id = PAD_ID - self.bpe_id = BPE_ID - self.n_words = self.vocab_size_unit_1 + self.offsetting_special_char - self.boe_token = boe_token - - # Build encoder (token -> id) and decoder (id -> token) mappings - self.encoder = {} + self.byte_vocab_size = 256 # byte units (0-255) + + boe_token = AddedToken(boe_token, lstrip=False, rstrip=False) if isinstance(boe_token, str) else boe_token + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + self._added_tokens_decoder = {0: boe_token, 1: bos_token, 2: eos_token, 3: pad_token} + self.offset = len(self._added_tokens_decoder) + self._utf_vocab_size = 2**8 # utf is 8 bits - # Add special tokens to encoder - self.encoder[str(bos_token)] = self.bos_id - self.encoder[str(eos_token)] = self.eos_id - self.encoder[str(pad_token)] = self.pad_id - self.encoder[str(boe_token)] = self.boe_id + self.encoder = {} - # Add byte tokens (0-255) to encoder - for i in range(self.vocab_size_unit_1): - self.encoder[str(i)] = i + self.offsetting_special_char + # Add byte tokens (0-255) to encoder with offset + for i in range(self.byte_vocab_size): + self.encoder[str(i)] = i + self.offset # Create decoder as reverse of encoder self.decoder = {v: k for k, v in self.encoder.items()} @@ -120,6 +105,7 @@ def __init__( pad_token=pad_token, unk_token=unk_token, boe_token=boe_token, + additional_special_tokens=additional_special_tokens, add_bos_token=add_bos_token, add_eos_token=add_eos_token, clean_up_tokenization_spaces=clean_up_tokenization_spaces, @@ -130,8 +116,8 @@ def __init__( @property def vocab_size(self): """Returns vocab size""" - # Account for byte tokens (4-259) plus special tokens (0,1,2,3,260) - return max(self.vocab_size_unit_1 + self.offsetting_special_char, PAD_ID + 1) + # Account for byte tokens plus special tokens + return self._utf_vocab_size + self.offset def get_vocab(self): """Returns vocab as a dict""" @@ -155,13 +141,8 @@ def convert_tokens_to_string(self, tokens: list[str]) -> str: byte_values = [] for token in tokens: - # Skip special tokens by checking if they're in encoder but not byte tokens - if token in self.encoder and token in { - str(self.bos_token), - str(self.eos_token), - str(self.pad_token), - str(self.boe_token), - }: + # Skip special tokens by checking if they're in added_tokens_encoder + if token in self.added_tokens_encoder: continue try: @@ -196,8 +177,8 @@ def build_inputs_with_special_tokens( Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ - bos = [self.bos_id] if self.add_bos_token else [] - eos = [self.eos_id] if self.add_eos_token else [] + bos = [self.bos_token_id] if self.add_bos_token else [] + eos = [self.eos_token_id] if self.add_eos_token else [] if token_ids_1 is None: return bos + token_ids_0 + eos diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 97870bf374cb..9336d1d0b1c2 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -17,7 +17,6 @@ import pytest from parameterized import parameterized -from torch import return_types from transformers import AutoTokenizer, is_torch_available, set_seed from transformers.testing_utils import ( @@ -287,7 +286,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): @unittest.skip(reason="Decoder cannot keep gradients") def test_flex_attention_with_grads(): - return_types + pass @require_torch_accelerator diff --git a/tests/models/blt/test_tokenization_blt.py b/tests/models/blt/test_tokenization_blt.py index d2077104ac2c..e2a71bf1163a 100644 --- a/tests/models/blt/test_tokenization_blt.py +++ b/tests/models/blt/test_tokenization_blt.py @@ -47,19 +47,17 @@ def test_blt_tokenizer_basics(self): """Test basic Blt tokenizer functionality""" tokenizer = BltTokenizer() - # Test vocab size (256 bytes + 4 offset + special tokens) - self.assertEqual(tokenizer.vocab_size, 261) + # Test vocab size (256 bytes + 4 special tokens) + self.assertEqual(tokenizer.vocab_size, 260) # Test special token IDs - self.assertEqual(tokenizer.bos_id, 1) - self.assertEqual(tokenizer.eos_id, 2) - self.assertEqual(tokenizer.boe_id, 0) - self.assertEqual(tokenizer.pad_id, 260) + self.assertEqual(tokenizer.bos_token_id, 1) + self.assertEqual(tokenizer.eos_token_id, 2) + self.assertEqual(tokenizer.pad_token_id, 3) # Test special tokens self.assertEqual(str(tokenizer.bos_token), "") self.assertEqual(str(tokenizer.eos_token), "") - self.assertEqual(str(tokenizer.boe_token), "") self.assertEqual(str(tokenizer.pad_token), "") def test_simple_encode_decode(self): From d45f2603697eb26b9291ee8f538c843ae662e426 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 29 Aug 2025 20:55:45 +0000 Subject: [PATCH 130/139] few more fixes --- src/transformers/models/blt/modeling_blt.py | 2 - .../models/blt/tokenization_blt.py | 37 +++++++++++-------- tests/models/blt/test_tokenization_blt.py | 36 ------------------ 3 files changed, 22 insertions(+), 53 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index b3136417985d..dd86aec87bbe 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -299,8 +299,6 @@ def forward( cache_position=None, **kwargs, ): - if hidden_states.dim() == 2: - hidden_states = hidden_states.unsqueeze(0) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py index 9e3fd2015e4e..5a92f2b57a33 100644 --- a/src/transformers/models/blt/tokenization_blt.py +++ b/src/transformers/models/blt/tokenization_blt.py @@ -80,25 +80,16 @@ def __init__( self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token self.byte_vocab_size = 256 # byte units (0-255) - + boe_token = AddedToken(boe_token, lstrip=False, rstrip=False) if isinstance(boe_token, str) else boe_token bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token - + self._added_tokens_decoder = {0: boe_token, 1: bos_token, 2: eos_token, 3: pad_token} self.offset = len(self._added_tokens_decoder) self._utf_vocab_size = 2**8 # utf is 8 bits - self.encoder = {} - - # Add byte tokens (0-255) to encoder with offset - for i in range(self.byte_vocab_size): - self.encoder[str(i)] = i + self.offset - - # Create decoder as reverse of encoder - self.decoder = {v: k for k, v in self.encoder.items()} - super().__init__( bos_token=bos_token, eos_token=eos_token, @@ -121,20 +112,36 @@ def vocab_size(self): def get_vocab(self): """Returns vocab as a dict""" - return dict(self.encoder, **self.added_tokens_encoder) + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab def _convert_token_to_id(self, token: str) -> int: """Converts a token (str) to an id using the vocab.""" - return self.encoder.get(token, self.added_tokens_encoder.get(token, self.unk_token_id)) + if token in self.added_tokens_encoder: + return self.added_tokens_encoder[token] + + # For byte tokens, convert string representation to integer and add offset + try: + byte_val = int(token) + if 0 <= byte_val <= 255: + return byte_val + self.offset + except ValueError: + pass + + return self.unk_token_id def _convert_id_to_token(self, index: int) -> str: """Converts an index (integer) to a token (str) using the vocab.""" - # Check added tokens first (they might override special token IDs) for token, token_id in self.added_tokens_encoder.items(): if token_id == index: return token - return self.decoder.get(index, str(self.unk_token)) + # For byte tokens, subtract offset and convert to string + if self.offset <= index < self.vocab_size: + return str(index - self.offset) + + return str(self.unk_token) def convert_tokens_to_string(self, tokens: list[str]) -> str: """Converts a sequence of tokens to a single string.""" diff --git a/tests/models/blt/test_tokenization_blt.py b/tests/models/blt/test_tokenization_blt.py index e2a71bf1163a..b14cd3ab64a8 100644 --- a/tests/models/blt/test_tokenization_blt.py +++ b/tests/models/blt/test_tokenization_blt.py @@ -43,23 +43,6 @@ def get_tokenizers(self, **kwargs): kwargs.update({"add_bos_token": True, "add_eos_token": False}) return super().get_tokenizers(**kwargs) - def test_blt_tokenizer_basics(self): - """Test basic Blt tokenizer functionality""" - tokenizer = BltTokenizer() - - # Test vocab size (256 bytes + 4 special tokens) - self.assertEqual(tokenizer.vocab_size, 260) - - # Test special token IDs - self.assertEqual(tokenizer.bos_token_id, 1) - self.assertEqual(tokenizer.eos_token_id, 2) - self.assertEqual(tokenizer.pad_token_id, 3) - - # Test special tokens - self.assertEqual(str(tokenizer.bos_token), "") - self.assertEqual(str(tokenizer.eos_token), "") - self.assertEqual(str(tokenizer.pad_token), "") - def test_simple_encode_decode(self): tokenizer = BltTokenizer(add_bos_token=False, add_eos_token=False) @@ -155,25 +138,6 @@ def test_empty_and_whitespace(self): decoded = tokenizer.decode(encoded) self.assertEqual(decoded, " ") - def test_get_vocab(self): - tokenizer = BltTokenizer() - vocab = tokenizer.get_vocab() - - # Should contain special tokens - self.assertIn(str(tokenizer.bos_token), vocab) - self.assertIn(str(tokenizer.eos_token), vocab) - self.assertIn(str(tokenizer.boe_token), vocab) - self.assertIn(str(tokenizer.pad_token), vocab) - - # Should contain byte representations - self.assertIn("0", vocab) # First byte - self.assertIn("255", vocab) # Last byte - - self.assertEqual(vocab[str(tokenizer.bos_token)], 1) - self.assertEqual(vocab[str(tokenizer.eos_token)], 2) - self.assertEqual(vocab["0"], 4) # 0 + 4 offset - self.assertEqual(vocab["255"], 259) # 255 + 4 offset - def test_build_inputs_with_special_tokens(self): tokenizer = BltTokenizer(add_bos_token=True, add_eos_token=True) From 7ccff57e51097d7ffad58c88e81e49fc47cbbbe1 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 11 Sep 2025 13:12:57 +0000 Subject: [PATCH 131/139] fix can_record_outputs --- src/transformers/models/blt/modeling_blt.py | 8 ++++++-- src/transformers/models/blt/modular_blt.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index dd86aec87bbe..870b019ecfbd 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -432,13 +432,14 @@ class BltPreTrainedModel(PreTrainedModel): _can_record_outputs = { "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), - "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), - "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), } class BltLocalEncoder(BltPreTrainedModel): config: BltLocalEncoderConfig + _can_record_outputs = { + "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), + } def __init__(self, config: BltLocalEncoderConfig): super().__init__(config) @@ -630,6 +631,9 @@ def forward( class BltGlobalTransformer(BltPreTrainedModel): config: BltGlobalTransformerConfig + _can_record_outputs = { + "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), + } def __init__(self, config: BltGlobalTransformerConfig): super().__init__(config) diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 0a1201b8cca1..bd886a1398d0 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -378,8 +378,6 @@ class BltPreTrainedModel(MllamaPreTrainedModel): _can_record_outputs = { "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"), "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), - "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), - "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), } def _init_weights(self, module): @@ -394,6 +392,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(self, module): class BltLocalEncoder(BltPreTrainedModel): config: BltLocalEncoderConfig + _can_record_outputs = { + "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"), + } def __init__(self, config: BltLocalEncoderConfig): super().__init__(config) @@ -585,6 +586,9 @@ def forward( class BltGlobalTransformer(BltPreTrainedModel): config: BltGlobalTransformerConfig + _can_record_outputs = { + "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), + } def __init__(self, config: BltGlobalTransformerConfig): super().__init__(config) From 90a9a2fb1a5bda3ed75b849e5dc9dc9e991a9071 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 15 Sep 2025 09:12:32 +0000 Subject: [PATCH 132/139] fast tokenizer --- src/transformers/convert_slow_tokenizer.py | 36 +++ .../models/blt/tokenization_blt.py | 84 ++++--- .../models/blt/tokenization_blt_fast.py | 154 ++++++++++++ tests/models/blt/test_modeling_blt.py | 8 +- tests/models/blt/test_tokenization_blt.py | 232 ++++++++---------- 5 files changed, 347 insertions(+), 167 deletions(-) create mode 100644 src/transformers/models/blt/tokenization_blt_fast.py diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index a9e7c9bff5bc..7098296dd010 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -1540,6 +1540,41 @@ def post_processor(self): ) +class BltConverter(Converter): + def converted(self) -> Tokenizer: + byte_encoder = bytes_to_unicode() + + vocab = {} + + vocab[""] = 0 + vocab[""] = 1 + vocab[""] = 2 + vocab[""] = 3 + + # Add byte tokens using unicode characters (compatible with ByteLevel) + offset = 4 # Start after special tokens + for byte_val, unicode_char in byte_encoder.items(): + vocab[unicode_char] = byte_val + offset + + tokenizer = Tokenizer(BPE(vocab, [], continuing_subword_prefix="", end_of_word_suffix="")) + + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) + tokenizer.decoder = decoders.ByteLevel() + + bos = str(self.original_tokenizer.bos_token) + bos_token_id = self.original_tokenizer.bos_token_id + + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{bos}:0 $A:0", + pair=f"{bos}:0 $A:0 $B:1", + special_tokens=[ + (bos, bos_token_id), + ], + ) + + return tokenizer + + # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode def bytes_to_unicode(): """ @@ -1653,6 +1688,7 @@ def converted(self) -> Tokenizer: "BertTokenizer": BertConverter, "BigBirdTokenizer": BigBirdConverter, "BlenderbotTokenizer": BlenderbotConverter, + "BltTokenizer": BltConverter, "CamembertTokenizer": CamembertConverter, "CLIPTokenizer": CLIPConverter, "CodeGenTokenizer": GPT2Converter, diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py index 5a92f2b57a33..f2eadf06317b 100644 --- a/src/transformers/models/blt/tokenization_blt.py +++ b/src/transformers/models/blt/tokenization_blt.py @@ -14,6 +14,8 @@ # limitations under the License. """Tokenization classes for Blt.""" +import json +import os from typing import Optional from ...tokenization_utils import AddedToken, PreTrainedTokenizer @@ -22,7 +24,7 @@ logger = logging.get_logger(__name__) -VOCAB_FILES_NAMES = {} # Blt doesn't require external vocab files +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"} class BltTokenizer(PreTrainedTokenizer): @@ -107,7 +109,6 @@ def __init__( @property def vocab_size(self): """Returns vocab size""" - # Account for byte tokens plus special tokens return self._utf_vocab_size + self.offset def get_vocab(self): @@ -117,7 +118,7 @@ def get_vocab(self): return vocab def _convert_token_to_id(self, token: str) -> int: - """Converts a token (str) to an id using the vocab.""" + """Converts a token (str) in an id using the vocab.""" if token in self.added_tokens_encoder: return self.added_tokens_encoder[token] @@ -132,7 +133,7 @@ def _convert_token_to_id(self, token: str) -> int: return self.unk_token_id def _convert_id_to_token(self, index: int) -> str: - """Converts an index (integer) to a token (str) using the vocab.""" + """Converts an index (integer) in a token (str) using the vocab.""" for token, token_id in self.added_tokens_encoder.items(): if token_id == index: return token @@ -144,25 +145,32 @@ def _convert_id_to_token(self, index: int) -> str: return str(self.unk_token) def convert_tokens_to_string(self, tokens: list[str]) -> str: - """Converts a sequence of tokens to a single string.""" + """Converts a sequence of tokens (string) in a single string.""" byte_values = [] for token in tokens: - # Skip special tokens by checking if they're in added_tokens_encoder - if token in self.added_tokens_encoder: - continue - - try: - byte_val = int(token) - if 0 <= byte_val <= 255: - byte_values.append(byte_val) - except ValueError: - continue + if token in self.added_tokens_decoder: + tok_string = self.added_tokens_decoder[token].encode("utf-8") + byte_values.extend(tok_string) + elif token in self.added_tokens_encoder: + tok_string = token.encode("utf-8") + byte_values.extend(tok_string) + else: + try: + byte_val = int(token) + if 0 <= byte_val <= 255: + byte_values.append(byte_val) + except ValueError: + continue return bytes(byte_values).decode("utf-8", errors="ignore") def _tokenize(self, text: str, **kwargs) -> list[str]: - """Converts a string to a list of tokens. For Blt, we work directly with byte values.""" + """ + Args: + text: TextInput + Returns a tokenized string. For Blt, we work directly with byte values. + """ return [str(byte_val) for byte_val in text.encode("utf-8", errors="ignore")] def build_inputs_with_special_tokens( @@ -176,13 +184,13 @@ def build_inputs_with_special_tokens( - pair of sequences: ` A B ` Args: - token_ids_0 (`List[int]`): + token_ids_0 (`list[int]`): List of IDs to which the special tokens will be added. - token_ids_1 (`List[int]`, *optional*): + token_ids_1 (`list[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: - `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ bos = [self.bos_token_id] if self.add_bos_token else [] eos = [self.eos_token_id] if self.add_eos_token else [] @@ -199,15 +207,15 @@ def get_special_tokens_mask( special tokens using the tokenizer `prepare_for_model` method. Args: - token_ids_0 (`List[int]`): + token_ids_0 (`list[int]`): List of IDs. - token_ids_1 (`List[int]`, *optional*): + token_ids_1 (`list[int]`, *optional*): Optional second list of IDs for sequence pairs. already_has_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not the token list is already formatted with special tokens for the model. Returns: - `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: return super().get_special_tokens_mask( @@ -221,13 +229,33 @@ def get_special_tokens_mask( return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + ([0] * len(token_ids_1)) + eos_token_id - def get_vocab_size(self) -> int: - """Get vocab size like the original tokenizer.""" - return self.vocab_size - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: - # Blt doesn't require external vocabulary files since it uses byte-level tokenization - return () + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the vocabulary file name. + + Returns: + `tuple[str]`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return () + + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + vocab = self.get_vocab() + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + return (vocab_file,) __all__ = ["BltTokenizer"] diff --git a/src/transformers/models/blt/tokenization_blt_fast.py b/src/transformers/models/blt/tokenization_blt_fast.py new file mode 100644 index 000000000000..ff08d4670486 --- /dev/null +++ b/src/transformers/models/blt/tokenization_blt_fast.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast tokenization classes for Blt.""" + +from typing import Optional + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_blt import BltTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "tokenizer_file": "tokenizer.json"} + + +class BltTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" Blt tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level + tokenization where each byte is treated as a token. + + This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset. + It supports special tokens for beginning of sequence (BOS), end of sequence (EOS), + beginning of example (BOE), and padding (PAD). + + ```python + >>> from transformers import BltTokenizerFast + + >>> tokenizer = BltTokenizerFast.from_pretrained("path/to/blt/model") + >>> tokenizer("Hello world")["input_ids"] + [1, 72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100] + + >>> tokenizer(" Hello world")["input_ids"] + [1, 32, 72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100] + ``` + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`, *optional*): + Path to the vocabulary file. + tokenizer_file (`str`, *optional*): + Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that + contains everything needed to load the tokenizer. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of example token used for marking the start of individual examples in a sequence. + additional_special_tokens (`list[str]` or `list[tokenizers.AddedToken]`, *optional*): + A list of additional special tokens to be added to the tokenizer's vocabulary. + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = BltTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + bos_token="", + eos_token="", + pad_token="", + unk_token="", + boe_token="", + additional_special_tokens=None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + spaces_between_special_tokens=False, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + unk_token=unk_token, + boe_token=boe_token, + additional_special_tokens=additional_special_tokens, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + # Store Blt-specific parameters + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating and + adding special tokens. A Blt sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + bos = [self.bos_token_id] if self.add_bos_token else [] + eos = [self.eos_token_id] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos + token_ids_0 + eos + return bos + token_ids_0 + eos + token_ids_1 + eos + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["BltTokenizerFast"] diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 9336d1d0b1c2..dc4703974781 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -316,7 +316,7 @@ def test_model(self): **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False ) - output_text = tokenizer.decode(generated_ids[0]) + output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(output_text, EXPECTED_TEXT) @slow @@ -422,7 +422,7 @@ def test_model_bf16(self): **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False ) - output_text = tokenizer.decode(generated_ids[0]) + output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(output_text, EXPECTED_TEXT) @slow @@ -530,7 +530,7 @@ def test_model_eager(self): **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False ) - output_text = tokenizer.decode(generated_ids[0]) + output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(output_text, EXPECTED_TEXT) @slow @@ -557,5 +557,5 @@ def test_model_bf16_static_cache(self): **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False ) - output_text = tokenizer.decode(generated_ids[0]) + output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(output_text, EXPECTED_TEXT) diff --git a/tests/models/blt/test_tokenization_blt.py b/tests/models/blt/test_tokenization_blt.py index b14cd3ab64a8..e9e7fde10f48 100644 --- a/tests/models/blt/test_tokenization_blt.py +++ b/tests/models/blt/test_tokenization_blt.py @@ -14,9 +14,8 @@ import unittest -from transformers import BltTokenizer +from transformers import BltTokenizer, BltTokenizerFast from transformers.testing_utils import require_tokenizers -from transformers.tokenization_utils import AddedToken from ...test_tokenization_common import TokenizerTesterMixin @@ -25,9 +24,9 @@ class BltTokenizationTest(TokenizerTesterMixin, unittest.TestCase): from_pretrained_id = [] tokenizer_class = BltTokenizer - rust_tokenizer_class = None + rust_tokenizer_class = BltTokenizerFast - test_rust_tokenizer = False + test_rust_tokenizer = True test_sentencepiece = False test_slow_tokenizer = True from_pretrained_kwargs = {} @@ -43,172 +42,135 @@ def get_tokenizers(self, **kwargs): kwargs.update({"add_bos_token": True, "add_eos_token": False}) return super().get_tokenizers(**kwargs) - def test_simple_encode_decode(self): + def test_unicode_handling(self): tokenizer = BltTokenizer(add_bos_token=False, add_eos_token=False) - text = "Hello" + # Test Unicode character (é) + text = "café" encoded = tokenizer.encode(text, add_special_tokens=False) - - # "Hello" in UTF-8 bytes: [72, 101, 108, 108, 111] - # With offset +4: [76, 105, 112, 112, 115] - expected = [76, 105, 112, 112, 115] + # "café" in UTF-8 bytes: [99, 97, 102, 195, 169] (é = 195, 169) + expected = [byte_val + tokenizer.offset for byte_val in [99, 97, 102, 195, 169]] self.assertEqual(encoded, expected) - decoded = tokenizer.decode(encoded) self.assertEqual(decoded, text) - def test_special_tokens_encoding(self): - tokenizer = BltTokenizer(add_bos_token=True, add_eos_token=True) - - text = "Hi" - encoded = tokenizer.encode(text, add_special_tokens=True) - - # "Hi" in UTF-8 bytes: [72, 105] -> with offset: [76, 109] - # With BOS (1) and EOS (2): [1, 76, 109, 2] - expected = [1, 76, 109, 2] + # Test emoji + text = "Hello 👋" + encoded = tokenizer.encode(text, add_special_tokens=False) + # "Hello 👋" in UTF-8 bytes: [72, 101, 108, 108, 111, 32, 240, 159, 145, 139] (👋 = 240, 159, 145, 139) + expected = [byte_val + tokenizer.offset for byte_val in [72, 101, 108, 108, 111, 32, 240, 159, 145, 139]] self.assertEqual(encoded, expected) + decoded = tokenizer.decode(encoded) + self.assertEqual(decoded, text) - def test_tokenize_method(self): - tokenizer = BltTokenizer() - - text = "ABC" - tokens = tokenizer._tokenize(text) - - # "ABC" in UTF-8 bytes: [65, 66, 67] - expected = ["65", "66", "67"] - self.assertEqual(tokens, expected) - - def test_token_conversion(self): - """Test token to ID and ID to token conversion""" - tokenizer = BltTokenizer() - - # Test byte token conversion - token = "65" # Byte value for 'A' - token_id = tokenizer._convert_token_to_id(token) - self.assertEqual(token_id, 69) # 65 + 4 offset - - converted_token = tokenizer._convert_id_to_token(token_id) - self.assertEqual(converted_token, token) - - bos_id = tokenizer._convert_token_to_id(str(tokenizer.bos_token)) - self.assertEqual(bos_id, 1) - - bos_token = tokenizer._convert_id_to_token(1) - self.assertEqual(bos_token, str(tokenizer.bos_token)) - - def test_convert_tokens_to_string(self): - tokenizer = BltTokenizer() - - tokens = ["72", "101", "108", "108", "111"] # "Hello" in bytes - result = tokenizer.convert_tokens_to_string(tokens) - self.assertEqual(result, "Hello") - - # Test with special tokens mixed in (should be ignored) - tokens_with_special = [str(tokenizer.bos_token), "72", "105", str(tokenizer.eos_token)] - result = tokenizer.convert_tokens_to_string(tokens_with_special) - self.assertEqual(result, "Hi") - - def test_unicode_handling(self): + def test_special_characters_and_unicode(self): tokenizer = BltTokenizer(add_bos_token=False, add_eos_token=False) - # Test Unicode character (é) - text = "café" + # Test special characters with unicode + text = "Hello, 世界! 🌍" encoded = tokenizer.encode(text, add_special_tokens=False) + expected = [ + byte_val + tokenizer.offset + for byte_val in [72, 101, 108, 108, 111, 44, 32, 228, 184, 150, 231, 149, 140, 33, 32, 240, 159, 140, 141] + ] + self.assertEqual(encoded, expected) decoded = tokenizer.decode(encoded) self.assertEqual(decoded, text) - # Test emoji - text = "Hello 👋" + # Test mixed special characters, numbers, and unicode + text = "Price: $100.50 €75.25 🎉" encoded = tokenizer.encode(text, add_special_tokens=False) + expected = [ + byte_val + tokenizer.offset + for byte_val in [ + 80, + 114, + 105, + 99, + 101, + 58, + 32, + 36, + 49, + 48, + 48, + 46, + 53, + 48, + 32, + 226, + 130, + 172, + 55, + 53, + 46, + 50, + 53, + 32, + 240, + 159, + 142, + 137, + ] + ] + self.assertEqual(encoded, expected) + decoded = tokenizer.decode(encoded) + self.assertEqual(decoded, text) + + # Test control characters with unicode + text = "Line1\nLine2\tTabbed 中文" + encoded = tokenizer.encode(text, add_special_tokens=False) + expected = [ + byte_val + tokenizer.offset + for byte_val in [ + 76, + 105, + 110, + 101, + 49, + 10, + 76, + 105, + 110, + 101, + 50, + 9, + 84, + 97, + 98, + 98, + 101, + 100, + 32, + 228, + 184, + 173, + 230, + 150, + 135, + ] + ] + self.assertEqual(encoded, expected) decoded = tokenizer.decode(encoded) self.assertEqual(decoded, text) def test_empty_and_whitespace(self): tokenizer = BltTokenizer(add_bos_token=False, add_eos_token=False) - # Test empty string encoded = tokenizer.encode("", add_special_tokens=False) self.assertEqual(encoded, []) decoded = tokenizer.decode(encoded) self.assertEqual(decoded, "") - # Test single space encoded = tokenizer.encode(" ", add_special_tokens=False) - self.assertEqual(encoded, [36]) # 32 (space) + 4 offset + self.assertEqual(encoded, [32 + tokenizer.offset]) # space + offset decoded = tokenizer.decode(encoded) self.assertEqual(decoded, " ") - def test_build_inputs_with_special_tokens(self): - tokenizer = BltTokenizer(add_bos_token=True, add_eos_token=True) - - # Single sequence - token_ids = [76, 109] # "Hi" encoded (H=72+4=76, i=105+4=109) - result = tokenizer.build_inputs_with_special_tokens(token_ids) - expected = [1, 76, 109, 2] # BOS + tokens + EOS - self.assertEqual(result, expected) - - # Pair of sequences - token_ids_1 = [76, 109] # "Hi" - token_ids_2 = [66, 121, 101] # "Bye" - result = tokenizer.build_inputs_with_special_tokens(token_ids_1, token_ids_2) - expected = [1, 76, 109, 2, 66, 121, 101, 2] # BOS + seq1 + EOS + seq2 + EOS - self.assertEqual(result, expected) - - def test_special_tokens_mask(self): - tokenizer = BltTokenizer(add_bos_token=True, add_eos_token=True) - - token_ids = [76, 109] # "Hi" encoded (H=72+4=76, i=105+4=109) - mask = tokenizer.get_special_tokens_mask(token_ids) - expected = [1, 0, 0, 1] # BOS=1, content=0, content=0, EOS=1 - self.assertEqual(mask, expected) - - def test_add_special_tokens_flags(self): - tokenizer1 = BltTokenizer(add_bos_token=True, add_eos_token=True) - encoded1 = tokenizer1.encode("Hi", add_special_tokens=True) - self.assertEqual(encoded1[0], 1) # BOS - self.assertEqual(encoded1[-1], 2) # EOS - - tokenizer2 = BltTokenizer(add_bos_token=False, add_eos_token=False) - encoded2 = tokenizer2.encode("Hi", add_special_tokens=True) - self.assertNotEqual(encoded2[0], 1) # No BOS - self.assertNotEqual(encoded2[-1], 2) # No EOS - - # Test with only BOS - tokenizer3 = BltTokenizer(add_bos_token=True, add_eos_token=False) - encoded3 = tokenizer3.encode("Hi", add_special_tokens=True) - self.assertEqual(encoded3[0], 1) # BOS - self.assertNotEqual(encoded3[-1], 2) # No EOS - - def test_added_tokens(self): - tokenizer = BltTokenizer() - - custom_token = AddedToken("", normalized=False, special=True) - tokenizer.add_tokens([custom_token]) - - self.assertIn("", tokenizer.get_vocab()) - - token_id = tokenizer._convert_token_to_id("") - self.assertIsInstance(token_id, int) - - back_token = tokenizer._convert_id_to_token(token_id) - self.assertEqual(back_token, "") - - @unittest.skip("Blt is byte-level, special tokens are encoded as bytes") - def test_add_special_tokens(self): - pass - @unittest.skip("Blt byte-level tokenization doesn't handle pretokenized inputs the same way") def test_pretokenized_inputs(self): pass - @unittest.skip("Blt encodes added tokens as bytes, not single tokens") - def test_add_tokens_tokenizer(self): - pass - - @unittest.skip("Blt tokenizer serialization needs additional work for added tokens") - def test_save_and_load_tokenizer(self): - pass - if __name__ == "__main__": unittest.main() From 180042d9b405b342fbe36841359f137fbf212d43 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 15 Sep 2025 12:22:46 +0000 Subject: [PATCH 133/139] no more modulelist --- .../models/blt/convert_blt_weights_to_hf.py | 39 +++++++++++++++- src/transformers/models/blt/modeling_blt.py | 46 ++++++++----------- src/transformers/models/blt/modular_blt.py | 46 ++++++++----------- 3 files changed, 75 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/blt/convert_blt_weights_to_hf.py b/src/transformers/models/blt/convert_blt_weights_to_hf.py index 7b5fc30d641e..c5e96f474214 100644 --- a/src/transformers/models/blt/convert_blt_weights_to_hf.py +++ b/src/transformers/models/blt/convert_blt_weights_to_hf.py @@ -100,7 +100,7 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> dict[str "vocab_size": unified_config.get("vocab_size", 256), "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_decoder", False), "cross_attn_k": unified_config.get("cross_attn_k", 2), - "hidden_size_global": unified_config.get("hidden_size_global", 2048), + "hidden_size_global": unified_config.get("dim_global", 2048), "hidden_size": decoder_hidden_size, "num_attention_heads": unified_config.get("n_heads_local_decoder", 16), "num_key_value_heads": unified_config.get("n_kv_heads"), @@ -203,6 +203,39 @@ def apply_weight_mapping(state_dict: dict[str, torch.Tensor]) -> dict[str, torch return new_state_dict +def convert_hash_embeddings_to_fused( + unified_weights: dict[str, torch.Tensor], config: dict[str, Any] +) -> dict[str, torch.Tensor]: + """Convert ModuleList hash embeddings to nn.embedding format""" + original_keys_format = [ + key + for key in unified_weights.keys() + if "encoder_hash_tok_embedding." in key and ".weight" in key and key.split(".")[-2].isdigit() + ] + + num_embeddings = config.get("encoder_hash_byte_group_nb_functions", 1) * len( + config.get("encoder_hash_byte_group_size", [3, 4, 5, 6, 7, 8]) + ) + vocab_size = config.get("encoder_hash_byte_group_vocab", 500002) + hidden_size = config.get("encoder_config", {}).get("hidden_size", 1024) + + fused_weight = torch.zeros(vocab_size * num_embeddings, hidden_size) + + sorted_keys = sorted(original_keys_format, key=lambda k: int(k.split(".")[-2])) + + for i, old_key in enumerate(sorted_keys): + start_idx = i * vocab_size + end_idx = (i + 1) * vocab_size + fused_weight[start_idx:end_idx] = unified_weights[old_key] + logger.info(f"Copied {old_key} to indices {start_idx}:{end_idx}") + del unified_weights[old_key] + + fused_key = "model.encoder_hash_tok_embedding.weight" + unified_weights[fused_key] = fused_weight + + return unified_weights + + def merge_weights(weights_path: str, entropy_weights_path: str) -> dict[str, torch.Tensor]: main_weights = load_file(weights_path) @@ -301,6 +334,8 @@ def convert_hf_blt_to_unified( unified_config = merge_configurations(config_path, entropy_params_path) unified_weights = merge_weights(weights_path, entropy_weights_path) + unified_weights = convert_hash_embeddings_to_fused(unified_weights, unified_config) + os.makedirs(output_dir, exist_ok=True) config_path = os.path.join(output_dir, config_name) @@ -336,7 +371,7 @@ def main(): parser.add_argument( "--model_id", type=str, - default="facebook/blt-1b", + default="facebook/blt-7b", ) parser.add_argument( "--output_dir", diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 870b019ecfbd..443f8bd09efe 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -819,7 +819,6 @@ def patch_lengths_from_entropies( Computes patch lengths from token entropies. Depending on whether a threshold is provided, the function uses either: - - Top-k selection based on entropy (when `threshold` is None), or - Thresholding the entropy values (when `threshold` is set). """ @@ -834,29 +833,23 @@ def patch_lengths_from_entropies( # Ignore first token entropy (BOS) entropies = entropies[:, 1:] - if threshold is None: - # Use top-k entropy values to define patch start points - num_patches = sequence_length // patch_size - topk_indices = entropies.topk(num_patches - 2, dim=1).indices - patch_starts = topk_indices.sort(dim=1).values - else: - # Threshold the entropy values to define patch start points - patch_mask = entropies > threshold + # Threshold the entropy values to define patch start points + patch_mask = entropies > threshold - seq_len = patch_mask.shape[1] + seq_len = patch_mask.shape[1] - # Create patch IDs (token indices), and add a sentinel to ensure alignment - token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) - sentinel = torch.full_like(token_indices, seq_len) - padded_indices = torch.cat([token_indices, sentinel], dim=1) + # Create patch IDs (token indices), and add a sentinel to ensure alignment + token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) + sentinel = torch.full_like(token_indices, seq_len) + padded_indices = torch.cat([token_indices, sentinel], dim=1) - # Pad mask with inverse to align sentinel correctly - padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) + # Pad mask with inverse to align sentinel correctly + padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) - # Select indices where mask is True - patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) - max_valid_patches = patch_mask.sum(dim=1).max() - patch_starts = patch_starts[:, :max_valid_patches] + # Select indices where mask is True + patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) + max_valid_patches = patch_mask.sum(dim=1).max() + patch_starts = patch_starts[:, :max_valid_patches] # Offset patch starts to account for the two initial tokens patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) @@ -920,7 +913,7 @@ def byte_group_hash_function( def compute_hash_embeddings( local_encoder_tokens: torch.Tensor, local_encoder, - encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_tok_embedding: nn.Embedding, encoder_hash_byte_group_nb_functions: int, encoder_hash_byte_group_size: list, encoder_hash_byte_group_vocab: int, @@ -947,7 +940,9 @@ def compute_hash_embeddings( prime = primes[func_nb % len(primes)] # Cycle through primes if more functions than primes for group_size in encoder_hash_byte_group_size: hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab) - embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) + # Apply offset to get the correct slice of the fused embedding + offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab + embeddings += encoder_hash_tok_embedding(offset_hash_ids) embedding_idx += 1 return embeddings @@ -1041,11 +1036,8 @@ def __init__(self, config: BltConfig): self.global_transformer = BltGlobalTransformer(config.global_config) self.local_decoder = BltLocalDecoder(config.decoder_config) num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size) - embeddings = [ - nn.Embedding(config.encoder_hash_byte_group_vocab, config.encoder_config.hidden_size) - for _ in range(num_embeddings) - ] - self.encoder_hash_tok_embedding = nn.ModuleList(embeddings) + total_vocab_size = config.encoder_hash_byte_group_vocab * num_embeddings + self.encoder_hash_tok_embedding = nn.Embedding(total_vocab_size, config.encoder_config.hidden_size) if self.config.patch_in_forward: self.patcher = BltPatcher(config.patcher_config) self.patcher.eval() diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index bd886a1398d0..0b04966d97fe 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -104,7 +104,7 @@ def byte_group_hash_function( def compute_hash_embeddings( local_encoder_tokens: torch.Tensor, local_encoder, - encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_tok_embedding: nn.Embedding, encoder_hash_byte_group_nb_functions: int, encoder_hash_byte_group_size: list, encoder_hash_byte_group_vocab: int, @@ -131,7 +131,9 @@ def compute_hash_embeddings( prime = primes[func_nb % len(primes)] # Cycle through primes if more functions than primes for group_size in encoder_hash_byte_group_size: hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab) - embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) + # Apply offset to get the correct slice of the fused embedding + offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab + embeddings += encoder_hash_tok_embedding(offset_hash_ids) embedding_idx += 1 return embeddings @@ -730,7 +732,6 @@ def patch_lengths_from_entropies( Computes patch lengths from token entropies. Depending on whether a threshold is provided, the function uses either: - - Top-k selection based on entropy (when `threshold` is None), or - Thresholding the entropy values (when `threshold` is set). """ @@ -745,29 +746,23 @@ def patch_lengths_from_entropies( # Ignore first token entropy (BOS) entropies = entropies[:, 1:] - if threshold is None: - # Use top-k entropy values to define patch start points - num_patches = sequence_length // patch_size - topk_indices = entropies.topk(num_patches - 2, dim=1).indices - patch_starts = topk_indices.sort(dim=1).values - else: - # Threshold the entropy values to define patch start points - patch_mask = entropies > threshold + # Threshold the entropy values to define patch start points + patch_mask = entropies > threshold - seq_len = patch_mask.shape[1] + seq_len = patch_mask.shape[1] - # Create patch IDs (token indices), and add a sentinel to ensure alignment - token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) - sentinel = torch.full_like(token_indices, seq_len) - padded_indices = torch.cat([token_indices, sentinel], dim=1) + # Create patch IDs (token indices), and add a sentinel to ensure alignment + token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) + sentinel = torch.full_like(token_indices, seq_len) + padded_indices = torch.cat([token_indices, sentinel], dim=1) - # Pad mask with inverse to align sentinel correctly - padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) + # Pad mask with inverse to align sentinel correctly + padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) - # Select indices where mask is True - patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) - max_valid_patches = patch_mask.sum(dim=1).max() - patch_starts = patch_starts[:, :max_valid_patches] + # Select indices where mask is True + patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) + max_valid_patches = patch_mask.sum(dim=1).max() + patch_starts = patch_starts[:, :max_valid_patches] # Offset patch starts to account for the two initial tokens patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) @@ -791,11 +786,8 @@ def __init__(self, config: BltConfig): self.global_transformer = BltGlobalTransformer(config.global_config) self.local_decoder = BltLocalDecoder(config.decoder_config) num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size) - embeddings = [ - nn.Embedding(config.encoder_hash_byte_group_vocab, config.encoder_config.hidden_size) - for _ in range(num_embeddings) - ] - self.encoder_hash_tok_embedding = nn.ModuleList(embeddings) + total_vocab_size = config.encoder_hash_byte_group_vocab * num_embeddings + self.encoder_hash_tok_embedding = nn.Embedding(total_vocab_size, config.encoder_config.hidden_size) if self.config.patch_in_forward: self.patcher = BltPatcher(config.patcher_config) self.patcher.eval() From c495819b8245ed2a12cc044edeb0fdea56c1ecc1 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 18 Sep 2025 08:44:40 +0000 Subject: [PATCH 134/139] tok auto --- src/transformers/models/auto/tokenization_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 3b00eeed8114..e200331b03d1 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -105,7 +105,7 @@ ("blip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("blip-2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)), - ("blt", ("BltTokenizer", None)), + ("blt", ("BltTokenizer", "BltTokenizerFast" if is_tokenizers_available() else None)), ("bridgetower", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), ("bros", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("byt5", ("ByT5Tokenizer", None)), From 5607b5add1d83bedfe9db3d1fa9253475589861d Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 18 Sep 2025 13:36:53 +0000 Subject: [PATCH 135/139] rm tokenizersss --- src/transformers/convert_slow_tokenizer.py | 36 --- .../models/auto/tokenization_auto.py | 2 +- .../models/blt/convert_blt_weights_to_hf.py | 50 +++- .../models/blt/tokenization_blt.py | 261 ------------------ .../models/blt/tokenization_blt_fast.py | 154 ----------- tests/models/blt/test_tokenization_blt.py | 176 ------------ 6 files changed, 50 insertions(+), 629 deletions(-) delete mode 100644 src/transformers/models/blt/tokenization_blt.py delete mode 100644 src/transformers/models/blt/tokenization_blt_fast.py delete mode 100644 tests/models/blt/test_tokenization_blt.py diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 7098296dd010..a9e7c9bff5bc 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -1540,41 +1540,6 @@ def post_processor(self): ) -class BltConverter(Converter): - def converted(self) -> Tokenizer: - byte_encoder = bytes_to_unicode() - - vocab = {} - - vocab[""] = 0 - vocab[""] = 1 - vocab[""] = 2 - vocab[""] = 3 - - # Add byte tokens using unicode characters (compatible with ByteLevel) - offset = 4 # Start after special tokens - for byte_val, unicode_char in byte_encoder.items(): - vocab[unicode_char] = byte_val + offset - - tokenizer = Tokenizer(BPE(vocab, [], continuing_subword_prefix="", end_of_word_suffix="")) - - tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) - tokenizer.decoder = decoders.ByteLevel() - - bos = str(self.original_tokenizer.bos_token) - bos_token_id = self.original_tokenizer.bos_token_id - - tokenizer.post_processor = processors.TemplateProcessing( - single=f"{bos}:0 $A:0", - pair=f"{bos}:0 $A:0 $B:1", - special_tokens=[ - (bos, bos_token_id), - ], - ) - - return tokenizer - - # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode def bytes_to_unicode(): """ @@ -1688,7 +1653,6 @@ def converted(self) -> Tokenizer: "BertTokenizer": BertConverter, "BigBirdTokenizer": BigBirdConverter, "BlenderbotTokenizer": BlenderbotConverter, - "BltTokenizer": BltConverter, "CamembertTokenizer": CamembertConverter, "CLIPTokenizer": CLIPConverter, "CodeGenTokenizer": GPT2Converter, diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index e200331b03d1..52726fd6200a 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -105,7 +105,7 @@ ("blip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("blip-2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)), - ("blt", ("BltTokenizer", "BltTokenizerFast" if is_tokenizers_available() else None)), + ("blt", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("bridgetower", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), ("bros", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("byt5", ("ByT5Tokenizer", None)), diff --git a/src/transformers/models/blt/convert_blt_weights_to_hf.py b/src/transformers/models/blt/convert_blt_weights_to_hf.py index c5e96f474214..f9decff3a1f8 100644 --- a/src/transformers/models/blt/convert_blt_weights_to_hf.py +++ b/src/transformers/models/blt/convert_blt_weights_to_hf.py @@ -7,7 +7,11 @@ import torch from huggingface_hub import hf_hub_download, upload_folder from safetensors.torch import load_file, save_file +from tokenizers import Tokenizer, decoders, pre_tokenizers, processors +from tokenizers.models import BPE +from transformers import PreTrainedTokenizerFast +from transformers.convert_slow_tokenizer import bytes_to_unicode from transformers.utils import logging as transformers_logging @@ -275,9 +279,10 @@ def merge_weights(weights_path: str, entropy_weights_path: str) -> dict[str, tor def create_tokenizer_config(output_dir: str, config: dict[str, Any]): tokenizer_config = { - "tokenizer_class": "BltTokenizer", + "tokenizer_class": "PreTrainedTokenizerFast", "vocab_size": config.get("vocab_size", 256), "model_max_length": config.get("max_seqlen", 1024), + "model_input_names": ["input_ids", "attention_mask"], "add_bos_token": True, "add_eos_token": True, "bos_token": "", @@ -291,6 +296,47 @@ def create_tokenizer_config(output_dir: str, config: dict[str, Any]): json.dump(tokenizer_config, f, indent=2) +def create_tokenizer_json(output_dir: str, config: dict[str, Any]): + byte_encoder = bytes_to_unicode() + + vocab: dict[str, int] = {} + vocab[""] = 0 + vocab[""] = 1 + vocab[""] = 2 + vocab[""] = 3 + + offset = 4 + for byte_val, unicode_char in byte_encoder.items(): + vocab[unicode_char] = byte_val + offset + + backend = Tokenizer( + BPE(vocab=vocab, merges=[], continuing_subword_prefix="", end_of_word_suffix="", fuse_unk=False) + ) + backend.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) + backend.decoder = decoders.ByteLevel() + + bos = config.get("bos_token", "") + backend.post_processor = processors.TemplateProcessing( + single=f"{bos}:0 $A:0", + pair=f"{bos}:0 $A:0 $B:1", + special_tokens=[(bos, 1)], + ) + + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=backend, + bos_token=config.get("bos_token", ""), + eos_token=config.get("eos_token", ""), + pad_token=config.get("pad_token", ""), + unk_token=config.get("unk_token", ""), + ) + + tokenizer.add_bos_token = bool(config.get("add_bos_token", True)) + tokenizer.add_eos_token = bool(config.get("add_eos_token", True)) + + tokenizer.save_pretrained(output_dir) + logger.info(f"Saved tokenizer.json to {os.path.join(output_dir, 'tokenizer.json')}") + + def push_to_hub( local_dir: str, repo_id: str, @@ -348,6 +394,8 @@ def convert_hf_blt_to_unified( weights_path = os.path.join(output_dir, weights_name) save_file(unified_weights, weights_path) + create_tokenizer_json(output_dir=output_dir, config=unified_config) + create_tokenizer_config(output_dir, unified_config) logger.info(f"Conversion completed, model saved to: {output_dir}") diff --git a/src/transformers/models/blt/tokenization_blt.py b/src/transformers/models/blt/tokenization_blt.py deleted file mode 100644 index f2eadf06317b..000000000000 --- a/src/transformers/models/blt/tokenization_blt.py +++ /dev/null @@ -1,261 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tokenization classes for Blt.""" - -import json -import os -from typing import Optional - -from ...tokenization_utils import AddedToken, PreTrainedTokenizer -from ...utils import logging - - -logger = logging.get_logger(__name__) - -VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"} - - -class BltTokenizer(PreTrainedTokenizer): - """ - Construct a Blt tokenizer. Based on byte-level tokenization where each byte is treated as a token. - - This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset. - It supports special tokens for beginning of sequence (BOS), end of sequence (EOS), - beginning of example (BOE), and padding (PAD). - - Args: - bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. - eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The end of sequence token. - pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by - attention mechanisms or loss computation. - unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this - token instead. - boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The beginning of example token used for marking the start of individual examples in a sequence. - additional_special_tokens (`list[str]` or `list[tokenizers.AddedToken]`, *optional*): - A list of additional special tokens to be added to the tokenizer's vocabulary. - add_bos_token (`bool`, *optional*, defaults to `True`): - Whether or not to add an `bos_token` at the start of sequences. - add_eos_token (`bool`, *optional*, defaults to `False`): - Whether or not to add an `eos_token` at the end of sequences. - clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): - Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like - extra spaces. - spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to add spaces between special tokens. - """ - - vocab_files_names = VOCAB_FILES_NAMES - model_input_names = ["input_ids", "attention_mask"] - - def __init__( - self, - bos_token="", - eos_token="", - pad_token="", - unk_token="", - boe_token="", - additional_special_tokens=None, - add_bos_token=True, - add_eos_token=False, - clean_up_tokenization_spaces=False, - spaces_between_special_tokens=False, - **kwargs, - ): - # Store Blt-specific parameters first - self.add_bos_token = add_bos_token - self.add_eos_token = add_eos_token - self.byte_vocab_size = 256 # byte units (0-255) - - boe_token = AddedToken(boe_token, lstrip=False, rstrip=False) if isinstance(boe_token, str) else boe_token - bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token - pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token - - self._added_tokens_decoder = {0: boe_token, 1: bos_token, 2: eos_token, 3: pad_token} - self.offset = len(self._added_tokens_decoder) - self._utf_vocab_size = 2**8 # utf is 8 bits - - super().__init__( - bos_token=bos_token, - eos_token=eos_token, - pad_token=pad_token, - unk_token=unk_token, - boe_token=boe_token, - additional_special_tokens=additional_special_tokens, - add_bos_token=add_bos_token, - add_eos_token=add_eos_token, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - spaces_between_special_tokens=spaces_between_special_tokens, - **kwargs, - ) - - @property - def vocab_size(self): - """Returns vocab size""" - return self._utf_vocab_size + self.offset - - def get_vocab(self): - """Returns vocab as a dict""" - vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} - vocab.update(self.added_tokens_encoder) - return vocab - - def _convert_token_to_id(self, token: str) -> int: - """Converts a token (str) in an id using the vocab.""" - if token in self.added_tokens_encoder: - return self.added_tokens_encoder[token] - - # For byte tokens, convert string representation to integer and add offset - try: - byte_val = int(token) - if 0 <= byte_val <= 255: - return byte_val + self.offset - except ValueError: - pass - - return self.unk_token_id - - def _convert_id_to_token(self, index: int) -> str: - """Converts an index (integer) in a token (str) using the vocab.""" - for token, token_id in self.added_tokens_encoder.items(): - if token_id == index: - return token - - # For byte tokens, subtract offset and convert to string - if self.offset <= index < self.vocab_size: - return str(index - self.offset) - - return str(self.unk_token) - - def convert_tokens_to_string(self, tokens: list[str]) -> str: - """Converts a sequence of tokens (string) in a single string.""" - byte_values = [] - - for token in tokens: - if token in self.added_tokens_decoder: - tok_string = self.added_tokens_decoder[token].encode("utf-8") - byte_values.extend(tok_string) - elif token in self.added_tokens_encoder: - tok_string = token.encode("utf-8") - byte_values.extend(tok_string) - else: - try: - byte_val = int(token) - if 0 <= byte_val <= 255: - byte_values.append(byte_val) - except ValueError: - continue - - return bytes(byte_values).decode("utf-8", errors="ignore") - - def _tokenize(self, text: str, **kwargs) -> list[str]: - """ - Args: - text: TextInput - Returns a tokenized string. For Blt, we work directly with byte values. - """ - return [str(byte_val) for byte_val in text.encode("utf-8", errors="ignore")] - - def build_inputs_with_special_tokens( - self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None - ) -> list[int]: - """ - Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating and - adding special tokens. A Blt sequence has the following format: - - - single sequence: ` X ` - - pair of sequences: ` A B ` - - Args: - token_ids_0 (`list[int]`): - List of IDs to which the special tokens will be added. - token_ids_1 (`list[int]`, *optional*): - Optional second list of IDs for sequence pairs. - - Returns: - `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. - """ - bos = [self.bos_token_id] if self.add_bos_token else [] - eos = [self.eos_token_id] if self.add_eos_token else [] - - if token_ids_1 is None: - return bos + token_ids_0 + eos - return bos + token_ids_0 + eos + token_ids_1 + eos - - def get_special_tokens_mask( - self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False - ) -> list[int]: - """ - Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer `prepare_for_model` method. - - Args: - token_ids_0 (`list[int]`): - List of IDs. - token_ids_1 (`list[int]`, *optional*): - Optional second list of IDs for sequence pairs. - already_has_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not the token list is already formatted with special tokens for the model. - - Returns: - `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. - """ - if already_has_special_tokens: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True - ) - - bos_token_id = [1] if self.add_bos_token else [] - eos_token_id = [1] if self.add_eos_token else [] - - if token_ids_1 is None: - return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + ([0] * len(token_ids_1)) + eos_token_id - - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: - """ - Save the vocabulary and special tokens file to a directory. - - Args: - save_directory (`str`): - The directory in which to save the vocabulary. - filename_prefix (`str`, *optional*): - An optional prefix to add to the vocabulary file name. - - Returns: - `tuple[str]`: Paths to the files saved. - """ - if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) should be a directory") - return () - - vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] - ) - - vocab = self.get_vocab() - - with open(vocab_file, "w", encoding="utf-8") as f: - f.write(json.dumps(vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n") - - return (vocab_file,) - - -__all__ = ["BltTokenizer"] diff --git a/src/transformers/models/blt/tokenization_blt_fast.py b/src/transformers/models/blt/tokenization_blt_fast.py deleted file mode 100644 index ff08d4670486..000000000000 --- a/src/transformers/models/blt/tokenization_blt_fast.py +++ /dev/null @@ -1,154 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Fast tokenization classes for Blt.""" - -from typing import Optional - -from ...tokenization_utils_fast import PreTrainedTokenizerFast -from ...utils import logging -from .tokenization_blt import BltTokenizer - - -logger = logging.get_logger(__name__) - -VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "tokenizer_file": "tokenizer.json"} - - -class BltTokenizerFast(PreTrainedTokenizerFast): - """ - Construct a "fast" Blt tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level - tokenization where each byte is treated as a token. - - This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset. - It supports special tokens for beginning of sequence (BOS), end of sequence (EOS), - beginning of example (BOE), and padding (PAD). - - ```python - >>> from transformers import BltTokenizerFast - - >>> tokenizer = BltTokenizerFast.from_pretrained("path/to/blt/model") - >>> tokenizer("Hello world")["input_ids"] - [1, 72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100] - - >>> tokenizer(" Hello world")["input_ids"] - [1, 32, 72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100] - ``` - - This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should - refer to this superclass for more information regarding those methods. - - Args: - vocab_file (`str`, *optional*): - Path to the vocabulary file. - tokenizer_file (`str`, *optional*): - Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that - contains everything needed to load the tokenizer. - bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. - eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The end of sequence token. - pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by - attention mechanisms or loss computation. - unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this - token instead. - boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The beginning of example token used for marking the start of individual examples in a sequence. - additional_special_tokens (`list[str]` or `list[tokenizers.AddedToken]`, *optional*): - A list of additional special tokens to be added to the tokenizer's vocabulary. - add_bos_token (`bool`, *optional*, defaults to `True`): - Whether or not to add an `bos_token` at the start of sequences. - add_eos_token (`bool`, *optional*, defaults to `False`): - Whether or not to add an `eos_token` at the end of sequences. - clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): - Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like - extra spaces. - spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to add spaces between special tokens. - """ - - vocab_files_names = VOCAB_FILES_NAMES - model_input_names = ["input_ids", "attention_mask"] - slow_tokenizer_class = BltTokenizer - - def __init__( - self, - vocab_file=None, - tokenizer_file=None, - bos_token="", - eos_token="", - pad_token="", - unk_token="", - boe_token="", - additional_special_tokens=None, - add_bos_token=True, - add_eos_token=False, - clean_up_tokenization_spaces=False, - spaces_between_special_tokens=False, - **kwargs, - ): - super().__init__( - vocab_file=vocab_file, - tokenizer_file=tokenizer_file, - bos_token=bos_token, - eos_token=eos_token, - pad_token=pad_token, - unk_token=unk_token, - boe_token=boe_token, - additional_special_tokens=additional_special_tokens, - add_bos_token=add_bos_token, - add_eos_token=add_eos_token, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - spaces_between_special_tokens=spaces_between_special_tokens, - **kwargs, - ) - - # Store Blt-specific parameters - self.add_bos_token = add_bos_token - self.add_eos_token = add_eos_token - - def build_inputs_with_special_tokens( - self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None - ) -> list[int]: - """ - Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating and - adding special tokens. A Blt sequence has the following format: - - - single sequence: ` X ` - - pair of sequences: ` A B ` - - Args: - token_ids_0 (`List[int]`): - List of IDs to which the special tokens will be added. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - - Returns: - `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. - """ - bos = [self.bos_token_id] if self.add_bos_token else [] - eos = [self.eos_token_id] if self.add_eos_token else [] - - if token_ids_1 is None: - return bos + token_ids_0 + eos - return bos + token_ids_0 + eos + token_ids_1 + eos - - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: - files = self._tokenizer.model.save(save_directory, name=filename_prefix) - return tuple(files) - - -__all__ = ["BltTokenizerFast"] diff --git a/tests/models/blt/test_tokenization_blt.py b/tests/models/blt/test_tokenization_blt.py deleted file mode 100644 index e9e7fde10f48..000000000000 --- a/tests/models/blt/test_tokenization_blt.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from transformers import BltTokenizer, BltTokenizerFast -from transformers.testing_utils import require_tokenizers - -from ...test_tokenization_common import TokenizerTesterMixin - - -@require_tokenizers -class BltTokenizationTest(TokenizerTesterMixin, unittest.TestCase): - from_pretrained_id = [] - tokenizer_class = BltTokenizer - rust_tokenizer_class = BltTokenizerFast - - test_rust_tokenizer = True - test_sentencepiece = False - test_slow_tokenizer = True - from_pretrained_kwargs = {} - - @classmethod - def setUpClass(cls): - super().setUpClass() - # Create a simple Blt tokenizer for testing - tokenizer = BltTokenizer() - tokenizer.save_pretrained(cls.tmpdirname) - - def get_tokenizers(self, **kwargs): - kwargs.update({"add_bos_token": True, "add_eos_token": False}) - return super().get_tokenizers(**kwargs) - - def test_unicode_handling(self): - tokenizer = BltTokenizer(add_bos_token=False, add_eos_token=False) - - # Test Unicode character (é) - text = "café" - encoded = tokenizer.encode(text, add_special_tokens=False) - # "café" in UTF-8 bytes: [99, 97, 102, 195, 169] (é = 195, 169) - expected = [byte_val + tokenizer.offset for byte_val in [99, 97, 102, 195, 169]] - self.assertEqual(encoded, expected) - decoded = tokenizer.decode(encoded) - self.assertEqual(decoded, text) - - # Test emoji - text = "Hello 👋" - encoded = tokenizer.encode(text, add_special_tokens=False) - # "Hello 👋" in UTF-8 bytes: [72, 101, 108, 108, 111, 32, 240, 159, 145, 139] (👋 = 240, 159, 145, 139) - expected = [byte_val + tokenizer.offset for byte_val in [72, 101, 108, 108, 111, 32, 240, 159, 145, 139]] - self.assertEqual(encoded, expected) - decoded = tokenizer.decode(encoded) - self.assertEqual(decoded, text) - - def test_special_characters_and_unicode(self): - tokenizer = BltTokenizer(add_bos_token=False, add_eos_token=False) - - # Test special characters with unicode - text = "Hello, 世界! 🌍" - encoded = tokenizer.encode(text, add_special_tokens=False) - expected = [ - byte_val + tokenizer.offset - for byte_val in [72, 101, 108, 108, 111, 44, 32, 228, 184, 150, 231, 149, 140, 33, 32, 240, 159, 140, 141] - ] - self.assertEqual(encoded, expected) - decoded = tokenizer.decode(encoded) - self.assertEqual(decoded, text) - - # Test mixed special characters, numbers, and unicode - text = "Price: $100.50 €75.25 🎉" - encoded = tokenizer.encode(text, add_special_tokens=False) - expected = [ - byte_val + tokenizer.offset - for byte_val in [ - 80, - 114, - 105, - 99, - 101, - 58, - 32, - 36, - 49, - 48, - 48, - 46, - 53, - 48, - 32, - 226, - 130, - 172, - 55, - 53, - 46, - 50, - 53, - 32, - 240, - 159, - 142, - 137, - ] - ] - self.assertEqual(encoded, expected) - decoded = tokenizer.decode(encoded) - self.assertEqual(decoded, text) - - # Test control characters with unicode - text = "Line1\nLine2\tTabbed 中文" - encoded = tokenizer.encode(text, add_special_tokens=False) - expected = [ - byte_val + tokenizer.offset - for byte_val in [ - 76, - 105, - 110, - 101, - 49, - 10, - 76, - 105, - 110, - 101, - 50, - 9, - 84, - 97, - 98, - 98, - 101, - 100, - 32, - 228, - 184, - 173, - 230, - 150, - 135, - ] - ] - self.assertEqual(encoded, expected) - decoded = tokenizer.decode(encoded) - self.assertEqual(decoded, text) - - def test_empty_and_whitespace(self): - tokenizer = BltTokenizer(add_bos_token=False, add_eos_token=False) - - encoded = tokenizer.encode("", add_special_tokens=False) - self.assertEqual(encoded, []) - decoded = tokenizer.decode(encoded) - self.assertEqual(decoded, "") - - encoded = tokenizer.encode(" ", add_special_tokens=False) - self.assertEqual(encoded, [32 + tokenizer.offset]) # space + offset - decoded = tokenizer.decode(encoded) - self.assertEqual(decoded, " ") - - @unittest.skip("Blt byte-level tokenization doesn't handle pretokenized inputs the same way") - def test_pretokenized_inputs(self): - pass - - -if __name__ == "__main__": - unittest.main() From 8085a95a63790b2128efc91a32d6ab2580a01839 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 18 Sep 2025 13:55:38 +0000 Subject: [PATCH 136/139] fix docs --- docs/source/en/model_doc/blt.md | 10 ---------- utils/check_docstrings.py | 3 ++- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/docs/source/en/model_doc/blt.md b/docs/source/en/model_doc/blt.md index ebc54039d181..0289f77ac901 100644 --- a/docs/source/en/model_doc/blt.md +++ b/docs/source/en/model_doc/blt.md @@ -88,16 +88,6 @@ The original code can be found [here]() [[autodoc]] BltConfig -## BltTokenizer - -[[autodoc]] BltTokenizer - - build_inputs_with_special_tokens - - get_special_tokens_mask - - create_token_type_ids_from_sequences - - save_vocabulary - -## BltModel - [[autodoc]] BltModel - forward diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 81cfe7fd7e27..be52c5298472 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -130,7 +130,6 @@ "BloomTokenizerFast", "BLTConfig", "BLTPatcherConfig", - "BLTTokenizer", "BridgeTowerTextConfig", "BridgeTowerVisionConfig", "BrosModel", @@ -463,6 +462,8 @@ "ZeroShotImageClassificationPipeline", "ZeroShotObjectDetectionPipeline", "Llama4TextConfig", + "BltConfig", + "BltPatcherConfig", } # In addition to the objects above, we also ignore objects with certain prefixes. If you add an item to the list # below, make sure to add a comment explaining why. From 4272552d4fba895c23237b1ea9b943226aa77c74 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 18 Sep 2025 14:21:19 +0000 Subject: [PATCH 137/139] ruff --- src/transformers/models/blt/modeling_blt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 443f8bd09efe..e1639d4e3e2b 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -159,7 +159,7 @@ def forward( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Cache`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): From 05a5b4967bd604312a388cb80086329b89fedc44 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 18 Sep 2025 14:30:38 +0000 Subject: [PATCH 138/139] fix after rebase --- src/transformers/models/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index e6ef726f7565..f0939b089977 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -211,8 +211,8 @@ from .megatron_gpt2 import * from .mgp_str import * from .mimi import * - from .ministral import * from .minimax import * + from .ministral import * from .mistral import * from .mistral3 import * from .mixtral import * @@ -280,11 +280,11 @@ from .qwen2_audio import * from .qwen2_moe import * from .qwen2_vl import * + from .qwen3 import * + from .qwen3_moe import * from .qwen3_next import * from .qwen3_vl import * from .qwen3_vl_moe import * - from .qwen3 import * - from .qwen3_moe import * from .rag import * from .recurrent_gemma import * from .reformer import * From d983e72bb0de65b0f9177469a82c5ac8a7548295 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 18 Sep 2025 15:25:31 +0000 Subject: [PATCH 139/139] fix test, configs are not subscriptable --- tests/causal_lm_tester.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/causal_lm_tester.py b/tests/causal_lm_tester.py index 8600f1dc265e..4757d4b69c6c 100644 --- a/tests/causal_lm_tester.py +++ b/tests/causal_lm_tester.py @@ -497,7 +497,7 @@ def _config_supports_rope_scaling(config: PretrainedConfig) -> bool: # Has rope_theta (and no rope_scaling) -> probably an older model, but should support rope scaling as well main_config_has_rope = hasattr(config, "rope_scaling") or hasattr(config, "rope_theta") sub_config_has_rope = any( - hasattr(config[sub_config], "rope_scaling") or hasattr(config[sub_config], "rope_theta") + hasattr(getattr(config, sub_config), "rope_scaling") or hasattr(getattr(config, sub_config), "rope_theta") for sub_config in config.sub_configs.keys() ) return main_config_has_rope or sub_config_has_rope