-
Notifications
You must be signed in to change notification settings - Fork 29
Support Whisper #45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Whisper #45
Changes from 10 commits
c5a899a
b2401ff
6e050d7
62800ff
a108a6f
5106c69
42a9cd2
de1d7f8
ba9644f
462acc0
d71ae4a
e40b0d1
f32e7ac
d6c1718
33d6274
2a3df8a
09fb277
edbb88c
e4b053b
f396043
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,8 @@ | |
| from transformers.generation.configuration_utils import GenerationConfig | ||
|
|
||
| from optimum.utils.import_utils import is_transformers_version | ||
| from transformers import PreTrainedModel, StaticCache, WhisperForConditionalGeneration | ||
| from transformers.generation.configuration_utils import GenerationConfig | ||
|
|
||
| from .utils import save_config_to_constant_methods | ||
|
|
||
|
|
@@ -153,7 +155,7 @@ def __init__(self, encoder_model): | |
| self.config = encoder_model.config | ||
|
|
||
| def forward(self, input_ids): | ||
| return self.encoder(input_ids=input_ids).last_hidden_state | ||
| return self.encoder(input_ids).last_hidden_state | ||
|
|
||
|
|
||
| class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module): | ||
|
|
@@ -168,7 +170,10 @@ def __init__(self, model, max_static_cache_length, batch_size): | |
|
|
||
| # Get the decoder component | ||
| self.decoder = model.get_decoder() | ||
| self.lm_head = model.lm_head | ||
| if isinstance(model, WhisperForConditionalGeneration): | ||
| self.proj_out = model.proj_out | ||
| else: | ||
| self.proj_out = model.lm_head | ||
chmjkb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.config = model.config | ||
|
|
||
| # Initialize static cache | ||
|
|
@@ -195,10 +200,9 @@ def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): | |
| cache_position=cache_position, | ||
| ) | ||
|
|
||
| # Apply language model head | ||
| lm_logits = self.lm_head(outputs[0]) | ||
|
|
||
| return lm_logits | ||
| # Apply linear projection (lm head) to obtain logits | ||
| logits = self.proj_out(outputs[0]) | ||
| return logits | ||
|
|
||
|
|
||
| class Seq2SeqLMExportableModule(torch.nn.Module): | ||
|
|
@@ -240,14 +244,20 @@ def _export_encoder(self, encoder_input_ids): | |
| wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval() | ||
|
|
||
| # Define dynamic sequence length for encoder | ||
| seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length) | ||
| if isinstance(self.full_model, WhisperForConditionalGeneration): | ||
| assert encoder_input_ids.shape == torch.Size( | ||
| [1, 80, 3000] | ||
| ), f"Whisper only accepts a log-mel spectrogram of shape [1, 80, 3000], passed shape: {encoder_input_ids.shape}" | ||
|
||
| dynamic_shapes = None | ||
| else: | ||
chmjkb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length) | ||
| dynamic_shapes = {"input_ids": {1: seq_len_dim}} | ||
|
|
||
| # Export the encoder | ||
| with torch.no_grad(): | ||
| exported_encoder = torch.export.export( | ||
| wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True | ||
| wrapped_encoder, (encoder_input_ids,), dynamic_shapes=dynamic_shapes, strict=True | ||
| ) | ||
|
|
||
| return exported_encoder | ||
|
|
||
| def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position): | ||
|
|
@@ -261,19 +271,23 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi | |
| .eval() | ||
| ) | ||
|
|
||
| # Define dynamic dimension for encoder output sequence length | ||
| encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) | ||
| if isinstance(self.full_model, WhisperForConditionalGeneration): | ||
| dynamic_shapes = None | ||
| else: | ||
chmjkb marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Define dynamic dimension for encoder output sequence length | ||
| encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length) | ||
| dynamic_shapes = { | ||
| "decoder_input_ids": None, | ||
| "encoder_hidden_states": {1: encoder_seq_len_dim}, | ||
| "cache_position": None, | ||
| } | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tugsbayasgalan @pianpwk Can we use
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep Dim.AuTO would be perfect here. Doing so, you don't need the if/else branching.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried changing the code to the following: Unfortunately when I do that, the export fails with the following error: However doing: |
||
|
|
||
| # Export the decoder | ||
| with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): | ||
| exported_decoder = torch.export.export( | ||
| wrapped_decoder, | ||
| (decoder_input_ids, encoder_hidden_states, cache_position), | ||
| dynamic_shapes={ | ||
| "decoder_input_ids": None, | ||
| "encoder_hidden_states": {1: encoder_seq_len_dim}, | ||
| "cache_position": None, | ||
| }, | ||
| dynamic_shapes=dynamic_shapes, | ||
| strict=True, | ||
| ) | ||
|
|
||
|
|
@@ -286,21 +300,26 @@ def export( | |
| encoder_hidden_states=None, | ||
| cache_position=None, | ||
| ) -> Dict[str, ExportedProgram]: | ||
| example_encoder_input_ids = ( | ||
| encoder_input_ids if encoder_input_ids is not None else torch.ones((1, 10), dtype=torch.long) | ||
| ) | ||
| if encoder_input_ids is None: | ||
| if isinstance(self.full_model, WhisperForConditionalGeneration): | ||
| example_encoder_input_ids = torch.rand((1, 80, 3000)) | ||
| else: | ||
| example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long) | ||
| else: | ||
| example_encoder_input_ids = encoder_input_ids | ||
|
|
||
| self.exported_encoder = self._export_encoder(example_encoder_input_ids) | ||
|
|
||
| if not encoder_hidden_states: | ||
| example_encoder_hidden_states = self.exported_encoder.module()(example_encoder_input_ids) | ||
| else: | ||
| example_encoder_hidden_states = encoder_hidden_states | ||
|
|
||
| example_decoder_input_ids = ( | ||
| decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long) | ||
| ) # Start token | ||
| example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) | ||
| example_encoder_hidden_states = ( | ||
| encoder_hidden_states | ||
| if encoder_hidden_states is not None | ||
| else torch.zeros( | ||
| (self.generation_config.cache_config.batch_size, 10, self.config.d_model), dtype=torch.float32 | ||
| ) | ||
| ) | ||
| self.exported_encoder = self._export_encoder(example_encoder_input_ids) | ||
| example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) | ||
|
|
||
| self.exported_decoder = self._export_decoder( | ||
| example_decoder_input_ids, example_encoder_hidden_states, example_cache_position | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.