Skip to content

Remove many output_attentions and other traced outputs on 100+ models #43590

Open
molbap wants to merge 176 commits intomainfrom
update_all_decorators
Open

Remove many output_attentions and other traced outputs on 100+ models #43590
molbap wants to merge 176 commits intomainfrom
update_all_decorators

Conversation

@molbap
Copy link
Contributor

@molbap molbap commented Jan 29, 2026

What does this PR do?

In model additions, we often see old standards not using check_model_inputs, can_return_tuple and it's often a first review comment/something that can slip through. Doing a wide scan to try to remove all occurrences systematically.

Background

Every model used to manually resolve output_attentions, output_hidden_states, and return_dict in each forward, then collect intermediate outputs in a loop, then convert to tuple at the end. That's ~30 lines of boilerplate per model, reimplemented everywhere with subtle inconsistencies.

Two decorators now handle this:

  • @capture_outputs goes on the base model forward (the one with the layer loop). It reads output_attentions/output_hidden_states from kwargs or config, installs hooks on modules listed in _can_record_outputs, collects intermediate outputs automatically, injects them into the ModelOutput, and handles return_dict. The model just needs to declare which module classes produce which outputs (e.g. _can_record_outputs = {"hidden_states": DecoderLayer, "attentions": Attention}).

  • @can_return_tuple goes on wrapper forwards (ForCausalLM, ForSequenceClassification, VLM wrappers) that only need return_dict conversion. Wrapper models should not use @capture_outputs to avoid nested hook chains.

What changes per model

  • output_attentions, output_hidden_states, return_dict dropped from forward signatures, replaced by **kwargs: Unpack[TransformersKwargs]
  • Explicit parameter resolution lines removed
  • Manual all_hidden_states += (hidden_states,) collection loops removed
  • Decoder layers return a single tensor instead of a tuple
  • Attention modules always return (attn_output, attn_weights) — the if not output_attentions: attn_weights = None guard is removed since hooks capture directly from the module output

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu
Copy link
Contributor

vasqu commented Mar 5, 2026

run-slow: dinov3_convnext,dinov3_vit,zamba,layoutlmv2

@github-actions
Copy link
Contributor

github-actions bot commented Mar 5, 2026

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/dinov3_convnext", "models/dinov3_vit", "models/layoutlmv2", "models/zamba"]
quantizations: []

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small self-review to point out some special things


# 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache", "past_key_values", "cache_params"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed as encoder oftentimes share the same attention module from the decoder, meaning that if we pass the cache there everything gets messy

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, I don't get why we need this suddenly? Even encoder may want the cache with EncoderDecoder cache no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the encoder module itself never wants a cache, it's only forwarded once and then "saved" for each subsequent step. Only the decoder needs the cache to properly overwrite states.

If the encoder also gets the cache, then it can update the cache as well which makes generate being broken for certain methods (not sure which anymore but CI was broken for a few tests on bart then).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha yes, cause you made sure to propagate kwargs now as well! Makes sense then!

