Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
1e92e6d
initial config and MLA layer
Ssukriti Apr 1, 2025
c4f8051
first pass at decoder
Ssukriti Apr 2, 2025
4966ac1
completion of layers
Ssukriti Apr 3, 2025
721645a
modeling class
Ssukriti Apr 3, 2025
570147a
adding hybrid class to imports
Ssukriti Apr 3, 2025
3fcb2bf
fix imports granitemoehybrid
Ssukriti Apr 4, 2025
1d5b29e
fix granitehybrid imports
Ssukriti Apr 4, 2025
ff2f4e0
fix granitehybrid import
Ssukriti Apr 4, 2025
a69ef3d
fix generated modeling file
Ssukriti Apr 4, 2025
846a507
add some comments
Ssukriti Apr 4, 2025
69c061e
minor fixes in layers
Ssukriti Apr 6, 2025
d5e310c
add sharedMLP layer
Ssukriti Apr 6, 2025
e7bad48
correct layer names
Ssukriti Apr 6, 2025
5d5a87a
fixes in mamba config
Ssukriti Apr 7, 2025
711fc62
fix mamba config
Ssukriti Apr 7, 2025
8177217
change name of MLP layer
Ssukriti Apr 7, 2025
9790fbe
fix seq mizer layers
Ssukriti Apr 7, 2025
3151198
correct mamba config
Ssukriti Apr 7, 2025
e9c145a
fixes in param names
Ssukriti Apr 7, 2025
c32b8b0
enable hybrid model
Ssukriti Apr 10, 2025
018decd
update config
Ssukriti Apr 10, 2025
808e8b2
fix config granite hybrid
Ssukriti Apr 10, 2025
2bdfc51
fix attention layer
Ssukriti Apr 10, 2025
0969237
cleanup to re-use mamba code
Ssukriti Apr 11, 2025
278ed95
keep layer types
Ssukriti Apr 11, 2025
57f08c3
attention bias cleanup
Ssukriti Apr 14, 2025
92e9e66
update mamba layer name
Ssukriti Apr 15, 2025
601b1a3
first pass at tests
Ssukriti Apr 15, 2025
6a7d73d
first pass at tests
Ssukriti Apr 16, 2025
51fa4a5
use granite attention
Ssukriti Apr 17, 2025
5464c3c
fix: self attn weights
Ssukriti Apr 17, 2025
14fab28
pass at making pos_emb optional
Ssukriti Apr 18, 2025
d2476c5
initialize self_attn only as needed
Ssukriti Apr 18, 2025
c398665
overwrite forward to create HybridMambaCache
Ssukriti Apr 20, 2025
0a7f39b
Log invalid layer types
alex-jw-brooks Apr 29, 2025
67dbdfc
Add attention outputs test
alex-jw-brooks Apr 29, 2025
16fe690
Only emit attentions/logits if not None
alex-jw-brooks Apr 29, 2025
5c9a2d2
Fix config test hidden size divisibility
alex-jw-brooks Apr 29, 2025
1d00c3b
mark granitmoehybrid as stateful
alex-jw-brooks Apr 29, 2025
c649791
Initialize mamba convolutional layers
alex-jw-brooks Apr 29, 2025
caa0c93
Formatting fixes
alex-jw-brooks Apr 29, 2025
f3c04ea
config docstring, removed some unused attrs
alex-jw-brooks Apr 29, 2025
633536e
Fix missing arg in models test
alex-jw-brooks Apr 29, 2025
da56418
Fix create and check decoder model test
alex-jw-brooks Apr 29, 2025
352b469
support logits to keep in granitemoe
alex-jw-brooks Apr 29, 2025
63c47fd
regen to pass logits_to_keep
alex-jw-brooks Apr 29, 2025
9ed519e
Allow None or rope
alex-jw-brooks Apr 29, 2025
b2b7ffc
Fix gradient checkpointing
alex-jw-brooks Apr 29, 2025
473f2c0
Add granitemoehybrid as special cache for generate check
alex-jw-brooks Apr 29, 2025
7e0c3cf
Remove unused MLA refs
alex-jw-brooks Apr 30, 2025
d7a1ce8
Fix mamba layer mask
alex-jw-brooks Apr 30, 2025
809b596
Remove logits to keep from config
alex-jw-brooks Apr 30, 2025
b178033
Minor docstring nits
alex-jw-brooks Apr 30, 2025
5918634
Update licenses
alex-jw-brooks Apr 30, 2025
2e0ce02
Enable cache by default
alex-jw-brooks Apr 30, 2025
d8c1b4b
map layer types to layer block type
alex-jw-brooks Apr 30, 2025
37ed6e6
First pass at granite moe hybrid docs
alex-jw-brooks Apr 30, 2025
a9c01f1
Ignore granite moe hybrid in valid checkpoint check
alex-jw-brooks Apr 30, 2025
8f2171a
Align attention interfaces
alex-jw-brooks Apr 30, 2025
39d8cc6
regenerate modular granitemoeshared attention interface
alex-jw-brooks Apr 30, 2025
92ab83a
Align granite moe hybrid attn interface
alex-jw-brooks Apr 30, 2025
a684f96
run formatting
alex-jw-brooks Apr 30, 2025
94d7a9c
Handle mamba initialization
alex-jw-brooks Apr 30, 2025
8bd0b2d
avoid conditional attr defs
alex-jw-brooks Apr 30, 2025
d59ced5
Move hybrid layer validation to config
alex-jw-brooks Apr 30, 2025
610d93a
Add placeholder integration tests
alex-jw-brooks Apr 30, 2025
8274d2c
Docs nits / Update model names
alex-jw-brooks Apr 30, 2025
d43ffcf
Clean up forward conditions
alex-jw-brooks May 1, 2025
c6679ab
Use gradient checkpointing layer
alex-jw-brooks May 1, 2025
f4dca66
Remove some copied bamba tests + inherit
alex-jw-brooks May 1, 2025
1672a75
avoid redundant intermediate std var
alex-jw-brooks May 1, 2025
74ef853
use @can_return_tuple
alex-jw-brooks May 1, 2025
f12c856
Remove unused moe state
alex-jw-brooks May 1, 2025
156a682
make skipped test names consistent
alex-jw-brooks May 1, 2025
91f176e
Fix docstring order
alex-jw-brooks May 1, 2025
ff7dae2
Add missing toc
alex-jw-brooks May 2, 2025
d9cf0cc
Always create the shared mlp
alex-jw-brooks May 2, 2025
1c0272a
Fix name in docstring
alex-jw-brooks May 2, 2025
bd3081a
link preview model in docs
alex-jw-brooks May 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,8 @@
title: Granite
- local: model_doc/granitemoe
title: GraniteMoe
- local: model_doc/granitemoehybrid
title: GraniteMoeHybrid
- local: model_doc/granitemoeshared
title: GraniteMoeShared
- local: model_doc/helium
Expand Down
64 changes: 64 additions & 0 deletions docs/source/en/model_doc/granitemoehybrid.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# GraniteMoeHybrid

