Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/gemma3.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,8 @@ visualizer("<img>What is shown in this image?")

[[autodoc]] Gemma3ForConditionalGeneration
- forward

## Gemma3ForSequenceClassification

[[autodoc]] Gemma3ForSequenceClassification
- forward
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,7 @@
("funnel", "FunnelForSequenceClassification"),
("gemma", "GemmaForSequenceClassification"),
("gemma2", "Gemma2ForSequenceClassification"),
("gemma3", "Gemma3ForSequenceClassification"),
("glm", "GlmForSequenceClassification"),
("glm4", "Glm4ForSequenceClassification"),
("gpt-sw3", "GPT2ForSequenceClassification"),
Expand Down
91 changes: 90 additions & 1 deletion src/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
Expand Down Expand Up @@ -1212,10 +1212,99 @@ def create_masks_for_generate(
return create_masks_for_generate(**mask_kwargs)


class Gemma3ForSequenceClassification(Gemma3PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Gemma3Model(config)
self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.get_input_embeddings()

def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)

@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> SequenceClassifierOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""

transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
token_type_ids=token_type_ids,
use_cache=use_cache,
**kwargs,
)
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.text_config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.text_config.pad_token_id is None:
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]

loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


__all__ = [
"Gemma3PreTrainedModel",
"Gemma3TextModel",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3Model",
"Gemma3ForSequenceClassification",
]
91 changes: 90 additions & 1 deletion src/transformers/models/gemma3/modular_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
Expand Down Expand Up @@ -1069,6 +1069,94 @@ def create_masks_for_generate(
return create_masks_for_generate(**mask_kwargs)


class Gemma3ForSequenceClassification(Gemma3PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Gemma3Model(config)
self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.get_input_embeddings()

def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)

@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> SequenceClassifierOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""

transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
token_type_ids=token_type_ids,
use_cache=use_cache,
**kwargs,
)
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.text_config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.text_config.pad_token_id is None:
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]

loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


__all__ = [
"Gemma3Config",
"Gemma3TextConfig",
Expand All @@ -1077,4 +1165,5 @@ def create_masks_for_generate(
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3Model",
"Gemma3ForSequenceClassification",
]
10 changes: 10 additions & 0 deletions tests/models/gemma3/test_modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from transformers import (
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
Gemma3ForSequenceClassification,
Gemma3Model,
Gemma3Processor,
Gemma3TextModel,
Expand Down Expand Up @@ -246,6 +247,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
(
Gemma3Model,
Gemma3ForConditionalGeneration,
Gemma3ForSequenceClassification,
)
if is_torch_available()
else ()
Expand Down Expand Up @@ -348,6 +350,14 @@ def test_eager_matches_fa2_generate(self):
def test_initialization(self):
pass

@unittest.skip("Loading nested configs with overwritten `kwargs` isn't supported yet, FIXME @raushan.")
def test_load_with_mismatched_shapes(self):
pass

@unittest.skip("Loading nested configs with overwritten `kwargs` isn't supported yet, FIXME @raushan.")
def test_mismatched_shapes_have_properly_initialized_weights(self):
pass

Comment on lines +353 to +360
Copy link
Member Author

@zucchini-nlp zucchini-nlp Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I mean loading like Gemma3Config.from_dict(config_dict, vocab_size=100) where the vocab_size is actually part of config.text_config. It is a known issue and I have it in my plans to support

def test_automodelforcausallm(self):
"""
Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
Expand Down