Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enc_dec model results are not aligned with the HF model #612

Closed
NSun-S opened this issue Dec 8, 2023 · 7 comments
Closed

enc_dec model results are not aligned with the HF model #612

NSun-S opened this issue Dec 8, 2023 · 7 comments
Assignees

Comments

@NSun-S
Copy link

NSun-S commented Dec 8, 2023

I tried to use the trt-llm v0.6.1 to optimize the enc-dec model (Flan-T5-small). During usage, I observed that it could not be aligned with the huggingface model under fp32 precision(as well as bf16). The following is the process to reproduce this phenomenon:

1. git clone https://huggingface.co/google/flan-t5-small tmp/hf_models/flan-t5-small
2. python t5/hf_convert.py -i tmp/hf_models/flan-t5-small -o tmp/trt_models/flan-t5-small \
        --weight_data_type float32 --inference_tensor_para_size 1
3. python build.py --weight_dir tmp/trt_models/flan-t5-small/tp1/ --output_dir tmp/trt_engines/flan-t5-small \
        --engine_name flan-t5-small --dtype float32 --max_encoder_input_len 512 - -max_beam_width 3 \
        --use_bert_attention_plugin --use_gpt_attention_plugin --use_gemm_plugin --use_rmsnorm_plugin --use_lookup_plugin
4. python run.py --max_new_tokens 50 --engine_dir tmp/trt_engines/flan-t5-small/float32/tp1/ \
        --model_name google/flan-t5-small --compare_hf_fp32 --engine_name flan-t5-small --num_beams 1

The prompt word I used is the one you provided "translate English to German: The house is wonderful, radiating timeless charm and offering a warm, inviting interior with beautiful details and a serene backyard."

HF output text: ['Das Haus ist wunderbar, es es es es es es es es es es es es es es es es']
TRT-LLM output text: ['Das Haus ist wunderbar, es es voller zeitgeisten Charm und bietet eine gemütliche, energisches, energisches Interior mit schönen Detailen und eine sa']

I observed some slight anomalies in the encoder output as following:

--------------------------------------                                                                                                                                                            
Debug output for Encoder                                                                                                                                                                          
--------------------------------------                                                                                                                                                            
Registered output tensors are:  dict_keys(['encoder_output'])                                                                                                                                     
encoder_output: mean=0.084, sum=1382.613, max=0.475                                                                                                                                               
tensor([-0.1069, -0.0448,  0.3182,  0.0441,  0.1847,  0.1091, -0.0059,  0.0022,                                                                                                                   
         0.1441,  0.0039], device='cuda:0')                                                                                                                                                       
Tensor Shape:  torch.Size([1, 32, 512])                                                                                                                                                           
                                                                                                                                                                                                  
--------------------------------------                                                                                                                                                            
hf_model encoder_output: mean=0.084, sum=1380.891, max=0.478                                                                                                                                      
tensor([-0.1191, -0.0537,  0.3337,  0.0399,  0.1866,  0.1062, -0.0063, -0.0088,                  
         0.1511, -0.0212], device='cuda:0', grad_fn=<SliceBackward0>)                                                                                                                             
Tensor Shape:  torch.Size([1, 32, 512])                                                                                                                                                           
                                                                                                                                                                                                  
--------------------------------------     

Although the output of the HF model is terrible, from my perspective the output of trt-llm should be perfectly aligned with the HF model. The difference in output will be more pronounced in scenarios such as beam search and batch input.

Is this normal, or is there a known precision problem? I have read related issues, but none of them can solve my problem. Could you awesome guys give me some suggestions and possible solutions to this issue?

@symphonylyh
Copy link
Collaborator

symphonylyh commented Dec 8, 2023

Hi @NSun-S , thanks for posting this.

Conclusion first: this is normal, and it's likely a HF/PyTorch gemm problem.