## Overview


The `GraniteMoeHybrid` model builds on top of `GraniteMoeSharedModel` and `Bamba`. Its decoding layers consist of state space layers or MoE attention layers with shared experts. By default, the attention layers do not use positional encoding.


```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model_path = "ibm-granite/granite-4.0-tiny-preview"
tokenizer = AutoTokenizer.from_pretrained(model_path)

# drop device_map if running on CPU
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
model.eval()

# change input text as desired
prompt = "Write a code to find the maximum value in a list of numbers."

# tokenize the text
input_tokens = tokenizer(prompt, return_tensors="pt")
# generate output tokens
output = model.generate(**input_tokens, max_new_tokens=100)
# decode output tokens into text
output = tokenizer.batch_decode(output)
# loop over the batch to print, in this example the batch size is 1
for i in output:
print(i)
```

This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944).


## GraniteMoeHybridConfig

[[autodoc]] GraniteMoeHybridConfig

## GraniteMoeHybridModel

[[autodoc]] GraniteMoeHybridModel
- forward

## GraniteMoeHybridForCausalLM

[[autodoc]] GraniteMoeHybridForCausalLM
- forward
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
from .granite import *
from .granite_speech import *
from .granitemoe import *
from .granitemoehybrid import *
from .granitemoeshared import *
from .grounding_dino import *
from .groupvit import *
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
("granite", "GraniteConfig"),
("granite_speech", "GraniteSpeechConfig"),
("granitemoe", "GraniteMoeConfig"),
("granitemoehybrid", "GraniteMoeHybridConfig"),
("granitemoeshared", "GraniteMoeSharedConfig"),
("granitevision", "LlavaNextConfig"),
("graphormer", "GraphormerConfig"),
Expand Down Expand Up @@ -509,6 +510,7 @@
("granite", "Granite"),
("granite_speech", "GraniteSpeech"),
("granitemoe", "GraniteMoeMoe"),
("granitemoehybrid", "GraniteMoeHybrid"),
("granitemoeshared", "GraniteMoeSharedMoe"),
("granitevision", "LLaVA-NeXT"),
("graphormer", "Graphormer"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
("granite", "GraniteModel"),
("granitemoe", "GraniteMoeModel"),
("granitemoehybrid", "GraniteMoeHybridModel"),
("granitemoeshared", "GraniteMoeSharedModel"),
("graphormer", "GraphormerModel"),
("grounding-dino", "GroundingDinoModel"),
Expand Down Expand Up @@ -558,6 +559,7 @@
("gptj", "GPTJForCausalLM"),
("granite", "GraniteForCausalLM"),
("granitemoe", "GraniteMoeForCausalLM"),
("granitemoehybrid", "GraniteMoeHybridForCausalLM"),
("granitemoeshared", "GraniteMoeSharedForCausalLM"),
("helium", "HeliumForCausalLM"),
("jamba", "JambaForCausalLM"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,7 @@ def torch_forward(
# Init cache
if ssm_state is not None and cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
cache_params.has_previous_state = True

scan_output = self.norm(y, gate)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ def torch_forward(
# Init cache
if ssm_state is not None and cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
cache_params.has_previous_state = True

scan_output = self.norm(y, gate)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def __init__(
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
# this model has rope embedding type, hardcoded for BC
self.position_embedding_type = "rope"

self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
Expand Down
Loading