diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 1ff6eddea256..3558193b1fcc 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -854,6 +854,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -876,7 +877,8 @@ def forward( cache_position=cache_position, ) hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.score(hidden_states[:, slice_indices, :]) if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2] diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index ad04a4ef5b82..8000b77beb06 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -682,6 +682,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -704,7 +705,8 @@ def forward( cache_position=cache_position, ) hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.score(hidden_states[:, slice_indices, :]) if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2]