I was obsessed by the same tiny numerical difference issue during my development of enc-dec too. You're checking the encoder_output tensor which has already gone through some numerical cumulation. I was checking the Q,K,V tensors right after QKV projection. The tiny deviation reaches noticeable decimal difference after a few layers. I wanted to know what's the ground truth, so this is what I did:
for FP32, Q = W*X, I saved 1 row in W and 1 column in X as tensors (i.e., two vectors that multiply-add to get one element in Q).
(1) use torch.matmul
(2) use TRT-LLM w/o gemm plugin
(3) use TRT-LLM w/ gemm plugin
(4) golden standard -- hand calculation, which I use numpy
Findings: (4) == (3) ~= (2) != (1) --> HF/PyTorch is not 100% reliable --> we shouldn't treat HF/PyTorch when it comes to tiny numerical difference. And a side note, FP32 is not guaranteed to have perfect match even, because gemm algorithms selection strategies in each framework are different.
Tiny difference will propagate over layers and over sequence length, so it is GUARANTEED to have results gap and the gap is GUARANTEED to enlarge when seqlen becomes longer. Model accuracy should be evaluated by downstream tasks instead of just numerically. e.g., for the given example above, I believe it can be interpreted as TRT-LLM result quality is better than HF result -- although this doesn't say anything either, LLM generation is not a deterministic process.

@symphonylyh symphonylyh self-assigned this Dec 8, 2023
@NSun-S
Copy link
Author

NSun-S commented Dec 8, 2023

@symphonylyh Thank you very much!!! Your reply has answered the questions that have troubled me.

@symphonylyh
Copy link
Collaborator

symphonylyh commented Dec 11, 2023

A little more explanation on this numerical analysis:
Such effect may be more prominent for encoder-decoder models than for decoder-only models, and the reason is cross attention:
encoder-decoder model will first run through encoder once, get the encoder output and the encoder KV cache. Then the decoder's cross attention will do matmul between decoder input & encoder KV cache.
Based on the above numerical accumulation explanation, you can see that the encoder output & KV cache itself has already accumulated some numerical errors through all the encoder layers. Then the cross attention calculation will inherit that error and again further accumulate through all decoder layers -- so you can see why encoder-decoder is more susceptible to such numerical deviation.

Key takeaway from this is: we should better evaluate on real downstream tasks and see whether & how much such numerical difference affects the output quality, rather than pursuing exact match of logit values. Of course, sometimes it's not easy to conclude whether it's implementation bug or numerical deviation, but so far from our analysis and user feedback we think it's not from implementation bug in TRT-LLM's encoder-decoder models.

@symphonylyh
Copy link
Collaborator

symphonylyh commented Dec 13, 2023

@NSun-S , for Flan-T5 specifically, I found a fix that should be very relevant to your observation above: see #474 (comment). Please apply the fix locally and see if it solves your problem.

The above discussion regarding numerical difference still holds true. But I guess this fix may solve the issue you described.

@NSun-S
Copy link
Author

NSun-S commented Dec 18, 2023

@NSun-S , for Flan-T5 specifically, I found a fix that should be very relevant to your observation above: see #474 (comment). Please apply the fix locally and see if it solves your problem.

The above discussion regarding numerical difference still holds true. But I guess this fix may solve the issue you described.

@symphonylyh Thanks for your support!! We have tried this modification mentioned above and the output will be different somewhat. However, the difference caused by this numerical issue still exists and has not completely disappeared.

I am willing to try any suggestions, and if there are any new findings, I will also synchronize them in the issue.

@symphonylyh
Copy link
Collaborator

@NSun-S just to confirm, you have commented out the rescale_before_lm_head logic, and rebuilt the engine. Because I just found my comments there might link to wrong lines after the latest main branch got updated. Anyway, you can stick to the current main which include this fix and no need to do manual change.

@NSun-S
Copy link
Author

NSun-S commented Dec 19, 2023

@symphonylyh Thank you for your reminder. I confirm that I have made the corresponding modifications. I also noticed the difference between Flan-T5 and T5 when locating the issue. At least from some examples, it can be seen that this modification has not fundamentally solved this problem(as mentioned earlier, differences exist from the encoder and continue to accumulate). Anyway, the original Flan-T5 did not use rescale_Before_LM_Head, we should handle it the same way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants