-
Notifications
You must be signed in to change notification settings - Fork 32k
Fix slow test_moshika_greedy_unconditional_fp16 #39251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
run-slow: moshi |
|
This comment contains run-slow, running the specified jobs: models: ['models/moshi'] |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
eustlb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update @manueldeprada, still some work to come to fix moshi...
Is this PR still relevant concerning changes to cache_utils.py and generation/utils.py ? If not can you solve merging conflicts?
| def _prepare_attention_mask_for_generation( | ||
| self, | ||
| input_ids: torch.LongTensor, | ||
| inputs_tensor: torch.Tensor, | ||
| generation_config: GenerationConfig, | ||
| kwargs: dict[str, Any], | ||
| model_kwargs: dict[str, Any], | ||
| ) -> torch.LongTensor: | ||
| pad_token_id = generation_config.pad_token_id | ||
| eos_token_id = generation_config.eos_token_id | ||
|
|
||
| default_attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) | ||
| if pad_token_id is None: | ||
| return default_attention_mask | ||
|
|
||
| is_pad_token_in_inputs = (pad_token_id is not None) and torch.isin(input_ids, pad_token_id).any() | ||
| is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~torch.isin( | ||
| eos_token_id, pad_token_id | ||
| ).any() | ||
| can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id | ||
| attention_mask_from_padding = input_ids.ne(pad_token_id).long() | ||
|
|
||
| attention_mask = ( | ||
| attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask | ||
| return super()._prepare_attention_mask_for_generation( | ||
| inputs_tensor=inputs_tensor, | ||
| generation_config=generation_config, | ||
| model_kwargs={}, | ||
| ) | ||
| return attention_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we just remove _prepare_attention_mask_for_generation override here?
| slicing = torch.arange(max_cache_len, device=value_states.device) | ||
| current_seq_len = cache_position[-1] + 1 # Use last position to determine current length | ||
| to_shift = current_seq_len > max_cache_len | ||
| to_shift = current_seq_len >= max_cache_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will break other models no?
| # - different models have a different cache name expected by the model (default = "past_key_values") | ||
| # - `max_length`, prepared above, is used to determine the maximum cache length | ||
| max_cache_length = generation_config.max_length - 1 | ||
| max_cache_length = generation_config.max_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will break other models no?
Fix #38725
Coming from #38725, previously, e18f233 attempted to fix the default attention mask issue that appeated with #34464, but it was still failing the slow test
tests/models/moshi/test_modeling_moshi.py::MoshiIntegrationTests::test_moshika_greedy_unconditional_fp16History from git bisect:
transformers/src/transformers/modeling_utils.py
Line 1413 in 84a6789
transformers/src/transformers/generation/utils.py
Line 2090 in 36bf1d2
transformers/src/transformers/cache_utils.py
Line 1740 in 1b22290
Setting cache_implementation="dynamic" makes the test pass, but the sliding window cache should not behave different. I believe this is due to the depth decoder being window 8 by default, but audio is confusing to me.
This PR is not a fix: the modeling code should be changed to accomodate what I highlight in the diff
cc @eustlb @ydshieh