diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index f0c46279c566a..4d183b95bd938 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -286,11 +286,10 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, // Check if we need to add a cast node for int64 bool needs_int64_cast = false; if (is_graph_output) { - for (const auto& input_name : input_names) { - if (input_name.find("_cast_int32") != std::string::npos) { - needs_int64_cast = true; - break; - } + if (supported_qnn_data_type == output_info.qnn_data_type && + (output_info.qnn_data_type == QNN_DATATYPE_INT_64 || output_info.qnn_data_type == QNN_DATATYPE_UINT_64)) { + supported_qnn_data_type = supported_qnn_data_type == QNN_DATATYPE_INT_64 ? QNN_DATATYPE_INT_32 : QNN_DATATYPE_UINT_32; + needs_int64_cast = true; } }