From 77482ee8b0a49d2f049fab0015570c703a85dd4d Mon Sep 17 00:00:00 2001 From: Aviral Date: Thu, 2 Oct 2025 01:30:52 +0530 Subject: [PATCH] added logits slicing to BioGpt for seq classifier Signed-off-by: Aviral --- src/transformers/models/biogpt/modeling_biogpt.py | 4 +++- src/transformers/models/biogpt/modular_biogpt.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 348bf2707584..ba7f191fa123 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -871,6 +871,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -894,7 +895,8 @@ def forward( cache_position=cache_position, ) hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.score(hidden_states[:, slice_indices, :]) if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2] diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index accc1bdc7559..b9529a2e4dcf 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -693,6 +693,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -716,7 +717,8 @@ def forward( cache_position=cache_position, ) hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.score(hidden_states[:, slice_indices, :]) if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2]