Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

t5 model ,the inference result are wrong when the batch size > 1 #1847

Closed
1 of 4 tasks
0xd8b opened this issue Jun 26, 2024 · 17 comments
Closed
1 of 4 tasks

t5 model ,the inference result are wrong when the batch size > 1 #1847

0xd8b opened this issue Jun 26, 2024 · 17 comments
Assignees
Labels
bug Something isn't working Investigating

Comments

@0xd8b
Copy link

0xd8b commented Jun 26, 2024

System Info

A100 Tensorrt_llm 0.7.0

Who can help?

@byshiue @sy

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Converted the T5-large model according to the official example, using GPT and BERT plugins with float16 precision. Inference works correctly when batch size is 1.

  2. When batch size > 1, e.g., batch size = 4, we observed that the self-attention results are correct for odd batch indices, but the output of self-attention is all zeros for even batch indices.

  3. We debugged the decoder separately (using the T5 encoder output from HF as the input for the decoder) and found that self-attention in the decoder works correctly. However, in the cross-attention, the results are correct for odd batch indices, but the output is all zeros for even batch indices.

The above phenomenon only occurs when using the BERT and GPT plugins; it does not occur in plain TensorRT mode.

Expected behavior

When the batch size is greater than 1, using the BERT and GPT plugins in the T5 model shows significant abnormalities, where certain dimensions of the attention output are entirely zeros.

actual behavior

When the batch size is greater than 1, using the BERT and GPT plugins in the T5 model shows significant abnormalities, where certain dimensions of the attention output are entirely zeros.

additional notes

When the batch size is greater than 1, inference in the T5 family models behaves abnormally.

@0xd8b 0xd8b added the bug Something isn't working label Jun 26, 2024
@nv-guomingz
Copy link
Collaborator

It seems that you're using a very outdated version (0.7.0), could u please try the latest main branch code?

@0xd8b
Copy link
Author

0xd8b commented Jun 26, 2024

@nv-guomingz Yes, due to historical reasons, we developed on version 0.7.0. We have not seen anyone report this issue in the issues section, and it is possible that this problem still exists in the newer versions. Therefore, we hope to address this problem in version 0.7.0. Have you ever encountered a similar issue?

@nv-guomingz
Copy link
Collaborator

nv-guomingz commented Jun 26, 2024

For me, I can't recall there's such issue for T5 on 0.7.0 version.

Would u please provide us step-by-step instructions for reproducing such issue?

I still suggest you try the latest release whl instead of using 0.7.0 to see if the issue still exists or not.

If so, we'll file a bug for internal tracking and investigating.

@0xd8b
Copy link
Author

0xd8b commented Jun 26, 2024

@nv-guomingz Just now, we used TensorRT LLM version 0.90 and converted T5-large using the official example (example/enc_dec/).

First, we followed the official example to convert the weights to float16.
Then, we used build.py to build the engine with batch_size=4, using the GPT plugin and the BERT plugin, with float16 precision, keeping everything else consistent with the official example.
We used run.py for inference. The inference results are abnormal in the even batch dimensions and correct in the odd batch dimensions. We have identified that the issue lies with the self-attention output in the BERT plugin being abnormal. In the GPT plugin, the self-attention output is normal, but the cross-attention output is abnormal, with the values in the even batch dimensions being all zeros. This might be a bug in the plugin.

@hijkzzz
Copy link
Collaborator

hijkzzz commented Jun 27, 2024

Could you try the latest version TRT_LLM 0.11+
see the tutorial: https://nvidia.github.io/TensorRT-LLM/installation/linux.html

@1096125073
Copy link

i have the same issuse when use gpt_attention plugin
6 A10
2tp 3pp
企业微信截图_17195681105924

@symphonylyh
Copy link
Collaborator

symphonylyh commented Jun 28, 2024

Hi @0xd8b @1096125073 , can you please provide your trt-llm version, runtime type (python or pybind of C++), model name, TP/PP setup, beam search, reproducible input examples (English preferred)? Because on our end we wasn't seeing any issue with BS>1
Examples:
0.10.0, pybind of C++, google/t5-large, TP=1 PP=1, no beam search, ["xxx", "yyy", "zzz"]
And if you can reproduce your issue on TP=1 PP=1, please provide an example under this config -- it's easier for debug

@0xd8b
Copy link
Author

0xd8b commented Jun 28, 2024

sorry, I did not provide detailed information earlier. here is the related information:

  1. First, we fine-tuned the t5-large network without changing the decoder's architecture.

  2. We used float32 precision during training and float16 precision during engine conversion.

  3. The engine conversion configurations are as follows:

    tensorrt_llm versions: 0.7.0 and 0.9.0

--world_size=1
--tp_size=1
--pp_size=1
--gpus_per_node=8
--parallel_build=False
--weight_from_pytorch_ckpt=False
--engine_name="t5-small"
--debug_mode=False
--timing_cache="model.cache"
--model_type="t5"
--dtype="float16"
--logits_dtype="float16"
--log_level="info"
--max_batch_size=4
--max_encoder_input_len=1500
--max_decoder_input_len=1
--max_output_len=200
--max_beam_width=1
--use_bert_attention_plugin="float16"
--use_gpt_attention_plugin="float16"
--use_gemm_plugin="float16"
--use_layernorm_plugin=False
--use_rmsnorm_plugin=False
--use_lookup_plugin=False
--enable_qk_half_accum=False
--builder_opt=None
--remove_input_padding=False
--random_seed=None
--use_parallel_embedding=False
--embedding_sharding_dim=0
--use_custom_all_reduce=False
--strongly_typed=True
--gather_all_token_logits=False

  • We constructed an encoder_output with the shape [1, 545, 1024], using float16, and repeated it four times along the batch dimension, resulting in an encoder_output with the shape [4, 545, 1024] as input for the decoder. The final decoder output predictions for batch=0 and batch=2 were the same, and batch=1 and batch=3 were the same. As mentioned above, the attention layer output is all zeros in even dimensions. Additionally, we modified the C++ code mentioned here in versions 0.7.0 and 0.9.0: Flan t5 xxl result large difference #1343.

