Skip to content
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/granite.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,8 @@ print(tokenizer.decode(outputs[0], skip_special_tokens=True))

[[autodoc]] GraniteForCausalLM
- forward

## GraniteForSequenceClassification

[[autodoc]] GraniteForSequenceClassification
- forward
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/granitemoe.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,8 @@ This model was contributed by [mayank-mishra](https://huggingface.co/mayank-mish

[[autodoc]] GraniteMoeForCausalLM
- forward

## GraniteMoeForSequenceClassification

[[autodoc]] GraniteMoeForSequenceClassification
- forward
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/granitemoehybrid.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,8 @@ This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co

[[autodoc]] GraniteMoeHybridForCausalLM
- forward

## GraniteMoeHybridForSequenceClassification

[[autodoc]] GraniteMoeHybridForSequenceClassification
- forward
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/granitemoeshared.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,8 @@ This HF implementation is contributed by [Mayank Mishra](https://huggingface.co/

[[autodoc]] GraniteMoeSharedForCausalLM
- forward

## GraniteMoeSharedForSequenceClassification

[[autodoc]] GraniteMoeSharedForSequenceClassification
- forward
4 changes: 4 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,10 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("gpt_neox", "GPTNeoXForSequenceClassification"),
("gpt_oss", "GptOssForSequenceClassification"),
("gptj", "GPTJForSequenceClassification"),
("granite", "GraniteForSequenceClassification"),
("granitemoe", "GraniteMoeForSequenceClassification"),
("granitemoehybrid", "GraniteMoeHybridForSequenceClassification"),
("granitemoeshared", "GraniteMoeSharedForSequenceClassification"),
("helium", "HeliumForSequenceClassification"),
("hunyuan_v1_dense", "HunYuanDenseV1ForSequenceClassification"),
("hunyuan_v1_moe", "HunYuanMoEV1ForSequenceClassification"),
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
Expand Down Expand Up @@ -588,4 +588,8 @@ def forward(
)


__all__ = ["GraniteForCausalLM", "GraniteModel", "GranitePreTrainedModel"]
class GraniteForSequenceClassification(GenericForSequenceClassification, GranitePreTrainedModel):
pass


__all__ = ["GraniteForCausalLM", "GraniteForSequenceClassification", "GraniteModel", "GranitePreTrainedModel"]
7 changes: 6 additions & 1 deletion src/transformers/models/granite/modular_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ...cache_utils import Cache, DynamicCache
from ...masking_utils import create_causal_mask
from ...modeling_layers import GenericForSequenceClassification
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, logging
Expand Down Expand Up @@ -276,4 +277,8 @@ def forward(
)


__all__ = ["GraniteForCausalLM", "GraniteModel", "GranitePreTrainedModel"]
class GraniteForSequenceClassification(GenericForSequenceClassification, GranitePreTrainedModel):
pass


__all__ = ["GraniteForCausalLM", "GraniteForSequenceClassification", "GraniteModel", "GranitePreTrainedModel"]
13 changes: 11 additions & 2 deletions src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
Expand Down Expand Up @@ -741,4 +741,13 @@ def forward(
)


__all__ = ["GraniteMoeForCausalLM", "GraniteMoeModel", "GraniteMoePreTrainedModel"]
class GraniteMoeForSequenceClassification(GenericForSequenceClassification, GraniteMoePreTrainedModel):
pass


__all__ = [
"GraniteMoeForCausalLM",
"GraniteMoeForSequenceClassification",
"GraniteMoeModel",
"GraniteMoePreTrainedModel",
]
12 changes: 11 additions & 1 deletion src/transformers/models/granitemoe/modular_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...masking_utils import create_causal_mask
from ...modeling_layers import GenericForSequenceClassification
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
Expand Down Expand Up @@ -323,4 +324,13 @@ def forward(
)


__all__ = ["GraniteMoeForCausalLM", "GraniteMoeModel", "GraniteMoePreTrainedModel"]
class GraniteMoeForSequenceClassification(GenericForSequenceClassification, GraniteMoePreTrainedModel):
pass


__all__ = [
"GraniteMoeForCausalLM",
"GraniteMoeForSequenceClassification",
"GraniteMoeModel",
"GraniteMoePreTrainedModel",
]
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...integrations.hub_kernels import lazy_load_kernel
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
Expand Down Expand Up @@ -1588,4 +1588,13 @@ def prepare_inputs_for_generation(
return model_inputs


__all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"]
class GraniteMoeHybridForSequenceClassification(GenericForSequenceClassification, GraniteMoeHybridPreTrainedModel):
pass


__all__ = [
"GraniteMoeHybridForCausalLM",
"GraniteMoeHybridForSequenceClassification",
"GraniteMoeHybridModel",
"GraniteMoeHybridPreTrainedModel",
]
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ... import initialization as init
from ...cache_utils import Cache
from ...masking_utils import create_causal_mask
from ...modeling_layers import GenericForSequenceClassification
from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
Expand Down Expand Up @@ -359,4 +360,13 @@ def prepare_inputs_for_generation(
return model_inputs


__all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"]
class GraniteMoeHybridForSequenceClassification(GenericForSequenceClassification, GraniteMoeHybridPreTrainedModel):
pass


__all__ = [
"GraniteMoeHybridForCausalLM",
"GraniteMoeHybridForSequenceClassification",
"GraniteMoeHybridModel",
"GraniteMoeHybridPreTrainedModel",
]
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
Expand Down Expand Up @@ -810,4 +810,13 @@ def forward(
)


__all__ = ["GraniteMoeSharedForCausalLM", "GraniteMoeSharedModel", "GraniteMoeSharedPreTrainedModel"]
class GraniteMoeSharedForSequenceClassification(GenericForSequenceClassification, GraniteMoeSharedPreTrainedModel):
pass


__all__ = [
"GraniteMoeSharedForCausalLM",
"GraniteMoeSharedForSequenceClassification",
"GraniteMoeSharedModel",
"GraniteMoeSharedPreTrainedModel",
]
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ...activations import ACT2FN
from ...cache_utils import Cache
from ...modeling_layers import GenericForSequenceClassification
from ...processing_utils import Unpack
from ...utils import logging
from ..granitemoe.modeling_granitemoe import (
Expand Down Expand Up @@ -153,4 +154,13 @@ def __init__(self, config: GraniteMoeSharedConfig):
self.post_init()


__all__ = ["GraniteMoeSharedForCausalLM", "GraniteMoeSharedModel", "GraniteMoeSharedPreTrainedModel"]
class GraniteMoeSharedForSequenceClassification(GenericForSequenceClassification, GraniteMoeSharedPreTrainedModel):
pass


__all__ = [
"GraniteMoeSharedForCausalLM",
"GraniteMoeSharedForSequenceClassification",
"GraniteMoeSharedModel",
"GraniteMoeSharedPreTrainedModel",
]
17 changes: 17 additions & 0 deletions tests/models/granite/test_modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from transformers import (
GraniteForCausalLM,
GraniteForSequenceClassification,
GraniteModel,
)

Expand Down Expand Up @@ -140,6 +141,16 @@ def create_and_check_model(
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))

def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = GraniteForSequenceClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
Expand All @@ -161,6 +172,7 @@ class GraniteModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
(
GraniteModel,
GraniteForCausalLM,
GraniteForSequenceClassification,
)
if is_torch_available()
else ()
Expand All @@ -169,6 +181,7 @@ class GraniteModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
{
"feature-extraction": GraniteModel,
"text-generation": GraniteForCausalLM,
"text-classification": GraniteForSequenceClassification,
}
if is_torch_available()
else {}
Expand All @@ -189,6 +202,10 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)


@require_torch_accelerator
class GraniteIntegrationTest(unittest.TestCase):
Expand Down
17 changes: 17 additions & 0 deletions tests/models/granitemoe/test_modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from transformers import (
GraniteMoeForCausalLM,
GraniteMoeForSequenceClassification,
GraniteMoeModel,
)

Expand Down Expand Up @@ -139,6 +140,16 @@ def create_and_check_model(
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))

def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = GraniteMoeForSequenceClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
Expand All @@ -160,6 +171,7 @@ class GraniteMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
(
GraniteMoeModel,
GraniteMoeForCausalLM,
GraniteMoeForSequenceClassification,
)
if is_torch_available()
else ()
Expand All @@ -168,6 +180,7 @@ class GraniteMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
{
"feature-extraction": GraniteMoeModel,
"text-generation": GraniteMoeForCausalLM,
"text-classification": GraniteMoeForSequenceClassification,
}
if is_torch_available()
else {}
Expand All @@ -188,6 +201,10 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)


@require_torch_accelerator
class GraniteMoeIntegrationTest(unittest.TestCase):
Expand Down
Loading