diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index f1f93c34cbf4b..0ee18cc6799fc 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -85,6 +85,108 @@ struct ShutdownProtobuf { namespace onnxruntime { +// Helper function to check if a data type is supported by NvTensorRTRTX EP +static bool IsSupportedDataType(ONNXTensorElementDataType data_type) { + switch (data_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: // kFLOAT - 32-bit floating point + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: // kHALF - IEEE 16-bit floating-point + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // kBF16 - Brain float 16 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: // kBOOL - 8-bit boolean + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: // kINT4 - 4-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: // kINT8 - 8-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // kUINT8 - 8-bit unsigned integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: // kINT32 - 32-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: // kINT64 - 64-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: // kFP8 - 8-bit floating point + return true; + default: + return false; + } +} + +// Helper function to get data type name as string +static std::string GetDataTypeName(ONNXTensorElementDataType data_type) { + switch (data_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return "FLOAT"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return "FLOAT16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + return "BFLOAT16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return "BOOL"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + return "INT4"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return "INT8"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return "UINT8"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return "INT32"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return "INT64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: + return "FLOAT8E4M3FN"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return "DOUBLE"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + return "STRING"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + return "UINT16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return "UINT32"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return "UINT64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + return "INT16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: + return "COMPLEX64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: + return "COMPLEX128"; + default: + return "UNKNOWN(" + std::to_string(static_cast(data_type)) + ")"; + } +} + +// Helper function to check if a node has supported data types +static bool CheckNodeDataTypes(const Node* node) { + // Check input data types + for (const auto* input_def : node->InputDefs()) { + if (input_def->Exists()) { + const auto* type_proto = input_def->TypeAsProto(); + if (type_proto && type_proto->has_tensor_type()) { + auto data_type = static_cast(type_proto->tensor_type().elem_type()); + if (!IsSupportedDataType(data_type)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Node '" << node->Name() + << "' (OpType: " << node->OpType() + << ") has unsupported input data type: " << GetDataTypeName(data_type) + << " for input '" << input_def->Name() << "'"; + return false; + } + } + } + } + + // Check output data types + for (const auto* output_def : node->OutputDefs()) { + if (output_def->Exists()) { + const auto* type_proto = output_def->TypeAsProto(); + if (type_proto && type_proto->has_tensor_type()) { + auto data_type = static_cast(type_proto->tensor_type().elem_type()); + if (!IsSupportedDataType(data_type)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Node '" << node->Name() + << "' (OpType: " << node->OpType() + << ") has unsupported output data type: " << GetDataTypeName(data_type) + << " for output '" << output_def->Name() << "'"; + return false; + } + } + } + } + + return true; +} + void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr @@ -478,10 +580,12 @@ Status BindContextInput(Ort::KernelContext& ctx, CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); @@ -562,10 +666,12 @@ Status BindContextOutput(Ort::KernelContext& ctx, CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); @@ -624,10 +730,12 @@ Status BindKernelOutput(Ort::KernelContext& ctx, CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); @@ -1878,6 +1986,7 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, /* Iterate all the nodes and exclude the node if: * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. * 2. It's a DDS op. + * 3. It has unsupported data types. */ for (const auto& index : nodes_vector) { const auto& node = graph.GetNode(node_index[index]); @@ -1917,6 +2026,16 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, supported_node = false; } + // Check data types and print warnings for unsupported types + if (supported_node) { + if (!CheckNodeDataTypes(node)) { + supported_node = false; + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Node '" << node->Name() + << "' (OpType: " << node->OpType() + << ") excluded due to unsupported data types"; + } + } + if (supported_node) { if (new_subgraph) { parser_nodes_vector.emplace_back();