diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index c0258da704be..0441c9183bd7 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -1017,6 +1017,8 @@
title: CLIPSeg
- local: model_doc/clvp
title: CLVP
+ - local: model_doc/cwm
+ title: Code World Model (CWM)
- local: model_doc/cohere2_vision
title: Cohere2Vision
- local: model_doc/colpali
diff --git a/docs/source/en/model_doc/cwm.md b/docs/source/en/model_doc/cwm.md
new file mode 100644
index 000000000000..c789d1abdab1
--- /dev/null
+++ b/docs/source/en/model_doc/cwm.md
@@ -0,0 +1,186 @@
+<-- Copyright 2025 the HuggingFace Team. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+
+⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer.
+
+-->
+
+
+# Code World Model (CWM)
+
+## Overview
+
+The Code World Model (CWM) model was proposed in [CWM: An Open-Weights LLM for Research on Code
+Generation with World Models](https://ai.facebook.com/research/publications/cwm) by Meta FAIR CodeGen Team.
+CWM is an LLM for code generation and reasoning about code that has, in particular, been trained
+to better represent and reason about how code and commands affect the state of a program or system.
+Specifically, we mid-trained CWM on a large number of observation-action trajectories from Python
+execution traces and agentic interactions in containerized environments. We post-trained with
+extensive multi-task RL in verifiable coding, math, and multi-turn software engineering environments.
+
+The abstract from the paper is the following:
+
+> *We release Code World Model (CWM), a 32-billion-parameter open-weights LLM, to advance research
+on code generation with world models. To improve code understanding beyond what can be learned
+from training on static code alone, we mid-train CWM on a large amount of observation-action
+trajectories from Python interpreter and agentic Docker environments, and perform extensive multi-
+task reasoning RL in verifiable coding, math, and multi-turn software engineering environments. With
+CWM, we provide a strong testbed for researchers to explore the opportunities world modeling affords
+for improving code generation with reasoning and planning in computational environments. We
+present first steps of how world models can benefit agentic coding, enable step-by-step simulation of
+Python code execution, and show early results of how reasoning can benefit from the latter. CWM is
+a dense, decoder-only LLM trained with a context size of up to 131 k tokens. Independent of its world
+modeling capabilities, CWM offers strong performance on general coding and math tasks: it reaches
+pass@1 scores of 65.8 % on SWE-bench Verified (with test-time scaling), 68.9 % on LiveCodeBench,
+96.6 % on Math-500, and 76.0 % on AIME 2024. To support further research on code world modeling,
+we release model checkpoints after mid-training, SFT, and RL.*
+
+Tips:
+
+CWM requires a dedicated system prompt to function optimally during inference. Without proper prompt
+configuration, CWM's output quality may be significantly degraded. The following serves as the default
+system prompt for reasoning tasks. For agentic workflows, append the relevant tool specifications
+after this base prompt. Checkout the original code repository for more details.
+```
+You are a helpful AI assistant. You always reason before responding, using the following format:
+
+
+your internal reasoning
+
+your external response
+```
+
+This model was contributed by [Meta FAIR CodeGen Team](https://huggingface.co/facebook).
+The original code can be found [here](https://github.com/facebookresearch/cwm).
+
+## Usage examples
+
+```python
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+model_name = 'facebook/cwm'
+tokenizer = AutoTokenizer.from_pretrained(model_name)
+model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ dtype="auto",
+ device_map="auto"
+)
+
+system_prompt = """
+You are a helpful AI assistant. You always reason before responding, using the following format:
+
+
+your internal reasoning
+
+your external response
+""".strip()
+
+messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": "Write a haiku about recursion in programming."}
+]
+
+
+text = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=True, # Switches between thinking and non-thinking modes. Default is True.
+ preserve_previous_think=True, # Switches between keeping thinking blocks from previous messages or not. Default is True.
+)
+
+model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
+
+generated_ids = model.generate(
+ **model_inputs,
+ max_new_tokens=1024
+)
+output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
+print(tokenizer.decode(output_ids))
+```
+
+Produces the following output:
+
+ Okay, I need to write a haiku about recursion in programming. Let's start by recalling what recursion is. Recursion is a programming technique where a function calls itself in its own definition, often used to solve problems that can be broken down into smaller, similar sub-problems.
+
+ A haiku is a three-line poem with a syllable count of 5-7-5. So, I need to capture the essence of recursion in a concise and poetic way. Let's think about the key elements of recursion: the function calling itself, the base case that stops the recursion, and the idea of breaking down problems into smaller parts.
+
+ Maybe I can use imagery related to mirrors or reflections, since recursion often involves a function reflecting or repeating itself. Alternatively, I could use natural imagery like a spiral or a tree, which suggests self-similar patterns.
+
+ Let me try to come up with some ideas. For example, "Function calls itself" could be a starting point. But how to make that into a haiku? Maybe something like:
+
+ Function calls itself
+ Base case stops the endless loop
+ Spiral of code
+
+ Hmm, that's a start. But maybe I can make it more poetic. Let's think about the idea of a function calling itself as a kind of infinite loop, but with a base case that stops it. That's like a spiral that unwinds at the center.
+
+ Another angle: recursion is often used for problems that have a recursive structure, like tree traversals or factorial calculations. Maybe I can use a natural image to represent that.
+
+ How about:
+
+ Branches split again
+ Leaves unfold in endless dance
+ Roots hold the base case
+
+ That's a bit more poetic. The branches splitting represent the recursive calls, the leaves unfolding could represent the processing of each sub-problem, and the roots holding the base case are what stops the recursion.
+
+ Alternatively, maybe something simpler:
+
+ Mirror within mirror
+ Reflections infinite deep
+ Base case breaks the glass
+
+ This uses the mirror metaphor to represent the function calling itself, with the base case being the point where the recursion stops, breaking the infinite reflection.
+
+ I think I like the mirror idea better. Let me refine that:
+
+ Mirror within mirror
+ Infinite reflections dance
+ Base case shatters glass
+
+ Yes, that seems to capture the idea of recursion well. The mirror represents the function calling itself, the infinite reflections are the recursive calls, and the base case shattering the glass is what stops the recursion.
+
+ Let me check the syllable count:
+
+ Mirror within mirror (5 syllables: Mir-ror-with-in-mir-ror)
+ Infinite reflections dance (7 syllables: In-fi-nite-ref-lec-tions-dance)
+ Base case shatters glass (5 syllables: Base-case-shat-ters-glass)
+
+ Yes, that fits the 5-7-5 structure. I think that's a good haiku about recursion.
+
+ Mirror within mirror
+ Infinite reflections dance
+ Base case shatters glass<|eot_id|>
+
+
+
+## CwmConfig
+
+[[autodoc]] CwmConfig
+
+## CwmPreTrainedModel
+
+[[autodoc]] CwmPreTrainedModel
+ - forward
+
+## CwmModel
+
+[[autodoc]] CwmModel
+ - forward
+
+## CwmForCausalLM
+
+[[autodoc]] CwmForCausalLM
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 2905a842612e..643d5d9c8bdb 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -76,6 +76,7 @@
from .csm import *
from .ctrl import *
from .cvt import *
+ from .cwm import *
from .d_fine import *
from .dab_detr import *
from .dac import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index c40b5a37b02a..cc6252bf1769 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -94,6 +94,7 @@
("csm", "CsmConfig"),
("ctrl", "CTRLConfig"),
("cvt", "CvtConfig"),
+ ("cwm", "CwmConfig"),
("d_fine", "DFineConfig"),
("dab-detr", "DabDetrConfig"),
("dac", "DacConfig"),
@@ -526,6 +527,7 @@
("csm", "CSM"),
("ctrl", "CTRL"),
("cvt", "CvT"),
+ ("cwm", "Code World Model (CWM)"),
("d_fine", "D-FINE"),
("dab-detr", "DAB-DETR"),
("dac", "DAC"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 297d4890d131..8ea63e08dc66 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -99,6 +99,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("csm", "CsmForConditionalGeneration"),
("ctrl", "CTRLModel"),
("cvt", "CvtModel"),
+ ("cwm", "CwmModel"),
("d_fine", "DFineModel"),
("dab-detr", "DabDetrModel"),
("dac", "DacModel"),
@@ -644,6 +645,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("cohere2", "Cohere2ForCausalLM"),
("cpmant", "CpmAntForCausalLM"),
("ctrl", "CTRLLMHeadModel"),
+ ("cwm", "CwmForCausalLM"),
("data2vec-text", "Data2VecTextForCausalLM"),
("dbrx", "DbrxForCausalLM"),
("deepseek_v2", "DeepseekV2ForCausalLM"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index 5aecc119560e..beddc007492e 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -170,6 +170,13 @@
("cpmant", ("CpmAntTokenizer", None)),
("csm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("ctrl", ("CTRLTokenizer", None)),
+ (
+ "cwm",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
("data2vec-audio", ("Wav2Vec2CTCTokenizer", None)),
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("dbrx", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
diff --git a/src/transformers/models/cwm/__init__.py b/src/transformers/models/cwm/__init__.py
new file mode 100644
index 000000000000..93c906ce6db9
--- /dev/null
+++ b/src/transformers/models/cwm/__init__.py
@@ -0,0 +1,29 @@
+# coding=utf-8
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_cwm import *
+ from .modeling_cwm import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/cwm/configuration_cwm.py b/src/transformers/models/cwm/configuration_cwm.py
new file mode 100644
index 000000000000..8dd2887722f7
--- /dev/null
+++ b/src/transformers/models/cwm/configuration_cwm.py
@@ -0,0 +1,196 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/cwm/modular_cwm.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_cwm.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...modeling_rope_utils import rope_config_validation
+
+
+class CwmConfig(PretrainedConfig):
+ """
+ Configuration for Code World Model (CWM).
+ This is an inherited Llama3-compatible configuration with layer-interleaved
+ sliding-window attention. Configures a `CwmModel`. Designed to yield a configuartion mirroring the model in the
+ [facebook/cwm](https://huggingface.co/facebook/cwm) architecture by default. Other models include:
+ - [facebook/cwm-sft](https://huggingface.co/facebook/cwm-sft)
+ - [facebook/cwm-pretrain](https://huggingface.co/facebook/cwm-pretrain)
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 128256):
+ Vocabulary size of the CWM model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`CwmModel`]
+ hidden_size (`int`, *optional*, defaults to 6144):
+ Dimension of the hidden representations
+ intermediate_size (`int`, *optional*, defaults to 21504):
+ Dimension of the MLP representations
+ num_hidden_layers (`int`, *optional*, defaults to 64):
+ Number of hidden layers in the Transformer decoder
+ num_attention_heads (`int`, *optional*, defaults to 48):
+ Number of attention heads for each attention layer in the Transformer decoder
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention (GQA).
+ If it is not specified, will default to `num_attention_heads`.
+ head_dim (`int`, *optional*, defaults to 128):
+ The attention head dimension.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
+ The maximum sequence length that this model might ever be used with. CWM's attention allows sequence
+ lengths up to 131072 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ eos_token_id (`int` or `list[int]`, *optional*, defaults to `[128001, 128008, 128009]`):
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
+ bos_token_id (`int`, *optional*, defaults to 128000):
+ The id of the *beginning-of-sequence* token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Tensor parallelism degree used during pretraining. See [this
+ document](https://huggingface.co/docs/transformers/parallelism) and [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings
+ sliding_window (`int`, *optional*, defaults to 8192):
+ Sliding window attention window size.
+ layer_types (`List[str]`, *optional*):
+ List of layer types for each layer. Each element should be either "full_attention" or "sliding_attention".
+ If not specified, will default to alternating pattern based on the provided window pattern.
+ """
+
+ model_type = "cwm"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `CwmModel`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size: int = 128256,
+ hidden_size: int = 6144,
+ intermediate_size: int = 21504,
+ num_hidden_layers: int = 64,
+ num_attention_heads: int = 48,
+ num_key_value_heads: int = 8,
+ head_dim: int = 128,
+ hidden_act: str = "silu",
+ max_position_embeddings: int = 131072,
+ initializer_range: float = 0.02,
+ rms_norm_eps: float = 1e-5,
+ use_cache: bool = True,
+ pad_token_id: Optional[int] = None,
+ eos_token_id=[128001, 128008, 128009],
+ bos_token_id: int = 128000,
+ tie_word_embeddings: bool = False,
+ rope_theta: float = 1_000_000.0,
+ attention_dropout: float = 0.0,
+ pretraining_tp: int = 1,
+ mlp_bias: bool = False,
+ rope_scaling: Optional[dict] = None,
+ # CWM interleaved sliding window fields
+ sliding_window: int = 8192,
+ layer_types: Optional[list[str]] = None, # ["full_attention"|"sliding_attention"] per layer
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ if rope_scaling is None:
+ rope_scaling = {
+ "factor": 16.0,
+ "high_freq_factor": 4.0,
+ "low_freq_factor": 1.0,
+ "original_max_position_embeddings": 8192,
+ "rope_type": "llama3",
+ }
+
+ if layer_types is None:
+ # Default pattern: every 4th layer uses full attention, others use sliding attention
+ window_pattern = 4
+ layer_types = [
+ ("full_attention" if (i % window_pattern == 0) else "sliding_attention")
+ for i in range(num_hidden_layers)
+ ]
+ else:
+ layer_type_validation(layer_types, num_hidden_layers)
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_dropout = attention_dropout
+ self.mlp_bias = mlp_bias
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+ self.sliding_window = int(sliding_window) if sliding_window else None
+ self.layer_types = list(layer_types)
+
+
+__all__ = ["CwmConfig"]
diff --git a/src/transformers/models/cwm/modeling_cwm.py b/src/transformers/models/cwm/modeling_cwm.py
new file mode 100644
index 000000000000..1f0c09996086
--- /dev/null
+++ b/src/transformers/models/cwm/modeling_cwm.py
@@ -0,0 +1,486 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/cwm/modular_cwm.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_cwm.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_cwm import CwmConfig
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class CwmAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: CwmConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+ self.q_proj = torch.nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
+ self.k_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window, # main diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class CwmRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ CwmRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class CwmMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class CwmDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: CwmConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = CwmAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = CwmMLP(config)
+ self.input_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.attention_type = config.layer_types[layer_idx]
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class CwmPreTrainedModel(PreTrainedModel):
+ config: CwmConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["CwmDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": CwmDecoderLayer,
+ "attentions": CwmAttention,
+ }
+
+
+class CwmModelOutputWithPast(BaseModelOutputWithPast):
+ pass
+
+
+class CwmRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: CwmConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@auto_docstring
+class CwmModel(CwmPreTrainedModel):
+ config_class = CwmConfig
+
+ def __init__(self, config: CwmConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = torch.nn.ModuleList(
+ [CwmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = CwmRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CwmModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ sliding_mask_kwargs = mask_kwargs.copy()
+
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
+ }
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return CwmModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class CwmForCausalLM(CwmPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = CwmModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, CwmForCausalLM
+
+ >>> model = CwmForCausalLM.from_pretrained("meta-cwm/Cwm-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-cwm/Cwm-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["CwmPreTrainedModel", "CwmModel", "CwmForCausalLM"]
diff --git a/src/transformers/models/cwm/modular_cwm.py b/src/transformers/models/cwm/modular_cwm.py
new file mode 100644
index 000000000000..022fe9c21c19
--- /dev/null
+++ b/src/transformers/models/cwm/modular_cwm.py
@@ -0,0 +1,293 @@
+# coding=utf-8
+# Copyright 2025
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+import torch
+
+from ...cache_utils import Cache, DynamicCache
+from ...configuration_utils import layer_type_validation
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_outputs import BaseModelOutputWithPast
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging
+from ..llama.configuration_llama import LlamaConfig
+from ..llama.modeling_llama import (
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaModel,
+ LlamaPreTrainedModel,
+)
+from ..qwen2.modeling_qwen2 import Qwen2Attention
+
+
+logger = logging.get_logger(__name__)
+
+
+class CwmConfig(LlamaConfig):
+ """
+ Configuration for Code World Model (CWM).
+ This is an inherited Llama3-compatible configuration with layer-interleaved
+ sliding-window attention. Configures a `CwmModel`. Designed to yield a configuartion mirroring the model in the
+ [facebook/cwm](https://huggingface.co/facebook/cwm) architecture by default. Other models include:
+ - [facebook/cwm-sft](https://huggingface.co/facebook/cwm-sft)
+ - [facebook/cwm-pretrain](https://huggingface.co/facebook/cwm-pretrain)
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 128256):
+ Vocabulary size of the CWM model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`CwmModel`]
+ hidden_size (`int`, *optional*, defaults to 6144):
+ Dimension of the hidden representations
+ intermediate_size (`int`, *optional*, defaults to 21504):
+ Dimension of the MLP representations
+ num_hidden_layers (`int`, *optional*, defaults to 64):
+ Number of hidden layers in the Transformer decoder
+ num_attention_heads (`int`, *optional*, defaults to 48):
+ Number of attention heads for each attention layer in the Transformer decoder
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention (GQA).
+ If it is not specified, will default to `num_attention_heads`.
+ head_dim (`int`, *optional*, defaults to 128):
+ The attention head dimension.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
+ The maximum sequence length that this model might ever be used with. CWM's attention allows sequence
+ lengths up to 131072 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ eos_token_id (`int` or `list[int]`, *optional*, defaults to `[128001, 128008, 128009]`):
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
+ bos_token_id (`int`, *optional*, defaults to 128000):
+ The id of the *beginning-of-sequence* token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Tensor parallelism degree used during pretraining. See [this
+ document](https://huggingface.co/docs/transformers/parallelism) and [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings
+ sliding_window (`int`, *optional*, defaults to 8192):
+ Sliding window attention window size.
+ layer_types (`List[str]`, *optional*):
+ List of layer types for each layer. Each element should be either "full_attention" or "sliding_attention".
+ If not specified, will default to alternating pattern based on the provided window pattern.
+ """
+
+ model_type = "cwm"
+
+ def __init__(
+ self,
+ vocab_size: int = 128256,
+ hidden_size: int = 6144,
+ intermediate_size: int = 21504,
+ num_hidden_layers: int = 64,
+ num_attention_heads: int = 48,
+ num_key_value_heads: int = 8,
+ head_dim: int = 128,
+ hidden_act: str = "silu",
+ max_position_embeddings: int = 131072,
+ initializer_range: float = 0.02,
+ rms_norm_eps: float = 1e-5,
+ use_cache: bool = True,
+ pad_token_id: Optional[int] = None,
+ eos_token_id=[128001, 128008, 128009],
+ bos_token_id: int = 128000,
+ tie_word_embeddings: bool = False,
+ rope_theta: float = 1_000_000.0,
+ attention_dropout: float = 0.0,
+ pretraining_tp: int = 1,
+ mlp_bias: bool = False,
+ rope_scaling: Optional[dict] = None,
+ # CWM interleaved sliding window fields
+ sliding_window: int = 8192,
+ layer_types: Optional[list[str]] = None, # ["full_attention"|"sliding_attention"] per layer
+ **kwargs,
+ ):
+ if rope_scaling is None:
+ rope_scaling = {
+ "factor": 16.0,
+ "high_freq_factor": 4.0,
+ "low_freq_factor": 1.0,
+ "original_max_position_embeddings": 8192,
+ "rope_type": "llama3",
+ }
+
+ if layer_types is None:
+ # Default pattern: every 4th layer uses full attention, others use sliding attention
+ window_pattern = 4
+ layer_types = [
+ ("full_attention" if (i % window_pattern == 0) else "sliding_attention")
+ for i in range(num_hidden_layers)
+ ]
+ else:
+ layer_type_validation(layer_types, num_hidden_layers)
+
+ super().__init__(
+ vocab_size=vocab_size,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ head_dim=head_dim,
+ hidden_act=hidden_act,
+ max_position_embeddings=max_position_embeddings,
+ initializer_range=initializer_range,
+ rms_norm_eps=rms_norm_eps,
+ use_cache=use_cache,
+ pad_token_id=pad_token_id,
+ eos_token_id=list(eos_token_id),
+ bos_token_id=bos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ rope_theta=rope_theta,
+ attention_bias=False,
+ attention_dropout=attention_dropout,
+ rope_scaling=rope_scaling,
+ pretraining_tp=pretraining_tp,
+ mlp_bias=mlp_bias,
+ **kwargs,
+ )
+
+ # CWM models don't use attention bias, remove it from config
+ del self.attention_bias
+
+ self.sliding_window = int(sliding_window) if sliding_window else None
+ self.layer_types = list(layer_types)
+
+
+class CwmAttention(Qwen2Attention):
+ def __init__(self, config: CwmConfig, layer_idx: int):
+ super().__init__(config=config, layer_idx=layer_idx)
+ self.q_proj = torch.nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
+ self.k_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
+
+
+class CwmDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: CwmConfig, layer_idx: int):
+ super().__init__(config=config, layer_idx=layer_idx)
+ self.attention_type = config.layer_types[layer_idx]
+ self.self_attn = CwmAttention(config=config, layer_idx=layer_idx)
+
+
+class CwmPreTrainedModel(LlamaPreTrainedModel):
+ pass
+
+
+class CwmModelOutputWithPast(BaseModelOutputWithPast):
+ pass
+
+
+class CwmModel(LlamaModel):
+ config_class = CwmConfig
+
+ def __init__(self, config: CwmConfig):
+ super().__init__(config)
+ self.layers = torch.nn.ModuleList(
+ [CwmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CwmModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ sliding_mask_kwargs = mask_kwargs.copy()
+
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
+ }
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return CwmModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+class CwmForCausalLM(LlamaForCausalLM):
+ pass
+
+
+__all__ = [
+ "CwmConfig",
+ "CwmPreTrainedModel",
+ "CwmModel",
+ "CwmForCausalLM",
+]
diff --git a/tests/models/cwm/__init__.py b/tests/models/cwm/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/cwm/test_configuration_cwm.py b/tests/models/cwm/test_configuration_cwm.py
new file mode 100644
index 000000000000..3d6f49527875
--- /dev/null
+++ b/tests/models/cwm/test_configuration_cwm.py
@@ -0,0 +1,126 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from transformers.models.cwm import CwmConfig
+from transformers.testing_utils import require_torch
+
+from ...test_configuration_common import ConfigTester
+
+
+class CwmConfigTest(unittest.TestCase):
+ def test_default_config(self):
+ """Test default CWM configuration"""
+ config = CwmConfig()
+
+ # CWM defaults
+ self.assertEqual(config.sliding_window, 8192)
+ self.assertIsInstance(config.layer_types, list)
+
+ # Llama3 defaults
+ self.assertEqual(config.vocab_size, 128256)
+ self.assertEqual(config.rope_theta, 1_000_000.0)
+ self.assertIsNotNone(config.rope_scaling)
+ self.assertEqual(config.rope_scaling["rope_type"], "llama3")
+
+ def test_custom_sliding_window_config(self):
+ config = CwmConfig(sliding_window=4096)
+
+ self.assertEqual(config.sliding_window, 4096)
+
+ def test_custom_layer_types_config(self):
+ layer_types = ["full_attention", "sliding_attention", "sliding_attention", "full_attention"]
+ config = CwmConfig(num_hidden_layers=4, layer_types=layer_types)
+
+ self.assertEqual(config.layer_types, layer_types)
+ self.assertEqual(len(config.layer_types), config.num_hidden_layers)
+
+ def test_invalid_layer_types_length(self):
+ with self.assertRaises(ValueError):
+ CwmConfig(
+ num_hidden_layers=4,
+ layer_types=["full_attention", "sliding_attention"], # Only 2 types for 4 layers
+ )
+
+ def test_invalid_layer_type_value(self):
+ with self.assertRaises(ValueError):
+ CwmConfig(num_hidden_layers=2, layer_types=["full_attention", "invalid_attention"])
+
+ def test_automatic_layer_types_generation(self):
+ # Test default pattern (every 4th layer uses full attention)
+ config = CwmConfig(num_hidden_layers=8)
+
+ expected_types = [
+ "full_attention", # layer 0: 0 % 4 == 0
+ "sliding_attention", # layer 1: 1 % 4 != 0
+ "sliding_attention", # layer 2: 2 % 4 != 0
+ "sliding_attention", # layer 3: 3 % 4 != 0
+ "full_attention", # layer 4: 4 % 4 == 0
+ "sliding_attention", # layer 5: 5 % 4 != 0
+ "sliding_attention", # layer 6: 6 % 4 != 0
+ "sliding_attention", # layer 7: 7 % 4 != 0
+ ]
+
+ self.assertEqual(config.layer_types, expected_types)
+
+ def test_rope_scaling_config(self):
+ custom_rope_scaling = {
+ "factor": 8.0,
+ "high_freq_factor": 2.0,
+ "low_freq_factor": 0.5,
+ "original_max_position_embeddings": 4096,
+ "rope_type": "llama3",
+ }
+
+ config = CwmConfig(rope_scaling=custom_rope_scaling)
+
+ self.assertEqual(config.rope_scaling, custom_rope_scaling)
+
+ def test_config_serialization(self):
+ config = CwmConfig(
+ sliding_window=4096,
+ layer_types=["full_attention", "sliding_attention"] * 3,
+ num_hidden_layers=6,
+ )
+
+ config_dict = config.to_dict()
+ self.assertIn("sliding_window", config_dict)
+ self.assertIn("layer_types", config_dict)
+
+ new_config = CwmConfig.from_dict(config_dict)
+ self.assertEqual(new_config.sliding_window, config.sliding_window)
+ self.assertEqual(new_config.layer_types, config.layer_types)
+
+ def test_config_inheritance_from_llama(self):
+ config = CwmConfig()
+
+ # Llama config attributes
+ self.assertTrue(hasattr(config, "hidden_size"))
+ self.assertTrue(hasattr(config, "num_attention_heads"))
+ self.assertTrue(hasattr(config, "num_key_value_heads"))
+ self.assertTrue(hasattr(config, "intermediate_size"))
+ self.assertTrue(hasattr(config, "rope_theta"))
+ self.assertTrue(hasattr(config, "attention_dropout"))
+
+
+@require_torch
+class CwmConfigTester(ConfigTester):
+ def __init__(self, parent, config_class=None, **kwargs):
+ super().__init__(parent, config_class=config_class, **kwargs)
+
+ def test_config(self):
+ config_class = CwmConfig
+ self.config_tester = ConfigTester(self, config_class=config_class)
+ self.config_tester.run_common_tests()
diff --git a/tests/models/cwm/test_modeling_cwm.py b/tests/models/cwm/test_modeling_cwm.py
new file mode 100644
index 000000000000..ff1a90323903
--- /dev/null
+++ b/tests/models/cwm/test_modeling_cwm.py
@@ -0,0 +1,298 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from transformers import is_torch_available
+from transformers.testing_utils import (
+ cleanup,
+ require_torch,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+
+from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
+
+
+if is_torch_available():
+ import torch
+
+ from transformers.models.cwm import (
+ CwmConfig,
+ CwmForCausalLM,
+ CwmModel,
+ )
+
+
+class CwmModelTester(CausalLMModelTester):
+ if is_torch_available():
+ config_class = CwmConfig
+ base_model_class = CwmModel
+ causal_lm_class = CwmForCausalLM
+
+ def get_config(self):
+ config = super().get_config()
+
+ config.sliding_window = 8192
+ config.rope_scaling = {
+ "factor": 16.0,
+ "high_freq_factor": 4.0,
+ "low_freq_factor": 1.0,
+ "original_max_position_embeddings": 8192,
+ "rope_type": "llama3",
+ }
+
+ return config
+
+
+@require_torch
+class CwmModelTest(CausalLMModelTest, unittest.TestCase):
+ all_model_classes = (
+ (
+ CwmModel,
+ CwmForCausalLM,
+ )
+ if is_torch_available()
+ else ()
+ )
+ pipeline_model_mapping = (
+ {
+ "feature-extraction": CwmModel,
+ "text-generation": CwmForCausalLM,
+ }
+ if is_torch_available()
+ else {}
+ )
+ test_headmasking = False
+ test_pruning = False
+ fx_compatible = False
+ model_tester_class = CwmModelTester
+
+ model_split_percents = [0.5, 0.7, 0.8]
+
+ _torch_compile_train_cls = CwmForCausalLM if is_torch_available() else None
+
+
+@require_torch_accelerator
+@slow
+class CwmIntegrationTest(unittest.TestCase):
+ def setUp(self):
+ cleanup(torch_device, gc_collect=True)
+
+ def tearDown(self):
+ cleanup(torch_device, gc_collect=True)
+
+ @slow
+ def test_cwm_integration(self):
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained("facebook/cwm")
+ model = CwmForCausalLM.from_pretrained("facebook/cwm", device_map="auto", dtype=torch.bfloat16)
+
+ self.assertIsNotNone(model.config.sliding_window)
+ self.assertIsNotNone(model.config.layer_types)
+ self.assertIn("full_attention", model.config.layer_types)
+ self.assertIn("sliding_attention", model.config.layer_types)
+
+ for i, layer in enumerate(model.model.layers):
+ expected_type = model.config.layer_types[i]
+ self.assertEqual(layer.attention_type, expected_type)
+ if expected_type == "sliding_attention":
+ self.assertEqual(layer.self_attn.sliding_window, model.config.sliding_window)
+
+ prompt = "def quicksort(arr):"
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
+
+ with torch.no_grad():
+ out = model(**inputs)
+
+ expected_logits = torch.tensor(
+ [
+ 0.5625,
+ 2.9531,
+ 9.1875,
+ 0.4746,
+ -0.3613,
+ 2.2031,
+ 2.9844,
+ 1.5312,
+ 0.5859,
+ 1.5391,
+ 2.7500,
+ 3.4375,
+ 2.0156,
+ 2.1719,
+ 1.5469,
+ 2.5469,
+ 2.8438,
+ 1.8203,
+ 1.7188,
+ 1.3984,
+ 1.0469,
+ 0.1748,
+ 0.4453,
+ 0.1533,
+ -0.1157,
+ 0.8516,
+ 2.2344,
+ 5.2188,
+ 1.2891,
+ 1.5234,
+ 0.8555,
+ 0.6992,
+ ],
+ dtype=torch.bfloat16,
+ ).to(model.device)
+
+ self.assertTrue(torch.allclose(out.logits[0, -1, :32], expected_logits, atol=1e-2, rtol=1e-2))
+
+ self.assertEqual(out.logits.shape[1], inputs.input_ids.shape[1])
+ self.assertEqual(out.logits.shape[2], model.config.vocab_size)
+ self.assertFalse(torch.isnan(out.logits).any())
+ self.assertFalse(torch.isinf(out.logits).any())
+
+ @slow
+ def test_cwm_sliding_window_long_sequence(self):
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained("facebook/cwm")
+ model = CwmForCausalLM.from_pretrained("facebook/cwm", device_map="auto", dtype=torch.bfloat16)
+
+ sliding_window = model.config.sliding_window
+ long_text = "for i in range(1000):\n print(f'iteration {i}')\n" * 600
+
+ inputs = tokenizer(long_text, return_tensors="pt").to(model.device)
+ seq_len = inputs.input_ids.shape[1]
+
+ # create a sequence longer than sliding window
+ self.assertGreater(
+ seq_len, sliding_window, f"Test sequence length {seq_len} should be > sliding window {sliding_window}"
+ )
+
+ with torch.no_grad():
+ out = model(**inputs)
+
+ expected_logits = torch.tensor(
+ [
+ 4.7812,
+ 6.1875,
+ 13.1875,
+ 4.4062,
+ 5.0312,
+ 3.9844,
+ 6.6875,
+ 4.8438,
+ 2.3125,
+ 6.5000,
+ 4.4688,
+ 0.5195,
+ 5.6562,
+ 3.3125,
+ 2.7500,
+ 4.9062,
+ 5.5938,
+ 4.1562,
+ 3.9531,
+ 2.4062,
+ 3.2812,
+ 2.8594,
+ 3.4688,
+ 2.9688,
+ 2.6875,
+ 3.4531,
+ 2.7344,
+ 7.2812,
+ 4.5000,
+ 5.7500,
+ 2.3438,
+ 5.9688,
+ ],
+ dtype=torch.bfloat16,
+ ).to(model.device)
+
+ self.assertTrue(torch.allclose(out.logits[0, -1, :32], expected_logits, atol=1e-2, rtol=1e-2))
+
+ self.assertEqual(out.logits.shape[1], seq_len)
+ self.assertEqual(out.logits.shape[2], model.config.vocab_size)
+ self.assertFalse(torch.isnan(out.logits).any())
+ self.assertFalse(torch.isinf(out.logits).any())
+
+ for i, layer in enumerate(model.model.layers):
+ if model.config.layer_types[i] == "sliding_attention":
+ self.assertEqual(layer.self_attn.sliding_window, sliding_window)
+
+ @slow
+ def test_cwm_generation_20_tokens(self):
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained("facebook/cwm")
+ model = CwmForCausalLM.from_pretrained("facebook/cwm", device_map="auto", dtype=torch.bfloat16)
+
+ system_prompt = "You are a helpful AI assistant. You always reason before responding, using the following format:\n\n\nyour internal reasoning\n\nyour external response"
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": "Write a simple Python function to add two numbers."},
+ ]
+
+ text = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=True,
+ preserve_previous_think=True,
+ )
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
+
+ with torch.no_grad():
+ generated_ids = model.generate(
+ **model_inputs,
+ max_new_tokens=20,
+ do_sample=False,
+ temperature=1.0,
+ top_p=1.0,
+ pad_token_id=tokenizer.eos_token_id,
+ )
+
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
+ generated_text = tokenizer.decode(output_ids, skip_special_tokens=False)
+
+ self.assertEqual(len(output_ids), 20, "Should generate exactly 20 tokens")
+
+ expected_token_ids = [
+ 33413,
+ 11,
+ 358,
+ 1205,
+ 311,
+ 3350,
+ 264,
+ 13325,
+ 734,
+ 430,
+ 11621,
+ 1403,
+ 5219,
+ 13,
+ 6914,
+ 596,
+ 1212,
+ 555,
+ 89746,
+ 1268,
+ ]
+ expected_text = "Okay, I need to write a Python function that adds two numbers. Let's start by recalling how"
+
+ self.assertEqual(output_ids, expected_token_ids, "Generated tokens should match ground truth")
+ self.assertEqual(generated_text, expected_text, "Generated text should match ground truth")