@0xd8b
Copy link
Author

0xd8b commented Jun 30, 2024

The BERT plugin has a parameter: relative_attention_bias: Tensor = None
The relative attention bias can have the shape [num_heads, max_seq_len, max_seq_len], or the relative attention embedding table for implicit mode, [num_heads, num_buckets].

We passed a pre-constructed relative_attention_bias with the shape [num_heads, max_seq_len, max_seq_len], which resulted in the attention calculation outputting all zeros on even layers. This issue does not occur in the T5 example in example/enc because it uses the implicit mode.

@symphonylyh
Copy link
Collaborator

The BERT plugin has a parameter: relative_attention_bias: Tensor = None The relative attention bias can have the shape [num_heads, max_seq_len, max_seq_len], or the relative attention embedding table for implicit mode, [num_heads, num_buckets].

We passed a pre-constructed relative_attention_bias with the shape [num_heads, max_seq_len, max_seq_len], which resulted in the attention calculation outputting all zeros on even layers. This issue does not occur in the T5 example in example/enc because it uses the implicit mode.

This makes more sense. Have you tried run the T5 example w/o the implicit mode? We have tested this before in earlier versions. If you can reproduce on the T5-explicit mode, please let us know

@1096125073
Copy link

Hi @0xd8b @1096125073 , can you please provide your trt-llm version, runtime type (python or pybind of C++), model name, TP/PP setup, beam search, reproducible input examples (English preferred)? Because on our end we wasn't seeing any issue with BS>1 Examples: 0.10.0, pybind of C++, google/t5-large, TP=1 PP=1, no beam search, ["xxx", "yyy", "zzz"] And if you can reproduce your issue on TP=1 PP=1, please provide an example under this config -- it's easier for debug

yes
0.9.0, pybind of C++, Private model, similar to llama,use gpt_attention plugin, TP=2 PP=3, no beam search
for example input: "how are you?" with batch size 4
the outputs are:
image

But when I tried on four A100 cards, the result was correct(4tp 1pp)
this is so wired.

@symphonylyh
Copy link
Collaborator

@1096125073 you case is a different issue. Actually it is EXPECTED. For pybind of C++ runtime, we didn't support PP yet. It's only TP support. Because we haven't seen much use cases of enc-dec models using PP for deployment.

If this is really needed for your case, would you mind open a new issue and raise this feature request?

@1096125073
Copy link

@1096125073 you case is a different issue. Actually it is EXPECTED. For pybind of C++ runtime, we didn't support PP yet. It's only TP support. Because we haven't seen much use cases of enc-dec models using PP for deployment.

If this is really needed for your case, would you mind open a new issue and raise this feature request?

hi,thanks a lot,i will open a new issue later.
Actually, I first used the Triton Backend to find inconsistent output before using run.py

@0xd8b
Copy link
Author

0xd8b commented Jul 1, 2024

I've found the problem; it's due to the data type of the operation data. In the Encoder and Decoder, there is a parameter called encoder_input_lengths. The documentation of the function did not specify the exact data type required for this parameter, so we did not pay attention to it. We constructed this variable using encoder_input_lengths = torch.sum(attention_mask, dim=-1), but the default data type for the torch.sum method is torch.int64. However, the requirement in TensorRT is torch.int32. This data type issue is not a problem when the batch_size is 1, but it causes the phenomenon I described earlier when the batch_size is greater than 1.

@1096125073
Copy link

@1096125073 you case is a different issue. Actually it is EXPECTED. For pybind of C++ runtime, we didn't support PP yet. It's only TP support. Because we haven't seen much use cases of enc-dec models using PP for deployment.

If this is really needed for your case, would you mind open a new issue and raise this feature request?

Hi, I tried run. py in py_session and the result was correct, but why is the triton backend incorrect? Are there any parameters that can be adjusted.

@symphonylyh
Copy link
Collaborator

@0xd8b good to hear that you have resolved the issue! yes it's indeed tricky for this int64/int32 thing. That's why we've put a caveat here from the past...

Next time it will be more straightforward that we know such interleaving problem is caused by int32/int32.

Closing the issue for now

@symphonylyh
Copy link
Collaborator

symphonylyh commented Jul 1, 2024

Hi @1096125073, it's because PP is only support in Python runtime at this point. There are several runtime choices for enc-dec:

  1. python runtime, using examples/enc_dec/run.py or examples/run.py --use_py_session
  2. pybind of cpp runtime, using examples/run.py
  3. triton backend, which is calling the same underlying APIs as (2)

Current status is, (1) supports TP + PP, (2) and (3) are the same and only supports TP due to the reason mentioned above (not much PP use cases for enc-dec, so it's on the roadmap but not at high priority). Does this help?

By the way, do you mean examples/run.py --use_py_session can work for enc-dec. If I remember correctly, enc-dec's Python runtime can only be run via exampels/enc_dec/run.py

Again, feel free to open a new issue regarding enc-dec C++ runtime for PP, and explain your need for PP there so we can prioritize accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Investigating
Projects
None yet
Development

No branches or pull requests

5 participants