Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
246f2d2
commenced work on supporting audio classification task for whisper mo…
adit299 Mar 11, 2023
f80abc7
Merge branch 'main' of https://github.com/adit299/transformers into A…
adit299 Mar 15, 2023
9c4c724
Merge branch 'main' of https://github.com/adit299/transformers into A…
adit299 Mar 24, 2023
6f1524a
Merge branch 'main' of https://github.com/adit299/transformers into A…
adit299 Mar 31, 2023
6c4d269
initial implementation of Whisper Audio Classification in tf finished
adit299 Mar 31, 2023
7033afa
Merge branch 'main' of https://github.com/adit299/transformers into A…
adit299 Apr 3, 2023
5e61c4c
registering whisper audio classification in tf
adit299 Apr 4, 2023
ab91074
Merge branch 'main' of https://github.com/adit299/transformers into A…
adit299 Apr 4, 2023
f4bb390
Merge branch 'huggingface:main' into Add_TensorFlow_Whisper_model_for…
adit299 Apr 18, 2023
d204aea
commencing work on writing tests
adit299 Apr 27, 2023
7fe04e6
modified tests
adit299 May 1, 2023
7943e6a
Merge branch 'main' of https://github.com/adit299/transformers into A…
adit299 May 5, 2023
6afd3a4
attempting to fix issues with tests
adit299 May 5, 2023
874af91
adding dummy_tf_object for whisper model
adit299 May 5, 2023
469e1ef
attempting to fix circleci tests
adit299 May 8, 2023
9e7415f
correcting mistake in previous commit
adit299 May 8, 2023
67eaf01
correcting mistakes
adit299 May 8, 2023
d545856
correcting more mistakes
adit299 May 8, 2023
c0e9320
Merge branch 'main' of https://github.com/adit299/transformers into A…
adit299 May 10, 2023
4907a9a
renamed call function to forward to resolve test error
adit299 May 11, 2023
94cdcd5
attempting to resolve more test errors
adit299 May 11, 2023
24a42fe
addressing review comments
adit299 May 16, 2023
d4360d8
Merge branch 'main' of https://github.com/adit299/transformers into A…
adit299 May 16, 2023
8304422
Merge branch 'main' of https://github.com/adit299/transformers into A…
adit299 May 17, 2023
6f4101d
Addressing review comments and fixing code quality
adit299 May 17, 2023
896d71d
Merge branch 'main' into Add_TensorFlow_Whisper_model_for_audio_class…
adit299 Jun 12, 2023
54b6edd
solving onnx test error
adit299 Jun 18, 2023
2dfcc46
Merge branch 'Add_TensorFlow_Whisper_model_for_audio_classification' …
adit299 Jun 18, 2023
d0f7e22
attempting to fix failing quality tests
adit299 Jun 18, 2023
05920b0
Merge branch 'main' of https://github.com/adit299/transformers into A…
adit299 Jul 17, 2023
f190cb1
Merge branch 'main' of https://github.com/adit299/transformers into A…
adit299 Aug 6, 2023
4f02f9e
resolving test_resize_token_embeddings
adit299 Aug 28, 2023
468a9bf
override dummy_inputs and input_signature methods
adit299 Sep 4, 2023
66fb2b5
removing uneeded file
adit299 Sep 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/whisper.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ The original code can be found [here](https://github.com/openai/whisper).
[[autodoc]] TFWhisperForConditionalGeneration
- call

## TFWhisperForAudioClassification

[[autodoc]] TFWhisperForAudioClassification
- call

## FlaxWhisperModel

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3705,6 +3705,7 @@
_import_structure["models.whisper"].extend(
[
"TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWhisperForAudioClassification",
"TFWhisperForConditionalGeneration",
"TFWhisperModel",
"TFWhisperPreTrainedModel",
Expand Down Expand Up @@ -7124,6 +7125,7 @@
)
from .models.whisper import (
TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWhisperForAudioClassification,
TFWhisperForConditionalGeneration,
TFWhisperModel,
TFWhisperPreTrainedModel,
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,13 @@
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
]
)
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])

TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("wav2vec2", "TFWav2Vec2ForSequenceClassification"),
("whisper", "TFWhisperForAudioClassification"),
]
)

TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
Expand Down Expand Up @@ -461,6 +467,10 @@
)

TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)

TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"TFWhisperForConditionalGeneration",
"TFWhisperModel",
"TFWhisperPreTrainedModel",
"TFWhisperForAudioClassification",
]

try:
Expand Down Expand Up @@ -115,6 +116,7 @@
else:
from .modeling_tf_whisper import (
TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWhisperForAudioClassification,
TFWhisperForConditionalGeneration,
TFWhisperModel,
TFWhisperPreTrainedModel,
Expand Down
100 changes: 99 additions & 1 deletion src/transformers/models/whisper/modeling_tf_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput,
TFSequenceClassifierOutput,
)
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
Expand Down Expand Up @@ -133,7 +134,7 @@ def build(self, input_shape):

def call(self, input_ids, past_key_values_length=0):
past_key_values_length = tf.cast(past_key_values_length, tf.int32)
gather_indices = tf.range(tf.shape(input_ids)[1], delta=1) + past_key_values_length
gather_indices = tf.range(tf.shape(input_ids)[-1], delta=1) + past_key_values_length
return tf.gather(self.weight, gather_indices)


Expand Down Expand Up @@ -627,6 +628,12 @@ def __init__(self, config: WhisperConfig, **kwargs):
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")

self.dropout = tf.keras.layers.Dropout(config.dropout)

def get_input_embeddings(self):
return self.conv1

def set_input_embeddings(self, value):
self.conv1 = value

@unpack_inputs
def call(
Expand Down Expand Up @@ -1599,3 +1606,94 @@ def prepare_inputs_for_generation(
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
}


class TFWhisperForAudioClassification(TFWhisperPreTrainedModel):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM from a functionality point of view! Will leave it to one of the TF maintainers to assist you with the TF-specific semantics, but what you've got here looks complete to me (no need for the freeze encoder method since this is only for fine-tuning, which we don't have in TF)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, great! I still see that there are some tests failing from within TFModelTesterMixin, so I can try to resolve those.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

def __init__(self, config):
super().__init__(config)

self.encoder = TFWhisperEncoder(config)
num_layers = config.num_hidden_layers + 1
if config.use_weighted_layer_sum:
self.layer_weights = tf.Variable(tf.ones(shape=(num_layers,)) / num_layers)
self.projector = tf.keras.layers.Dense(units=config.classifier_proj_size, input_shape=(config.hidden_size,))
self.classifier = tf.keras.layers.Dense(
units=config.num_labels, input_shape=(config.classifier_proj_size,), activation=None
)

def get_input_embeddings(self):
return self.encoder.get_input_embeddings()

def set_input_embeddings(self, value):
self.encoder.set_input_embeddings(value)

@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network.

Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
return {
self.main_input_name: tf.random.uniform(
[1, self.config.num_mel_bins, self.config.max_source_positions * 2 - 1], dtype=tf.float32
),
}

@property
def input_signature(self):
return {
"input_features": tf.TensorSpec((None, self.config.num_mel_bins, None), tf.float32, name="input_features"),
}

@unpack_inputs
def call(
self,
input_features: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[tf.Tensor]]] = None,
labels: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states

if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = tf.stack(encoder_outputs, axis=1)
norm_weights = tf.nn.softmax(self.layer_weights, axis=-1)
hidden_states = tf.reduce_sum(hidden_states * tf.reshape(norm_weights, [-1, 1, 1]), axis=1)
else:
hidden_states = encoder_outputs[0]

hidden_states = self.projector(hidden_states)
pooled_output = tf.reduce_mean(hidden_states, axis=1)

logits = self.classifier(pooled_output)

loss = None

if labels is not None:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss = loss_fn(tf.reshape(labels, [-1]), tf.reshape(logits, [-1, self.config.num_labels]))

if not return_dict:
output = (logits,) + encoder_outputs[1:]
return ((loss,) + output) if loss is not None else output

return TFSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_tf_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2699,6 +2699,13 @@ def __init__(self, *args, **kwargs):
TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = None


class TFWhisperForAudioClassification(metaclass=DummyObject):
_backends = ["tf"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])


class TFWhisperForConditionalGeneration(metaclass=DummyObject):
_backends = ["tf"]

Expand Down
Loading