-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[NvTensorRT RTX] Add Bfloat16 #24743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I noticed some other things on the path for CUDA device bindings when ORT is compiled without CUDA EP and just with the CUDA EP interface enabled. I will convert this to a draft and finish up tomorrow. |
|
We will resort to relying on CUDA EP for device bindings for the time being. |
|
@chilo-ms can you help review this ? |
|
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline |
|
Azure Pipelines successfully started running 5 pipeline(s). |
|
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline |
|
Azure Pipelines successfully started running 5 pipeline(s). |
|
Please also help add bf16 in python binding for TRT EP. |
|
@chilo-ms do you mind if I just delete the entire thing and replace it with onnxruntime/onnxruntime/test/perftest/ort_test_session.cc Lines 151 to 183 in 2a09f27
Is the type checking and explicit warning really necessary at the pybind level ? |
Agree that type checking and explicit warning are not necessary. Either way, i think we can do this in another PR? |
|
@chilo-ms changes are done. Let me know if there is something else for this API. |
|
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline |
|
Azure Pipelines successfully started running 5 pipeline(s). |
|
We got the following error when running the tests on T4 machines: C++ exception with description "TensorRT EP failed to create engine from network for fused node: TensorrtExecutionProvider_TRTKernel_graph_mul test_9904508914055400613_0_0" thrown in the test body. |
Addressed in this PR. |
| } else if constexpr (std::is_same<T, BFloat16>::value) { | ||
| dtype_name = "fp16"; | ||
| } else if constexpr (std::is_same<T, MLFloat16>::value) { | ||
| dtype_name = "bf16"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this missed in the original review, it seems we are setting the BFloat16 type to fp16 rather than bf16 and vice-versa?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it seems so. If you want to use bfloat16 I would recommend using a strongly typed ONNX rather than using this global flag.
### Description TRT supports Bfloat 16 and ORT does as well. In addition the `setup.py` was missing a copy for NVTRT EP and TRT EP can only be built against the packaged parser with TRT RTX.
Description
TRT supports Bfloat 16 and ORT does as well.
In addition the
setup.pywas missing a copy for NVTRT EP and TRT EP can only be built against the packaged parser with TRT RTX.