-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[BYOC][TENSOORT] Add support for FP16 on TensorRT BYOC flow #10388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
comaniac
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, in addition to support FP16 for TRT, this PR attempts to deprecate the support of TRT < 7.0.0? Since we don't have TRT runtime in CI, I have no clue how it affects existing use cases. If so, this would be a more important change and needs to be discussed and documented.
| node_output_map_[nid] = {}; | ||
| for (auto out : params.outputs) { | ||
| VLOG(1) << "Before forcing output tensor type: " << static_cast<int>(out->getType()) | ||
| << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need a newline for log.
| // According to documentation this is required for single FP precision. Always on doesnt seem to | ||
| // prevent pure FP32 execution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Better to provide the document link.
|
|
||
| node_output_map_[nid].push_back(TensorRTOpInput(out)); | ||
| VLOG(1) << "After forcing output tensor type: " << static_cast<int>(out->getType()) | ||
| << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| // Pass it explicitly | ||
| // config_->setFlag(nvinfer1::BuilderFlag::kDEBUG); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove?
| DLDeviceType src_device) { | ||
| VLOG(1) << "Device type for DLTensorAsWeight: " << dptr->device.device_type; | ||
| VLOG(1) << "DLType for DLTensorAsWeight: " << dptr->dtype; | ||
| VLOG(1) << "DLShape for DLTensorAsWeight: " << dptr->shape << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| mod, params, i_data, mode=mode, use_trt=use_trt | ||
| ) | ||
|
|
||
| print(result_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
|
|
||
| if run_module: | ||
| assert_result_dict_holds(result_dict) | ||
| print(result_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
| # run_and_verify_func( | ||
| # get_graph((1, 3, 16, 16), (1, 3, 1, 1), channels=1), run_module=run_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uncomment?
| return f, {"x": x_shape}, [] | ||
|
|
||
| run_and_verify_func(get_graph(), run_module=run_module) | ||
| # for tp in ["float32", "float16", "int8", "uint8"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove or?
| # run_and_verify_func(get_graph((1, 1000), axis=-1), run_module=run_module) | ||
| # run_and_verify_func(get_graph((1, 3, 4), axis=-2), run_module=run_module) | ||
| # run_and_verify_func(get_graph((1, 3, 4), axis=1), run_module=run_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uncomment
81c53f4 to
e36ceb0
Compare
I revert all of the versioning changes and just kept it focused on the fp16 support. Thanks for the review PTAL. |
comaniac
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks. Just nits.
6a6640e to
5bdd0ed
Compare
| ICHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are supported."; | ||
| auto input_tensor = network_->addInput(name.c_str(), nvinfer1::DataType::kFLOAT, dims); | ||
| auto tensor_dtype = | ||
| (dtypes[i].bits == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest ICHECK failing if unsupported type.
| # ops with dynamic shapes are offloaded to VM | ||
| if check_dynamism(args, op_name): | ||
| return False | ||
| if any([x.checked_type.dtype != "float32" for x in args]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not seeing where the type check (which must now be generalized to float32/float16) has gone too. If we remove it altogether then I think we'll either generate bad code or fail at trt build time, which from the tvm users point of view is runtime and too late. We also need to check in the predicate to prevent collage from exploring invalid candidate kernels.
| // Get outputs. | ||
| node_output_map_[nid] = {}; | ||
| for (auto out : params.outputs) { | ||
| auto out_type = params.inputs.at(1).weight.type == params.inputs.at(0).tensor->getType() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain this? It seems very specific yet AddLayer is used for all of the supported ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unfortunately causing an vector index exception for me. I believe we need to pick up the output type from the node's dtype vector.
| ? nvinfer1::DataType::kFLOAT | ||
| : nvinfer1::DataType::kINT32; | ||
|
|
||
| const auto trt_dtype = (static_cast<int>(dptr->dtype.bits) == 16) ? nvinfer1::DataType::kHALF |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another ICHECK would be in order to make sure we're not silently generating bad code.
| builder_->destroy(); | ||
| for (auto weight : trt_weights_) { | ||
| if (weight.type == nvinfer1::DataType::kFLOAT) { | ||
| if (static_cast<int>(weight.type) <= 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we avoid hard coding the enum constants?
e7405a9 to
2eb104b
Compare
3a3e1e4 to
0741642
Compare
0741642 to
d0e508b
Compare
| nvinfer1::Dims dims = VectorToTrtDims(shape); | ||
| ICHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are supported."; | ||
| auto input_tensor = network_->addInput(name.c_str(), nvinfer1::DataType::kFLOAT, dims); | ||
| ICHECK((dtypes[i].bits != 16 || dtypes[i].bits != 32)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is always true, I think you mean bits == 16 || bits == 32.
| ret: bool | ||
| True if supported, False if not. | ||
| """ | ||
| if any([x.checked_type.dtype in supported_types for x in args]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if all(...)
return True
log error
return False
…0388) * FP16 support for TRT * Cleanups on tests * Fix for typing on output tensor * Fix icheck * Add TRT inference builder auto-convert precision flags as attrs in the config * Address PR comments * Fix bug on passing the new config attrs to codegen for tensorrt partition Co-authored-by: Michalis Papapdimitriou <[email protected]>
This PR enables support for FP16 types on the TensorRT BYOC flow.
Changes:
@mbs-octoml @electriclilies @masahi