super().__init__()
self.config = config
self.layers = nn.ModuleList([Aimv2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a wrong pattern, this is likely not needed in many cases - I tried to remove them whenever I found them

Copy link
Member

@zucchini-nlp zucchini-nlp Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know if this is used anywhere when in training code, but I see that we set it to True after enabling GC. Could we split deletion of GC (if not needed) to its own PR?

# Apply it on the top-level module in case the top-level modules supports it
# for example, LongT5Stack inherits from `PreTrainedModel`.
if hasattr(self, "gradient_checkpointing"):
self._gradient_checkpointing_func = gradient_checkpointing_func
self.gradient_checkpointing = enable
is_gradient_checkpointing_set = True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not needed because the GC layer handles this internally now

class GradientCheckpointingLayer(nn.Module):
"""Base class for layers with gradient checkpointing.
This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled
(`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is
enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`.
Important:
When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states)
must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients.
Example:
```python
>>> # Correct - hidden_states passed as positional arg
>>> out = self.layer(hidden_states, attention_mask=attention_mask)
>>> # Incorrect - hidden_states passed as keyword arg
>>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask)
```
"""
gradient_checkpointing = False

Agree tho that it probably makes more sense within its own PR, lemme revert these here and open a big PR for this instead

Comment on lines +425 to +437
class Dinov2Encoder(Dinov2PreTrainedModel):
def __init__(self, config: Dinov2Config):
super().__init__(config)
self.layer = nn.ModuleList([Dinov2Layer(config) for _ in range(config.num_hidden_layers)])
self.post_init()

@merge_with_config_defaults
@capture_outputs(tie_last_hidden_states=False)
def forward(self, hidden_states: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> BaseModelOutput:
for layer_module in self.layer:
hidden_states = layer_module(hidden_states)

return BaseModelOutput(last_hidden_state=hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new change to have this be the "collector", it allows us to have one entry point and not duplicate efforts at their parent classes. We often had weird backbone structures so some got a bit of a more refactor, see conversion mapping for those that seen the most changes

```"""
if output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states
kwargs["output_hidden_states"] = True # required to extract layers for the stages
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also as a heads up, adding comments in the backbone utils but this is properly wrapped due to the mixin

Comment on lines +169 to +176
@functools.wraps(forward_function)
def wrapper(self, *args, **kwargs):
output_hidden_states = kwargs.get("output_hidden_states", getattr(self.config, "output_hidden_states", False))
output = forward_function(self, *args, **kwargs)
if not output_hidden_states:
filtered_output_data = {k: v for k, v in output.items() if k not in ("hidden_states")}
output = type(output)(**filtered_output_data)
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the new wrapper to make backbones behave like expected: They always output hidden states so we control this here.

output = type(output)(**filtered_output_data) is a bit weird but it allows us to construct our modeling outputs properly as there is no delete function and I don't think we want one

Comment on lines +207 to +211
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)

if "forward" in cls.__dict__:
cls.forward = can_return_tuple(filter_output_hidden_states(cls.forward))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return tuple to guarantee the proper dict type within the filter decorator

Comment on lines +71 to +76
"dinov3_convnext": [WeightRenaming(r"(?<!model\.)stages", r"model.stages")],
"dinov3_vit": [WeightRenaming(r"layer_scale", r"scale"), WeightRenaming(r"(?<!model\.)layer", r"model.layer")],
"zamba": [
WeightRenaming(r"layers.(\d+).mamba(?!_decoder)", r"layers.\1.mamba_decoder.mamba"),
WeightRenaming(r"layers.(\d+).input_layernorm", r"layers.\1.mamba_decoder.input_layernorm"),
],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the "special" models

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's test all with slow CI before merging

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was one step ahead of you #43590 (comment) 😂

Comment on lines +129 to +132
"Kosmos2TextTransformer",
"Kosmos2VisionTransformer",
"Kosmos2_5TextTransformer",
"XCLIPVisionTransformer",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same ish issue as CLIP and similar models which do (at least) one wrapper too much

@github-actions
Copy link
Contributor

github-actions bot commented Mar 5, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 1287530f workflow commit (merge commit)
PR 433f8170 branch commit (from PR)
main e498b5bd base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huge PR! Can we trigger slow tests since there are some models that were refactored along the way?

Also would be great to check-out if the GC attribute is needed for some BC behavior, because it's being used in modeling files and the "source-of-truth" Llama also has the attribute

super().__init__()
self.config = config
self.layers = nn.ModuleList([Aimv2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
Copy link
Member

@zucchini-nlp zucchini-nlp Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know if this is used anywhere when in training code, but I see that we set it to True after enabling GC. Could we split deletion of GC (if not needed) to its own PR?

# Apply it on the top-level module in case the top-level modules supports it
# for example, LongT5Stack inherits from `PreTrainedModel`.
if hasattr(self, "gradient_checkpointing"):
self._gradient_checkpointing_func = gradient_checkpointing_func
self.gradient_checkpointing = enable
is_gradient_checkpointing_set = True

Comment on lines -938 to 897
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this deleted?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bad pattern from previous models, the embedding module already handles this

# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

Comment on lines -988 to -995
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here 👀

oh actually, let's also run slow CI with important model list before merging

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the qwens etc? Or any preference on models (outside the ones I already checked on that one run-slow)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd prefer models that were refactored here + a couple multimodal and backbones, those one always return hidden states and re-use it further

Comment on lines -366 to -382
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
past_key_values (`Cache`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
cache in the correct position and to infer the complete sequence length.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@auto_docstring missing i believe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parent modules already have proper auto docstrings, it doesn't really make sense to have docs on the layers (which have the same signature mostly) and are not really user facing

context_layer = torch.transpose(context_layer, 1, 2)

# this is just for visualizing; forward pass doesn't depend on following code
if output_attentions:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too many lines changed, let's slow test this model

else:
self.mlp = DINOv3ViTMLP(config)
self.layer_scale2 = DINOv3ViTLayerScale(config)
self.scale2 = DINOv3ViTLayerScale(config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we completely refactoring the model, I see layer names changed! Let's slow test as well

init.zeros_(module.position_embeddings)


class DPTViTEncoder(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also to slow test

Comment on lines +118 to +127
elif (
hasattr(module, "_get_pos_embed_values")
and hasattr(module, "feat_shape")
and module.feat_shape is not None
):
module.pos_embed = module._get_pos_embed_values(
feat_shape=module.feat_shape,
device=module.pos_embed.device if module.pos_embed is not None else None,
dtype=module.pos_embed.dtype if module.pos_embed is not None else torch.float32,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it not supposed to be covered on timm-side when calling init_non_persistent_buffers? Prob we should ask for a fix from timm team

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure tbh, this was already there when I took over 👀

Comment on lines +71 to +76
"dinov3_convnext": [WeightRenaming(r"(?<!model\.)stages", r"model.stages")],
"dinov3_vit": [WeightRenaming(r"layer_scale", r"scale"), WeightRenaming(r"(?<!model\.)layer", r"model.layer")],
"zamba": [
WeightRenaming(r"layers.(\d+).mamba(?!_decoder)", r"layers.\1.mamba_decoder.mamba"),
WeightRenaming(r"layers.(\d+).input_layernorm", r"layers.\1.mamba_decoder.input_layernorm"),
],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's test all with slow CI before merging

@vasqu
Copy link
Contributor

vasqu commented Mar 6, 2026

run-slow: aimv2,align,altclip,apertus,aria,audio_spectrogram_transformer,audioflamingo3,autoformer,aya_vision,bamba,bart,bert,bert_generation,big_bird,bigbird_pegasus,biogpt,blenderbot,blenderbot_small,blip,bridgetower,bros,camembert,chameleon,chinese_clip,clap,clip,clipseg,clvp,cohere,cohere2,cohere2_vision,colpali,convbert,convnext,convnextv2,data2vec,decision_transformer,deit,dia,dinov2,dinov2_with_registers,dinov3_convnext,dinov3_vit,dpt,electra,eomt_dinov3,ernie,ernie4_5_vl_moe,falcon_h1,fast_vlm,florence2,fuyu,gemma3n,git,glm46v,glm4v,glm4v_moe,glm_image,glm_ocr,got_ocr2,gpt_bigcode,gpt_neox,granite,groupvit,idefics,idefics2,idefics3,ijepa,informer,instructblipvideo,internvl,janus,kosmos2,kosmos2_5,layoutlm,layoutlmv2,layoutlmv3,lightglue,lighton_ocr,llava,llava_next,llava_next_video,llava_onevision,lw_detr,m2m_100,marian,markuplm,mbart,metaclip_2,mistral3,mobilebert,musicgen,musicgen_melody,nemotron,nemotron_h,opt,ovis2,owlv2,owlvit,paddleocr_vl,paligemma,pegasus,pegasus_x,perception_lm,persimmon,phi,phi4_multimodal,pixio,pixtral,plbart,pp_doclayout_v2,prompt_depth_anything,qwen2_5_omni,qwen2_5_vl,qwen2_audio,qwen2_vl,qwen3_5,qwen3_5_moe,qwen3_omni_moe,qwen3_vl,qwen3_vl_moe,roberta,roberta_prelayernorm,roc_bert,sam,siglip,siglip2,smolvlm,speech_to_text,splinter,stablelm,time_series_transformer,timesfm,timesfm2_5,timm_wrapper,video_llama_3,video_llava,videomae,vipllava,vit,vit_mae,vit_msn,vitpose_backbone,vivit,vjepa2,voxtral,voxtral_realtime,whisper,x_clip,xglm,xlm_roberta,xlm_roberta_xl,xlstm,xmod,yolos,zamba,zamba2

@github-actions
Copy link
Contributor

github-actions bot commented Mar 6, 2026

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/aimv2", "models/align", "models/altclip", "models/apertus", "models/aria", "models/audio_spectrogram_transformer", "models/audioflamingo3", "models/autoformer", "models/aya_vision", "models/bamba", "models/bart", "models/bert", "models/bert_generation", "models/big_bird", "models/bigbird_pegasus", "models/biogpt", "models/blenderbot", "models/blenderbot_small", "models/blip", "models/bridgetower", "models/bros", "models/camembert", "models/chameleon", "models/chinese_clip", "models/clap", "models/clip", "models/clipseg", "models/clvp", "models/cohere", "models/cohere2", "models/cohere2_vision", "models/colpali", "models/convbert", "models/convnext", "models/convnextv2", "models/data2vec", "models/decision_transformer", "models/deit", "models/dia", "models/dinov2", "models/dinov2_with_registers", "models/dinov3_convnext", "models/dinov3_vit", "models/dpt", "models/electra", "models/eomt_dinov3", "models/ernie", "models/ernie4_5_vl_moe", "models/falcon_h1", "models/fast_vlm", "models/florence2", "models/fuyu", "models/gemma3n", "models/git", "models/glm46v", "models/glm4v", "models/glm4v_moe", "models/glm_image", "models/glm_ocr", "models/got_ocr2", "models/gpt_bigcode", "models/gpt_neox", "models/granite", "models/groupvit", "models/idefics", "models/idefics2", "models/idefics3", "models/ijepa", "models/informer", "models/instructblipvideo", "models/internvl", "models/janus", "models/kosmos2", "models/kosmos2_5", "models/layoutlm", "models/layoutlmv2", "models/layoutlmv3", "models/lightglue", "models/lighton_ocr", "models/llava", "models/llava_next", "models/llava_next_video", "models/llava_onevision", "models/lw_detr", "models/m2m_100", "models/marian", "models/markuplm", "models/mbart", "models/metaclip_2", "models/mistral3", "models/mobilebert", "models/musicgen", "models/musicgen_melody", "models/nemotron", "models/nemotron_h", "models/opt", "models/ovis2", "models/owlv2", "models/owlvit", "models/paddleocr_vl", "models/paligemma", "models/pegasus", "models/pegasus_x", "models/perception_lm", "models/persimmon", "models/phi", "models/phi4_multimodal", "models/pixio", "models/pixtral", "models/plbart", "models/pp_doclayout_v2", "models/prompt_depth_anything", "models/qwen2_5_omni", "models/qwen2_5_vl", "models/qwen2_audio", "models/qwen2_vl", "models/qwen3_5", "models/qwen3_5_moe", "models/qwen3_omni_moe", "models/qwen3_vl", "models/qwen3_vl_moe", "models/roberta", "models/roberta_prelayernorm", "models/roc_bert", "models/sam", "models/siglip", "models/siglip2", "models/smolvlm", "models/speech_to_text", "models/splinter", "models/stablelm", "models/time_series_transformer", "models/timesfm", "models/timesfm2_5", "models/timm_wrapper", "models/video_llama_3", "models/video_llava", "models/videomae", "models/vipllava", "models/vit", "models/vit_mae", "models/vit_msn", "models/vitpose_backbone", "models/vivit", "models/vjepa2", "models/voxtral", "models/voxtral_realtime", "models/whisper", "models/x_clip", "models/xglm", "models/xlm_roberta", "models/xlm_roberta_xl", "models/xlstm", "models/xmod", "models/yolos", "models/zamba", "models/zamba2"]
quantizations: []

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, just reviewed the most critical parts, but in general I don't think we should need any structure change/conversion! It's only supposed to make code more readable, let's not overcomplicate IMO by adding conversions that should not be needed!


# 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache", "past_key_values", "cache_params"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, I don't get why we need this suddenly? Even encoder may want the cache with EncoderDecoder cache no?

Comment on lines +207 to +210
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)

if "forward" in cls.__dict__:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not put it on individual models as decorators instead? Would be less surprising probably no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea true, I was struggling a bit to make things properly here and this is the final solution in the end. Lemme change it to being explicit decorators

Comment on lines -725 to +703
class ZambaHybridLayer(GradientCheckpointingLayer):
def __init__(self, shared_transf: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer):
class ZambaMixedLayer(GradientCheckpointingLayer):
def __init__(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change the name here?? Let's keep the same for BC

Comment on lines +861 to +836
else:
layers.append(mamba)
layers.append(ZambaMixedLayer(shared_transformer, linear, mamba))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we change the weight structure here? Should not be necessary for capture is it?

Comment on lines -205 to -206
self.config = config
self.stages = nn.ModuleList([DINOv3ConvNextStage(config, stage_idx) for stage_idx in range(config.num_stages)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, it should not be necessary to change the structure?

Comment on lines -255 to +266
self.stages = nn.ModuleList([DINOv3ConvNextStage(config, s) for s in range(config.num_stages)])
self.model = DINOv3ConvNextEncoder(config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = DINOv3ViTAttention(config)
self.layer_scale1 = DINOv3ViTLayerScale(config)
self.scale1 = DINOv3ViTLayerScale(config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we change that???

Comment on lines -479 to +499
self.layer = nn.ModuleList([DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)])
self.model = DINOv3ViTEncoder(config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, should not be needed I believe!

@github-actions
Copy link
Contributor

github-actions bot commented Mar 6, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 5a789ed0 workflow commit (merge commit)
PR 433f8170 branch commit (from PR)
main 4f91111b base commit (on main)

⚠️ Model CI failed to report results

The test failure analysis could not be completed. Please check the workflow run for details.

@github-actions
Copy link
Contributor

github-actions bot commented Mar 6, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: aimv2, align, altclip, apertus, aria, audio_spectrogram_transformer, audioflamingo3, autoformer, aya_vision, bamba, bart, beit, bert, bert_generation, big_bird, bigbird_pegasus

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants