Skip to content

Commit 3bb1b48

Browse files
authored
Standardize audio embedding function name for audio multimodal models (#40919)
* Standardize audio embedding function name for audio multimodal models * PR review
1 parent 58e13b9 commit 3bb1b48

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

src/transformers/models/voxtral/modeling_voxtral.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# limitations under the License.
2121

2222
import math
23+
import warnings
2324
from typing import Callable, Optional, Union
2425

2526
import torch
@@ -431,7 +432,7 @@ def set_decoder(self, decoder):
431432
def get_decoder(self):
432433
return self.language_model.get_decoder()
433434

434-
def get_audio_embeds(self, input_features: torch.FloatTensor):
435+
def get_audio_features(self, input_features: torch.FloatTensor):
435436
"""
436437
This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector.
437438
Args:
@@ -452,6 +453,12 @@ def get_audio_embeds(self, input_features: torch.FloatTensor):
452453
audio_embeds = self.multi_modal_projector(audio_hidden_states)
453454
return audio_embeds
454455

456+
def get_audio_embeds(self, input_features: torch.FloatTensor):
457+
warnings.warn(
458+
"The method `get_audio_embeds` is deprecated. Please use `get_audio_features` instead.", FutureWarning
459+
)
460+
return self.get_audio_features(input_features)
461+
455462
@can_return_tuple
456463
@auto_docstring
457464
def forward(
@@ -505,7 +512,7 @@ def forward(
505512
inputs_embeds = self.get_input_embeddings()(input_ids)
506513

507514
if input_features is not None and input_ids is not None:
508-
audio_embeds = self.get_audio_embeds(input_features)
515+
audio_embeds = self.get_audio_features(input_features)
509516

510517
# replace text-audio token placeholders with audio embeddings
511518
audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)

src/transformers/models/voxtral/modular_voxtral.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import warnings
1617
from typing import Optional, Union
1718

1819
import torch
@@ -166,7 +167,7 @@ def set_decoder(self, decoder):
166167
def get_decoder(self):
167168
return self.language_model.get_decoder()
168169

169-
def get_audio_embeds(self, input_features: torch.FloatTensor):
170+
def get_audio_features(self, input_features: torch.FloatTensor):
170171
"""
171172
This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector.
172173
Args:
@@ -187,6 +188,12 @@ def get_audio_embeds(self, input_features: torch.FloatTensor):
187188
audio_embeds = self.multi_modal_projector(audio_hidden_states)
188189
return audio_embeds
189190

191+
def get_audio_embeds(self, input_features: torch.FloatTensor):
192+
warnings.warn(
193+
"The method `get_audio_embeds` is deprecated. Please use `get_audio_features` instead.", FutureWarning
194+
)
195+
return self.get_audio_features(input_features)
196+
190197
@can_return_tuple
191198
@auto_docstring
192199
def forward(
@@ -240,7 +247,7 @@ def forward(
240247
inputs_embeds = self.get_input_embeddings()(input_ids)
241248

242249
if input_features is not None and input_ids is not None:
243-
audio_embeds = self.get_audio_embeds(input_features)
250+
audio_embeds = self.get_audio_features(input_features)
244251

245252
# replace text-audio token placeholders with audio embeddings
246253
audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)

0 commit comments

Comments
 (0)