-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[ci-daily] Fix pipeline tests #21257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e256e3c
bb191b9
4c70ec0
7c47fcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,7 +56,7 @@ def rescale_stride(stride, ratio): | |
| return new_strides | ||
|
|
||
|
|
||
| def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, ratio, dtype=None): | ||
| def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None): | ||
| inputs_len = inputs.shape[0] | ||
| step = chunk_len - stride_left - stride_right | ||
| for i in range(0, inputs_len, step): | ||
|
|
@@ -68,9 +68,15 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, | |
| _stride_left = 0 if i == 0 else stride_left | ||
| is_last = i + step + stride_left >= inputs_len | ||
| _stride_right = 0 if is_last else stride_right | ||
|
|
||
| chunk_len = chunk.shape[0] | ||
| stride = (chunk_len, _stride_left, _stride_right) | ||
| if ratio != 1: | ||
| if "input_features" in processed: | ||
| processed_len = processed["input_features"].shape[-1] | ||
| elif "input_values" in processed: | ||
| processed_len = processed["input_values"].shape[-1] | ||
| if processed_len != chunk.shape[-1] and rescale: | ||
| ratio = processed_len / chunk_len | ||
| stride = rescale_stride([stride], ratio)[0] | ||
| if chunk.shape[0] > _stride_left: | ||
| yield {"is_last": is_last, "stride": stride, **processed} | ||
|
|
@@ -101,10 +107,10 @@ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source | |
| sequence = sequence[begin_idx:] | ||
|
|
||
| timestamp_tokens = sequence >= timestamp_begin | ||
| consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 | ||
| last_timestamp = np.where(timestamp_tokens)[0][-1] | ||
| consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive | ||
| if seq_idx != 0: | ||
| if seq_idx != 0 and sum(timestamp_tokens) > 0: | ||
| consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 | ||
| last_timestamp = np.where(timestamp_tokens)[0][-1] | ||
| consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive | ||
|
Comment on lines
-104
to
+113
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This just makes sure that if the model output no timestamps, we just don't throw an error |
||
| time -= stride_left + stride_right | ||
| offset = int((time / feature_extractor.sampling_rate) / time_precision) | ||
| overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision) | ||
|
|
@@ -400,13 +406,12 @@ def _sanitize_parameters( | |
| " only 1 version" | ||
| ) | ||
| forward_params["generate_kwargs"].update(generate_kwargs) | ||
| if return_timestamps is not None: | ||
| forward_params["generate_kwargs"]["return_timestamps"] = return_timestamps | ||
|
|
||
| postprocess_params = {} | ||
| if decoder_kwargs is not None: | ||
| postprocess_params["decoder_kwargs"] = decoder_kwargs | ||
| if return_timestamps is not None: | ||
| forward_params["return_timestamps"] = return_timestamps | ||
| postprocess_params["return_timestamps"] = return_timestamps | ||
| if self.model.config.model_type == "whisper": | ||
| # Whisper is highly specific, if we want timestamps, we need to | ||
|
|
@@ -502,9 +507,10 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warn | |
| if chunk_len < stride_left + stride_right: | ||
| raise ValueError("Chunk length must be superior to stride length") | ||
|
|
||
| rescale = self.type != "seq2seq_whisper" | ||
| # make sure that | ||
| for item in chunk_iter( | ||
| inputs, self.feature_extractor, chunk_len, stride_left, stride_right, align_to, self.torch_dtype | ||
| inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype | ||
| ): | ||
| yield item | ||
| else: | ||
|
|
@@ -520,12 +526,11 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warn | |
| processed["stride"] = stride | ||
| yield {"is_last": True, **processed, **extra} | ||
|
|
||
| def _forward(self, model_inputs, generate_kwargs=None): | ||
| def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding this argument prevent the |
||
| if generate_kwargs is None: | ||
| generate_kwargs = {} | ||
|
|
||
| is_last = model_inputs.pop("is_last") | ||
| return_timestamps = generate_kwargs.pop("return_timestamps", False) | ||
|
|
||
| if self.type == "seq2seq": | ||
| encoder = self.model.get_encoder() | ||
|
|
@@ -635,9 +640,9 @@ def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, retu | |
| # Simply cast from pyctcdecode format to wav2vec2 format to leverage | ||
| # pre-existing code later | ||
| chunk_offset = beams[0][2] | ||
| word_offsets = [] | ||
| offsets = [] | ||
| for word, (start_offset, end_offset) in chunk_offset: | ||
| word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) | ||
| offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) | ||
|
Comment on lines
-638
to
+645
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. normalized the name of this variable |
||
| else: | ||
| skip_special_tokens = self.type != "ctc" | ||
| text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was missing! Fixes the LM tests