Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1466,12 +1466,15 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
fused_inputs.erase(it);
erased.insert(output);
}
// Only when output is neither in input list nor erased list, add the output to output list
// Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list
else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
fused_outputs[output] = output_order++;

if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
fused_outputs[output] = output_order++;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2114,12 +2114,15 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
fused_inputs.erase(it);
erased.insert(output);
}
// Only when output is neither in input list nor erased list, add the output to output list
// Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list
else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
fused_outputs[output] = output_order++;

if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
fused_outputs[output] = output_order++;
}
}
}
}
Expand Down
36 changes: 36 additions & 0 deletions onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,42 @@ void SmallModelTest(CompileParam test_param, bool fully_supported_model) {
session_object.Run(run_options, io_binding);
}

TEST(CompileApiTest, ModelWithOptionalNodeOutput) {
PathString model_name = path_utils::MakePathString("nv_execution_provider_compile_simple_model_with_optional_node_output.onnx");
PathString model_name_ctx = path_utils::MakePathString("nv_execution_provider_compile_simple_model_with_optional_node_output_ctx.onnx");
CreateSimpleModelWithOptionalNodeOutput(model_name);

Ort::SessionOptions session_options;
std::unordered_map<std::string, std::string> option_map{
{onnxruntime::nv::provider_option_names::kUseExternalDataInitializer, std::to_string(false)}};
auto ep = AppendTrtEtxEP(session_options, option_map);
Ort::ModelCompilationOptions model_compile_options(*ort_env, session_options);

model_compile_options.SetInputModelPath(model_name.c_str());
model_compile_options.SetOutputModelPath(model_name_ctx.c_str());

ASSERT_TRUE(Ort::CompileModel(*ort_env, model_compile_options).IsOK());

// Load the model from file
onnx::ModelProto model;
std::ifstream ifs(model_name_ctx, std::ios::binary);
if (!ifs) {
std::cerr << "Failed to open " << model_name_ctx.c_str() << "\n";
ASSERT_TRUE(false);
}

if (!model.ParseFromIstream(&ifs)) {
std::cerr << "Failed to parse ONNX model\n";
ASSERT_TRUE(false);
}

const onnx::GraphProto& graph = model.graph();
ASSERT_TRUE(graph.node_size() == 1);

const onnx::NodeProto& node = graph.node(0);
ASSERT_TRUE(node.output_size() == 1);
}

TEST_P(CompileApiTest, SmallModel) {
const auto& test_param = GetCompileParam();
SmallModelTest(test_param, true);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Licensed under the MIT License.

Expand Down Expand Up @@ -138,6 +138,72 @@ void CreateBaseModel(const PathString& model_name,
ASSERT_TRUE(status.IsOK());
}

void CreateSimpleModelWithOptionalNodeOutput(const PathString& model_name) {
// Create a new model
Model model("DropoutMatMulModel", false, DefaultLoggingManager().DefaultLogger());
Graph& graph = model.MainGraph();

// Define inputs
// X: [3, 2]
ONNX_NAMESPACE::TensorProto_DataType dtype = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
ONNX_NAMESPACE::TypeProto float_tensor;
std::vector<int> dims = {3, 2};
float_tensor.mutable_tensor_type()->set_elem_type(dtype);
for (auto dim : dims) {
float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
}

auto& X = graph.GetOrCreateNodeArg("X", &float_tensor);

// W: [2, 3]
ONNX_NAMESPACE::TypeProto float_tensor_2;
std::vector<int> dims_2 = {2, 3};
float_tensor_2.mutable_tensor_type()->set_elem_type(dtype);
for (auto dim : dims_2) {
float_tensor_2.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
}
auto& W = graph.GetOrCreateNodeArg("W", &float_tensor_2);

// Define outputs
ONNX_NAMESPACE::TypeProto float_tensor_3;
std::vector<int> dims_3 = {2, 3};
float_tensor_3.mutable_tensor_type()->set_elem_type(dtype);
for (auto dim : dims_3) {
float_tensor_3.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
}
auto& Y = graph.GetOrCreateNodeArg("Y", &float_tensor_3);

// Dropout Node
auto& dropout_out = graph.GetOrCreateNodeArg("dropout_out", &float_tensor);

dtype = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL;
ONNX_NAMESPACE::TypeProto boolean_tensor;
boolean_tensor.mutable_tensor_type()->set_elem_type(dtype);
auto& dropout_optional_out = graph.GetOrCreateNodeArg("dropout_mask", &boolean_tensor);

Node& dropout_node = graph.AddNode("DropoutNode", "Dropout", "Applies dropout",
{&X}, {&dropout_out, &dropout_optional_out});

// MatMul Node
Node& matmul_node = graph.AddNode("MatMulNode", "MatMul", "Dropout followed by MatMul",
{&dropout_out, &W}, {&Y});

// Mark graph inputs/outputs
graph.SetInputs({&X, &W});
graph.SetOutputs({&Y});

// Resolve to finalize
auto status = graph.Resolve();
if (!status.IsOK()) {
std::cerr << "Graph resolve failed: " << status.ErrorMessage() << "\n";
}

// Serialize to ONNX file
status = Model::Save(model, model_name);

ASSERT_TRUE(status.IsOK());
}

// Helper to create large initializers
ONNX_NAMESPACE::TensorProto CreateLargeWeight(
const std::string& name,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Licensed under the MIT License.
#pragma once
Expand Down Expand Up @@ -119,6 +119,22 @@ void CreateBaseModel(const PathString& model_name,

void CreateLargeLLMModel(const PathString& model_path, const PathString& external_data_path);

/**
* Create a simple model that has a dropout node that has an optional output.
* \param model_name - model name
*
*
* input (X, W)
* │
* Dropout(X) - has an optional "mask" output
* │
* MatMul(Dropout_out, W)
* │
* output (Y)
*
*/
void CreateSimpleModelWithOptionalNodeOutput(const PathString& model_name);

Ort::IoBinding generate_io_binding(
Ort::Session& session,
std::map<std::string, std::vector<int64_t>> shape_overwrites = {},
Expand Down
Loading