-
Notifications
You must be signed in to change notification settings - Fork 31.6k
fix issue with logit processor during beam search in Flax #29636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix issue with logit processor during beam search in Flax #29636
Conversation
|
Hi @giganttheo 👋 Thank you for opening the PR! The fix looks reasonable to me. However, if the fix is indeed correct, I wonder how our code could be running correctly before 🤔 I have three small requests:
|
|
I think that not many people have encountered this issue before, since most flax logits processors do not really use the About your requests:
from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-small", dtype="bfloat16")
input_text = "translate English to French: hello how are you? hello how are you? hello how are you? hello how are you? hello how are you?"
input_ids = tokenizer(input_text, return_tensors="np").input_ids
decoder_start_token_id=model.config.decoder_start_token_id
outputs = model.generate(input_ids=input_ids, num_beams=2, decoder_start_token_id=decoder_start_token_id, no_repeat_ngram_size=2)
outputs.sequences, tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)without the change, it gives: (Array([[ 0, 21845, 6, 1670, 3, 6738, 18, 3249, 58,
21845, 6, 1670, 3, 6738, 18, 3249, 58, 1,
0, 0]], dtype=int32),
['Bonjour, comment êtes-vous? Bonjour, comment êtes-vous?'])For instance the 2-gram with the change, it prompts: (Array([[ 0, 21845, 6, 1670, 3, 6738, 18, 3249, 58,
21845, 3, 15, 17, 1670, 58, 1, 0, 0,
0, 0]], dtype=int32),
['Bonjour, comment êtes-vous? Bonjour et comment?'])there is no 2-gram repetition For reference, with torch and 2-gram blocking, the model output is: tensor([ 0, 21845, 6, 1670, 3, 6738, 18, 3249, 58, 21845,
3, 15, 17, 1670, 327, 58, 1])
Bonjour, comment êtes-vous? Bonjour et comment vous?
Show============================= test session starts ==============================
platform linux -- Python 3.9.18, pytest-8.1.1, pluggy-1.4.0 -- /home/gigant/miniconda3/envs/transformers-dev/bin/python
cachedir: .pytest_cache
rootdir: /home/gigant/Documents/transformers_fix/transformers
configfile: pyproject.toml
collected 52 items
tests/models/bart/test_modeling_flax_bart.py::BartHeadTests::test_lm_forward PASSED [ 1%]
tests/models/bart/test_modeling_flax_bart.py::BartHeadTests::test_lm_uneven_forward PASSED [ 3%]
tests/models/bart/test_modeling_flax_bart.py::BartHeadTests::test_question_answering_forward PASSED [ 5%]
tests/models/bart/test_modeling_flax_bart.py::BartHeadTests::test_sequence_classification_forward PASSED [ 7%]
tests/models/bart/test_modeling_flax_bart.py::BartHeadTests::test_shift_tokens_right PASSED [ 9%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_attention_outputs <- tests/test_modeling_flax_common.py PASSED [ 11%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_beam_search_generate <- tests/generation/test_flax_utils.py PASSED [ 13%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_beam_search_generate_attn_mask <- tests/generation/test_flax_utils.py PASSED [ 15%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_beam_search_generate_logits_warper <- tests/generation/test_flax_utils.py PASSED [ 17%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_beam_search_generate_num_return_sequences <- tests/generation/test_flax_utils.py PASSED [ 19%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_checkpoint_sharding_from_hub <- tests/test_modeling_flax_common.py PASSED [ 21%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_checkpoint_sharding_local <- tests/test_modeling_flax_common.py PASSED [ 23%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_cnn_summarization_same_as_fairseq PASSED [ 25%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_decode PASSED [ 26%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_default_params_dtype <- tests/test_modeling_flax_common.py PASSED [ 28%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_encode PASSED [ 30%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_equivalence_flax_to_pt <- tests/test_modeling_flax_common.py PASSED [ 32%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_equivalence_pt_to_flax <- tests/test_modeling_flax_common.py PASSED [ 34%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_forward_signature <- tests/test_modeling_flax_common.py PASSED [ 36%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_from_pretrained_save_pretrained <- tests/test_modeling_flax_common.py PASSED [ 38%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_from_pretrained_with_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 40%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_from_sharded_pt <- tests/test_modeling_flax_common.py PASSED [ 42%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_gradient_checkpointing <- tests/test_modeling_flax_common.py PASSED [ 44%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_greedy_generate <- tests/generation/test_flax_utils.py PASSED [ 46%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_greedy_generate_attn_mask <- tests/generation/test_flax_utils.py PASSED [ 48%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_greedy_generate_logits_warper <- tests/generation/test_flax_utils.py PASSED [ 50%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_greedy_generate_pt_fx <- tests/generation/test_flax_utils.py PASSED [ 51%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_headmasking <- tests/test_modeling_flax_common.py PASSED [ 53%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_hidden_states_output <- tests/test_modeling_flax_common.py PASSED [ 55%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_jit_compilation <- tests/test_modeling_flax_common.py PASSED [ 57%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_load_with_mismatched_shapes <- tests/test_modeling_flax_common.py PASSED [ 59%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_model_from_pretrained PASSED [ 61%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_model_main_input_name <- tests/test_modeling_flax_common.py PASSED [ 63%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_model_outputs_equivalence <- tests/test_modeling_flax_common.py PASSED [ 65%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_naming_convention <- tests/test_modeling_flax_common.py PASSED [ 67%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 69%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_sample_generate <- tests/generation/test_flax_utils.py PASSED [ 71%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_sample_generate_attn_mask <- tests/generation/test_flax_utils.py PASSED [ 73%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_sample_generate_logits_warper <- tests/generation/test_flax_utils.py PASSED [ 75%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_bf16_to_base_pt <- tests/test_modeling_flax_common.py PASSED [ 76%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_from_base <- tests/test_modeling_flax_common.py PASSED [ 78%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_from_base_pt <- tests/test_modeling_flax_common.py PASSED [ 80%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_in_bf16 <- tests/test_modeling_flax_common.py PASSED [ 82%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_in_fp16 <- tests/test_modeling_flax_common.py PASSED [ 84%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_to_base <- tests/test_modeling_flax_common.py PASSED [ 86%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_to_base_pt <- tests/test_modeling_flax_common.py PASSED [ 88%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_summarization_fast PASSED [ 90%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_to_bf16 <- tests/test_modeling_flax_common.py PASSED [ 92%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_to_fp16 <- tests/test_modeling_flax_common.py PASSED [ 94%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_to_fp32 <- tests/test_modeling_flax_common.py PASSED [ 96%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_use_cache_forward PASSED [ 98%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_use_cache_forward_with_attn_mask PASSED [100%]
=============================== warnings summary ===============================
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/_pytest/config/__init__.py:1439
/home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/_pytest/config/__init__.py:1439: PytestConfigWarning: Unknown config option: doctest_glob
self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")
tests/models/bart/test_modeling_flax_bart.py: 371 warnings
/home/gigant/Documents/transformers_fix/transformers/tests/test_modeling_flax_common.py:795: DeprecationWarning: Please use assertEqual instead.
self.assertEquals(type_, jnp.float32, msg=f"param {name} is not initialized in fp32.")
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_equivalence_flax_to_pt
/home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/transformers/modeling_flax_pytorch_utils.py:460: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================= 52 passed, 373 warnings in 166.31s (0:02:46) =================
Show======================================== test session starts ========================================
platform linux -- Python 3.9.18, pytest-8.1.1, pluggy-1.4.0 -- /home/gigant/miniconda3/envs/transformers-dev/bin/python
cachedir: .pytest_cache
rootdir: /home/gigant/Documents/transformers_fix/transformers
configfile: pyproject.toml
collected 86 items
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_attention_outputs <- tests/test_modeling_flax_common.py PASSED [ 1%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_beam_search_generate <- tests/generation/test_flax_utils.py PASSED [ 2%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_beam_search_generate_attn_mask <- tests/generation/test_flax_utils.py PASSED [ 3%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_beam_search_generate_logits_warper <- tests/generation/test_flax_utils.py PASSED [ 4%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_beam_search_generate_num_return_sequences <- tests/generation/test_flax_utils.py PASSED [ 5%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_checkpoint_sharding_from_hub <- tests/test_modeling_flax_common.py PASSED [ 6%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_checkpoint_sharding_local <- tests/test_modeling_flax_common.py PASSED [ 8%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_config PASSED [ 9%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_decode PASSED [ 10%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_default_params_dtype <- tests/test_modeling_flax_common.py PASSED [ 11%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_encode PASSED [ 12%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_equivalence_flax_to_pt <- tests/test_modeling_flax_common.py PASSED [ 13%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_equivalence_pt_to_flax <- tests/test_modeling_flax_common.py PASSED [ 15%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_forward_signature <- tests/test_modeling_flax_common.py PASSED [ 16%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_from_pretrained_save_pretrained <- tests/test_modeling_flax_common.py PASSED [ 17%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_from_pretrained_with_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 18%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_from_sharded_pt <- tests/test_modeling_flax_common.py PASSED [ 19%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_gradient_checkpointing <- tests/test_modeling_flax_common.py PASSED [ 20%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_greedy_generate <- tests/generation/test_flax_utils.py PASSED [ 22%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_greedy_generate_attn_mask <- tests/generation/test_flax_utils.py PASSED [ 23%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_greedy_generate_logits_warper <- tests/generation/test_flax_utils.py PASSED [ 24%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_greedy_generate_pt_fx <- tests/generation/test_flax_utils.py PASSED [ 25%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_headmasking <- tests/test_modeling_flax_common.py PASSED [ 26%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_hidden_states_output <- tests/test_modeling_flax_common.py PASSED [ 27%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_jit_compilation <- tests/test_modeling_flax_common.py PASSED [ 29%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_load_with_mismatched_shapes <- tests/test_modeling_flax_common.py PASSED [ 30%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_model PASSED [ 31%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_model_main_input_name <- tests/test_modeling_flax_common.py PASSED [ 32%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_model_outputs_equivalence <- tests/test_modeling_flax_common.py PASSED [ 33%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_model_v1_1 PASSED [ 34%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_naming_convention <- tests/test_modeling_flax_common.py PASSED [ 36%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 37%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_sample_generate <- tests/generation/test_flax_utils.py PASSED [ 38%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_sample_generate_attn_mask <- tests/generation/test_flax_utils.py PASSED [ 39%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_sample_generate_logits_warper <- tests/generation/test_flax_utils.py PASSED [ 40%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_bf16_to_base_pt PASSED [ 41%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_from_base PASSED [ 43%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_from_base_pt PASSED [ 44%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_in_bf16 <- tests/test_modeling_flax_common.py PASSED [ 45%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_in_fp16 <- tests/test_modeling_flax_common.py PASSED [ 46%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_to_base PASSED [ 47%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_to_base_pt PASSED [ 48%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_shift_right PASSED [ 50%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_to_bf16 <- tests/test_modeling_flax_common.py PASSED [ 51%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_to_fp16 <- tests/test_modeling_flax_common.py PASSED [ 52%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_to_fp32 <- tests/test_modeling_flax_common.py PASSED [ 53%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_use_cache_forward_with_attn_mask PASSED [ 54%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_attention_outputs <- tests/test_modeling_flax_common.py PASSED [ 55%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_checkpoint_sharding_from_hub <- tests/test_modeling_flax_common.py PASSED [ 56%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_checkpoint_sharding_local <- tests/test_modeling_flax_common.py PASSED [ 58%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_config PASSED [ 59%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_default_params_dtype <- tests/test_modeling_flax_common.py PASSED [ 60%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_encode PASSED [ 61%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_equivalence_flax_to_pt <- tests/test_modeling_flax_common.py PASSED [ 62%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_equivalence_pt_to_flax <- tests/test_modeling_flax_common.py PASSED [ 63%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_forward_signature <- tests/test_modeling_flax_common.py PASSED [ 65%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_from_pretrained_save_pretrained <- tests/test_modeling_flax_common.py PASSED [ 66%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_from_pretrained_with_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 67%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_from_sharded_pt <- tests/test_modeling_flax_common.py PASSED [ 68%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_gradient_checkpointing <- tests/test_modeling_flax_common.py PASSED [ 69%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_headmasking <- tests/test_modeling_flax_common.py PASSED [ 70%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_hidden_states_output <- tests/test_modeling_flax_common.py PASSED [ 72%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_jit_compilation <- tests/test_modeling_flax_common.py PASSED [ 73%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_load_with_mismatched_shapes <- tests/test_modeling_flax_common.py PASSED [ 74%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_model PASSED [ 75%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_model_main_input_name <- tests/test_modeling_flax_common.py PASSED [ 76%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_model_outputs_equivalence <- tests/test_modeling_flax_common.py PASSED [ 77%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_model_v1_1 PASSED [ 79%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_naming_convention <- tests/test_modeling_flax_common.py PASSED [ 80%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 81%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_bf16_to_base_pt PASSED [ 82%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_from_base PASSED [ 83%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_from_base_pt PASSED [ 84%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_in_bf16 <- tests/test_modeling_flax_common.py PASSED [ 86%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_in_fp16 <- tests/test_modeling_flax_common.py PASSED [ 87%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_to_base PASSED [ 88%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_to_base_pt PASSED [ 89%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_to_bf16 <- tests/test_modeling_flax_common.py PASSED [ 90%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_to_fp16 <- tests/test_modeling_flax_common.py PASSED [ 91%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_to_fp32 <- tests/test_modeling_flax_common.py PASSED [ 93%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelIntegrationTests::test_small_byt5_integration_test SKIPPED [ 94%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelIntegrationTests::test_small_generation SKIPPED [ 95%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelIntegrationTests::test_small_generation_bfloat16 SKIPPED [ 96%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelIntegrationTests::test_small_integration_test SKIPPED [ 97%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelIntegrationTests::test_small_v1_1_integration_test SKIPPED [ 98%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelIntegrationTests::test_summarization SKIPPED [100%]
========================================= warnings summary ==========================================
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/_pytest/config/__init__.py:1439
/home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/_pytest/config/__init__.py:1439: PytestConfigWarning: Unknown config option: doctest_glob
self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")
tests/models/t5/test_modeling_flax_t5.py: 113 warnings
/home/gigant/Documents/transformers_fix/transformers/tests/test_modeling_flax_common.py:795: DeprecationWarning: Please use assertEqual instead.
self.assertEquals(type_, jnp.float32, msg=f"param {name} is not initialized in fp32.")
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_equivalence_flax_to_pt
/home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/transformers/modeling_flax_pytorch_utils.py:460: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================== 80 passed, 6 skipped, 115 warnings in 65.67s (0:01:05) =======================
Show======================================== test session starts ========================================
platform linux -- Python 3.9.18, pytest-8.1.1, pluggy-1.4.0 -- /home/gigant/miniconda3/envs/transformers-dev/bin/python
cachedir: .pytest_cache
rootdir: /home/gigant/Documents/transformers_fix/transformers
configfile: pyproject.toml
collected 74 items
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_attention_outputs <- tests/test_modeling_flax_common.py PASSED [ 1%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_checkpoint_sharding_from_hub <- tests/test_modeling_flax_common.py PASSED [ 2%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_checkpoint_sharding_local <- tests/test_modeling_flax_common.py PASSED [ 4%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_config PASSED [ 5%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_default_params_dtype <- tests/test_modeling_flax_common.py PASSED [ 6%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_encoder_sinusoidal_embed_positions PASSED [ 8%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_equivalence_flax_to_pt <- tests/test_modeling_flax_common.py PASSED [ 9%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_equivalence_pt_to_flax <- tests/test_modeling_flax_common.py PASSED [ 10%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_forward_signature PASSED [ 12%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_from_pretrained_save_pretrained <- tests/test_modeling_flax_common.py PASSED [ 13%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_from_pretrained_with_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 14%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_from_sharded_pt <- tests/test_modeling_flax_common.py PASSED [ 16%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_gradient_checkpointing <- tests/test_modeling_flax_common.py PASSED [ 17%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_headmasking <- tests/test_modeling_flax_common.py PASSED [ 18%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_hidden_states_output <- tests/test_modeling_flax_common.py PASSED [ 20%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_jit_compilation PASSED [ 21%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_load_with_mismatched_shapes <- tests/test_modeling_flax_common.py PASSED [ 22%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_model_main_input_name <- tests/test_modeling_flax_common.py PASSED [ 24%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_model_outputs_equivalence <- tests/test_modeling_flax_common.py PASSED [ 25%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_naming_convention <- tests/test_modeling_flax_common.py PASSED [ 27%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 28%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_bf16_to_base_pt PASSED [ 29%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_from_base PASSED [ 31%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_from_base_pt PASSED [ 32%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_in_bf16 <- tests/test_modeling_flax_common.py PASSED [ 33%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_in_fp16 <- tests/test_modeling_flax_common.py PASSED [ 35%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_to_base PASSED [ 36%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_to_base_pt PASSED [ 37%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_to_bf16 <- tests/test_modeling_flax_common.py PASSED [ 39%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_to_fp16 <- tests/test_modeling_flax_common.py PASSED [ 40%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_to_fp32 <- tests/test_modeling_flax_common.py PASSED [ 41%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_batched_generation FAILED [ 43%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_generation PASSED [ 44%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_generation_multilingual FAILED [ 45%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_logits_librispeech FAILED [ 47%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_small_en_logits_librispeech FAILED [ 48%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_en_batched_generation PASSED [ 50%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_en_generation FAILED [ 51%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_generation PASSED [ 52%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_logits_librispeech PASSED [ 54%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_timestamp_generation FAILED [ 55%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_attention_outputs <- tests/test_modeling_flax_common.py PASSED [ 56%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_checkpoint_sharding_from_hub <- tests/test_modeling_flax_common.py PASSED [ 58%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_checkpoint_sharding_local <- tests/test_modeling_flax_common.py PASSED [ 59%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_config PASSED [ 60%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_default_params_dtype <- tests/test_modeling_flax_common.py PASSED [ 62%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_equivalence_flax_to_pt <- tests/test_modeling_flax_common.py PASSED [ 63%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_equivalence_pt_to_flax <- tests/test_modeling_flax_common.py PASSED [ 64%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_forward_signature PASSED [ 66%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_from_pretrained_save_pretrained <- tests/test_modeling_flax_common.py PASSED [ 67%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_from_pretrained_with_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 68%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_from_sharded_pt <- tests/test_modeling_flax_common.py PASSED [ 70%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_gradient_checkpointing <- tests/test_modeling_flax_common.py PASSED [ 71%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_headmasking <- tests/test_modeling_flax_common.py PASSED [ 72%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_hidden_states_output <- tests/test_modeling_flax_common.py PASSED [ 74%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_inputs_embeds PASSED [ 75%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_jit_compilation PASSED [ 77%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_load_with_mismatched_shapes <- tests/test_modeling_flax_common.py PASSED [ 78%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_model_common_attributes PASSED [ 79%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_model_main_input_name <- tests/test_modeling_flax_common.py PASSED [ 81%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_model_outputs_equivalence <- tests/test_modeling_flax_common.py PASSED [ 82%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_naming_convention <- tests/test_modeling_flax_common.py PASSED [ 83%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 85%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_resize_tokens_embeddings PASSED [ 86%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_bf16_to_base_pt PASSED [ 87%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_from_base PASSED [ 89%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_from_base_pt PASSED [ 90%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_in_bf16 <- tests/test_modeling_flax_common.py PASSED [ 91%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_in_fp16 <- tests/test_modeling_flax_common.py PASSED [ 93%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_to_base PASSED [ 94%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_to_base_pt PASSED [ 95%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_to_bf16 <- tests/test_modeling_flax_common.py PASSED [ 97%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_to_fp16 <- tests/test_modeling_flax_common.py PASSED [ 98%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_to_fp32 <- tests/test_modeling_flax_common.py PASSED [100%]
============================================= FAILURES ==============================================
___________________ FlaxWhisperModelIntegrationTest.test_large_batched_generation ___________________
self = <tests.models.whisper.test_modeling_flax_whisper.FlaxWhisperModelIntegrationTest testMethod=test_large_batched_generation>
def test_large_batched_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)
input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features
generated_ids = model.generate(input_features, max_length=20).sequences
# fmt: off
EXPECTED_LOGITS = np.array(
[
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
]
)
# fmt: on
> self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
E AssertionError: False is not true
tests/models/whisper/test_modeling_flax_whisper.py:613: AssertionError
--------------------------------------- Captured stderr call ----------------------------------------
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
________________ FlaxWhisperModelIntegrationTest.test_large_generation_multilingual _________________
self = <fsspec.implementations.http.HTTPFileSystem object at 0x7f4b0227d8e0>
url = 'https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/ja.tar.gz'
kwargs = {}, info = {}, session = <aiohttp.client.ClientSession object at 0x7f4b028e0250>
policy = 'get'
async def _info(self, url, **kwargs):
"""Get info of URL
Tries to access location via HEAD, and then GET methods, but does
not fetch the data.
It is possible that the server does not supply any size information, in
which case size will be given as None (and certain operations on the
corresponding file will not work).
"""
info = {}
session = await self.set_session()
for policy in ["head", "get"]:
try:
info.update(
> await _file_info(
self.encode_url(url),
size_policy=policy,
session=session,
**self.kwargs,
**kwargs,
)
)
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/implementations/http.py:419:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/implementations/http.py:832: in _file_info
r.raise_for_status()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <ClientResponse(https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-202...e': 'application/xml', 'Transfer-Encoding': 'chunked', 'Date': 'Mon, 18 Mar 2024 18:09:08 GMT', 'Server': 'AmazonS3')>
def raise_for_status(self) -> None:
if not self.ok:
# reason should always be not None for a started response
assert self.reason is not None
self.release()
> raise ClientResponseError(
self.request_info,
self.history,
status=self.status,
message=self.reason,
headers=self.headers,
)
E aiohttp.client_exceptions.ClientResponseError: 403, message='Forbidden', url=URL('https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/ja.tar.gz')
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/aiohttp/client_reqrep.py:1060: ClientResponseError
The above exception was the direct cause of the following exception:
self = <tests.models.whisper.test_modeling_flax_whisper.FlaxWhisperModelIntegrationTest testMethod=test_large_generation_multilingual>
def test_large_generation_multilingual(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)
ds = load_dataset("common_voice", "ja", split="test", streaming=True)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
> input_speech = next(iter(ds))["audio"]["array"]
tests/models/whisper/test_modeling_flax_whisper.py:566:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/iterable_dataset.py:1388: in __iter__
for key, example in ex_iterable:
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/iterable_dataset.py:234: in __iter__
yield from self.generate_examples_fn(**self.kwargs)
../../../.cache/huggingface/modules/datasets_modules/datasets/common_voice/220833898d6a60c50f621126e51fb22eb2dfe5244392c70dccd8e6e2f055f4bf/common_voice.py:774: in _generate_examples
for path, f in archive_iterator:
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/download/streaming_download_manager.py:869: in __iter__
yield from self.generator(*self.args, **self.kwargs)
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/download/streaming_download_manager.py:922: in _iter_from_urlpath
with xopen(urlpath, "rb", download_config=download_config, block_size=0) as f:
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/download/streaming_download_manager.py:512: in xopen
file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open()
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/core.py:135: in open
return self.__enter__()
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/core.py:103: in __enter__
f = self.fs.open(self.path, mode=mode)
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/spec.py:1293: in open
f = self._open(
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/implementations/http.py:358: in _open
size = size or self.info(path, **kwargs)["size"]
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/asyn.py:118: in wrapper
return sync(self.loop, func, *args, **kwargs)
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/asyn.py:103: in sync
raise return_result
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/asyn.py:56: in _runner
result[0] = await coro
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <fsspec.implementations.http.HTTPFileSystem object at 0x7f4b0227d8e0>
url = 'https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/ja.tar.gz'
kwargs = {}, info = {}, session = <aiohttp.client.ClientSession object at 0x7f4b028e0250>
policy = 'get'
async def _info(self, url, **kwargs):
"""Get info of URL
Tries to access location via HEAD, and then GET methods, but does
not fetch the data.
It is possible that the server does not supply any size information, in
which case size will be given as None (and certain operations on the
corresponding file will not work).
"""
info = {}
session = await self.set_session()
for policy in ["head", "get"]:
try:
info.update(
await _file_info(
self.encode_url(url),
size_policy=policy,
session=session,
**self.kwargs,
**kwargs,
)
)
if info.get("size") is not None:
break
except Exception as exc:
if policy == "get":
# If get failed, then raise a FileNotFoundError
> raise FileNotFoundError(url) from exc
E FileNotFoundError: https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/ja.tar.gz
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/implementations/http.py:432: FileNotFoundError
--------------------------------------- Captured stderr call ----------------------------------------
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
___________________ FlaxWhisperModelIntegrationTest.test_large_logits_librispeech ___________________
self = <tests.models.whisper.test_modeling_flax_whisper.FlaxWhisperModelIntegrationTest testMethod=test_large_logits_librispeech>
def test_large_logits_librispeech(self):
model = FlaxWhisperModel.from_pretrained("openai/whisper-large", from_pt=True)
input_speech = self._load_datasamples(1)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
processed_inputs = processor(
audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="np"
)
input_features = processed_inputs.input_features
decoder_input_ids = processed_inputs.labels
logits = model(
input_features,
decoder_input_ids=decoder_input_ids,
output_hidden_states=False,
output_attentions=False,
return_dict=False,
)
> logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T
E KeyError: 'model'
tests/models/whisper/test_modeling_flax_whisper.py:492: KeyError
--------------------------------------- Captured stderr call ----------------------------------------
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
_________________ FlaxWhisperModelIntegrationTest.test_small_en_logits_librispeech __________________
self = <tests.models.whisper.test_modeling_flax_whisper.FlaxWhisperModelIntegrationTest testMethod=test_small_en_logits_librispeech>
def test_small_en_logits_librispeech(self):
model = FlaxWhisperModel.from_pretrained("openai/whisper-small.en", from_pt=True)
input_speech = self._load_datasamples(1)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="np").input_features
> logits = model(
input_features,
decoder_input_ids=np.array([model.config.decoder_start_token_id]),
output_hidden_states=False,
output_attentions=False,
return_dict=False,
)
tests/models/whisper/test_modeling_flax_whisper.py:451:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <transformers.models.whisper.modeling_flax_whisper.FlaxWhisperModel object at 0x7f4b2e3ee0a0>
input_features = array([[[ 1.1933082e-01, -9.4576120e-02, -1.0977852e-01, ...,
-8.0602670e-01, -8.0602670e-01, -8.0602670e-01]...70e-01, -8.0602670e-01, -8.0602670e-01, ...,
-8.0602670e-01, -8.0602670e-01, -8.0602670e-01]]], dtype=float32)
decoder_input_ids = array([50257]), attention_mask = None, decoder_attention_mask = None
position_ids = None, decoder_position_ids = None, output_attentions = False
output_hidden_states = False, return_dict = False, train = False, params = None, dropout_rng = None
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
def __call__(
self,
input_features: jnp.ndarray,
decoder_input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# prepare decoder inputs
if decoder_position_ids is None:
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
else:
> batch_size, sequence_length = decoder_input_ids.shape
E ValueError: not enough values to unpack (expected 2, got 1)
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/transformers/models/whisper/modeling_flax_whisper.py:1161: ValueError
--------------------------------------- Captured stderr call ----------------------------------------
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
______________________ FlaxWhisperModelIntegrationTest.test_tiny_en_generation ______________________
self = <tests.models.whisper.test_modeling_flax_whisper.FlaxWhisperModelIntegrationTest testMethod=test_tiny_en_generation>
def test_tiny_en_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
model.config.decoder_start_token_id = 50257
input_speech = self._load_datasamples(1)
input_features = processor.feature_extractor(
raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"
).input_features
generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences
transcript = processor.tokenizer.decode(generated_ids[0])
EXPECTED_TRANSCRIPT = (
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
" classes and we are glad to"
)
> self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
E AssertionError: '<|st[14 chars]t|><|notimestamps|> Mr. Quilter is the apostle[84 chars]xt|>' != '<|st[14 chars]t|><|en|><|transcribe|><|notimestamps|> Mr. Qu[57 chars]d to'
E - <|startoftranscript|><|notimestamps|> Mr. Quilter is the apostle of the middle classes,<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
E + <|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes and we are glad to
tests/models/whisper/test_modeling_flax_whisper.py:523: AssertionError
__________________ FlaxWhisperModelIntegrationTest.test_tiny_timestamp_generation ___________________
self = <tests.models.whisper.test_modeling_flax_whisper.FlaxWhisperModelIntegrationTest testMethod=test_tiny_timestamp_generation>
@slow
def test_tiny_timestamp_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
input_speech = np.concatenate(self._load_datasamples(4))
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="jax").input_features
generate_fn = jax.jit(functools.partial(model.generate, max_length=448, return_timestamps=True))
generated_ids = generate_fn(input_features)
EXPECTED_OUTPUT = np.array([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257]) # fmt: skip
> self.assertTrue(np.allclose(generated_ids, EXPECTED_OUTPUT))
tests/models/whisper/test_modeling_flax_whisper.py:675:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/numpy/core/numeric.py:2241: in allclose
res = all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = FlaxGreedySearchOutput(sequences=Array([[50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391,
307,...257, 50257, 50257, 50257, 50257, 50257, 50257,
50257, 50257, 50257, 50257, 50257, 50257, 50257]], dtype=int32))
b = array([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391,
307, 264, 50244, 295, 264, 2808,... 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281,
934, 439, 11, 293, 51836, 51836, 50257])
rtol = 1e-05, atol = 1e-08, equal_nan = False
@array_function_dispatch(_isclose_dispatcher)
def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
"""
Returns a boolean array where two arrays are element-wise equal within a
tolerance.
The tolerance values are positive, typically very small numbers. The
relative difference (`rtol` * abs(`b`)) and the absolute difference
`atol` are added together to compare against the absolute difference
between `a` and `b`.
.. warning:: The default `atol` is not appropriate for comparing numbers
that are much smaller than one (see Notes).
Parameters
----------
a, b : array_like
Input arrays to compare.
rtol : float
The relative tolerance parameter (see Notes).
atol : float
The absolute tolerance parameter (see Notes).
equal_nan : bool
Whether to compare NaN's as equal. If True, NaN's in `a` will be
considered equal to NaN's in `b` in the output array.
Returns
-------
y : array_like
Returns a boolean array of where `a` and `b` are equal within the
given tolerance. If both `a` and `b` are scalars, returns a single
boolean value.
See Also
--------
allclose
math.isclose
Notes
-----
.. versionadded:: 1.7.0
For finite values, isclose uses the following equation to test whether
two floating point values are equivalent.
absolute(`a` - `b`) <= (`atol` + `rtol` * absolute(`b`))
Unlike the built-in `math.isclose`, the above equation is not symmetric
in `a` and `b` -- it assumes `b` is the reference value -- so that
`isclose(a, b)` might be different from `isclose(b, a)`. Furthermore,
the default value of atol is not zero, and is used to determine what
small values should be considered close to zero. The default value is
appropriate for expected values of order unity: if the expected values
are significantly smaller than one, it can result in false positives.
`atol` should be carefully selected for the use case at hand. A zero value
for `atol` will result in `False` if either `a` or `b` is zero.
`isclose` is not defined for non-numeric data types.
`bool` is considered a numeric data-type for this purpose.
Examples
--------
>>> np.isclose([1e10,1e-7], [1.00001e10,1e-8])
array([ True, False])
>>> np.isclose([1e10,1e-8], [1.00001e10,1e-9])
array([ True, True])
>>> np.isclose([1e10,1e-8], [1.0001e10,1e-9])
array([False, True])
>>> np.isclose([1.0, np.nan], [1.0, np.nan])
array([ True, False])
>>> np.isclose([1.0, np.nan], [1.0, np.nan], equal_nan=True)
array([ True, True])
>>> np.isclose([1e-8, 1e-7], [0.0, 0.0])
array([ True, False])
>>> np.isclose([1e-100, 1e-7], [0.0, 0.0], atol=0.0)
array([False, False])
>>> np.isclose([1e-10, 1e-10], [1e-20, 0.0])
array([ True, True])
>>> np.isclose([1e-10, 1e-10], [1e-20, 0.999999e-10], atol=0.0)
array([False, True])
"""
def within_tol(x, y, atol, rtol):
with errstate(invalid='ignore'), _no_nep50_warning():
return less_equal(abs(x-y), atol + rtol * abs(y))
x = asanyarray(a)
y = asanyarray(b)
# Make sure y is an inexact type to avoid bad behavior on abs(MIN_INT).
# This will cause casting of x later. Also, make sure to allow subclasses
# (e.g., for numpy.ma).
# NOTE: We explicitly allow timedelta, which used to work. This could
# possibly be deprecated. See also gh-18286.
# timedelta works if `atol` is an integer or also a timedelta.
# Although, the default tolerances are unlikely to be useful
if y.dtype.kind != "m":
dt = multiarray.result_type(y, 1.)
y = asanyarray(y, dtype=dt)
> xfin = isfinite(x)
E TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/numpy/core/numeric.py:2348: TypeError
--------------------------------------- Captured stderr call ----------------------------------------
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
========================================= warnings summary ==========================================
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/_pytest/config/__init__.py:1439
/home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/_pytest/config/__init__.py:1439: PytestConfigWarning: Unknown config option: doctest_glob
self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")
tests/models/whisper/test_modeling_flax_whisper.py: 219 warnings
/home/gigant/Documents/transformers_fix/transformers/tests/test_modeling_flax_common.py:795: DeprecationWarning: Please use assertEqual instead.
self.assertEquals(type_, jnp.float32, msg=f"param {name} is not initialized in fp32.")
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_encoder_sinusoidal_embed_positions
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_encoder_sinusoidal_embed_positions
/home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/transformers/models/whisper/modeling_flax_whisper.py:72: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype)
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_equivalence_flax_to_pt
/home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/transformers/modeling_flax_pytorch_utils.py:460: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_batched_generation
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_generation
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_logits_librispeech
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_small_en_logits_librispeech
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_en_batched_generation
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_en_generation
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_generation
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_logits_librispeech
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_timestamp_generation
/home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/load.py:1461: FutureWarning: The repository for hf-internal-testing/librispeech_asr_dummy contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/hf-internal-testing/librispeech_asr_dummy
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
warnings.warn(
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_batched_generation
/home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/librosa/core/intervals.py:8: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
from pkg_resources import resource_filename
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_generation_multilingual
/home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/load.py:1461: FutureWarning: The repository for common_voice contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/common_voice
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
warnings.warn(
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_generation_multilingual
/home/gigant/.cache/huggingface/modules/datasets_modules/datasets/common_voice/220833898d6a60c50f621126e51fb22eb2dfe5244392c70dccd8e6e2f055f4bf/common_voice.py:634: FutureWarning:
This version of the Common Voice dataset is deprecated.
You can download the latest one with
>>> load_dataset("mozilla-foundation/common_voice_11_0", "en")
warnings.warn(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================== short test summary info ======================================
FAILED tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_batched_generation - AssertionError: False is not true
FAILED tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_generation_multilingual - FileNotFoundError: https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazon...
FAILED tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_logits_librispeech - KeyError: 'model'
FAILED tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_small_en_logits_librispeech - ValueError: not enough values to unpack (expected 2, got 1)
FAILED tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_en_generation - AssertionError: '<|st[14 chars]t|><|notimestamps|> Mr. Quilter is the apostle[84 chars]xt|>' != ...
FAILED tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_timestamp_generation - TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safel...
====================== 6 failed, 68 passed, 235 warnings in 491.73s (0:08:11) ======================= |
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@giganttheo thank you for the detailed explanations 💛
The same 6 slow tests are failing on main, so they are not a result of this PR (cc @sanchit-gandhi)
amyeroberts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
What does this PR do?
Fixes #29635